1212from discord .ext import commands
1313from env import GITHUB_REPO , GITHUB_TOKEN
1414from github import Github
15- from leaderboard_eval import cu_eval , py_eval
15+ from leaderboard_eval import amd_requirements , nvidia_requirements
1616from report import generate_report
1717from run_eval import CompileResult , FullResult , RunResult
1818from utils import get_github_branch_name , send_discord_message , setup_logging
@@ -68,34 +68,12 @@ async def run_github(
6868 else :
6969 reference_content = None
7070
71- if gpu_type .value == "nvidia" :
72- run_id = await self .trigger_github_nvidia (
73- lang = lang ,
74- script_content = script_content ,
75- reference_content = reference_content ,
76- )
77- else :
78- ##########
79- # OLD CODE
80- filename = "train.py" if script .filename .endswith (".py" ) else "train.cu"
81- if reference_script is not None or reference_code is not None :
82- reference_content = (
83- reference_code
84- if reference_code is not None
85- else (await reference_script .read ()).decode ("utf-8" )
86- )
87- eval_code = py_eval if script .filename .endswith (".py" ) else cu_eval
88-
89- run_id = await self .trigger_github_amd (
90- script_content ,
91- filename ,
92- selected_gpu ,
93- reference_content ,
94- eval_code ,
95- )
96- else :
97- run_id = await self .trigger_github_amd (script_content , filename , selected_gpu )
98- ##########
71+ run_id = await self .trigger_github_run (
72+ lang = lang ,
73+ gpu_type = selected_gpu ,
74+ script_content = script_content ,
75+ reference_content = reference_content ,
76+ )
9977
10078 if run_id :
10179 await thread .send (
@@ -128,88 +106,42 @@ async def run_github(
128106 await thread .send (f"Error processing request: { str (e )} " )
129107 raise
130108
131- async def trigger_github_nvidia (
132- self , lang : str , script_content : str , reference_content : Optional [str ]
109+ async def trigger_github_run (
110+ self , lang : str , gpu_type : GPUType , script_content : str , reference_content : Optional [str ]
133111 ):
112+ if lang == "cu" and gpu_type == GPUType .AMD :
113+ # TODO implement HIP
114+ raise ValueError ("Cannot use CUDA runs with AMD GPUs" )
115+
134116 eval_name = {"py" : "eval.py" , "cu" : "eval.cu" }[lang ]
135117 ref_name = {"py" : "reference.py" , "cu" : "reference.cuh" }[lang ]
136118 sub_name = {"py" : "submission.py" , "cu" : "submission.cuh" }[lang ]
119+ lang_name = {"py" : "Python" , "cu" : "CUDA" }[lang ]
137120
138121 if reference_content is None :
139122 config = {eval_name : script_content , "lang" : lang }
140123 else :
141124 config = {ref_name : reference_content , sub_name : script_content , "lang" : lang }
142125
143- logger .info ("Attempting to trigger GitHub action for NVIDIA " )
126+ logger .info (f "Attempting to trigger GitHub action for { lang_name } on { gpu_type . name } " )
144127 gh = Github (GITHUB_TOKEN )
145128 repo = gh .get_repo (GITHUB_REPO )
146129
147130 try :
148131 trigger_time = datetime .now (timezone .utc )
149- workflow_file = "nvidia_workflow.yml"
132+ workflow_file = gpu_type . value
150133 workflow = repo .get_workflow (workflow_file )
151134
152135 payload = json .dumps (config )
153136
154137 inputs = {"payload" : payload }
155138 if lang == "py" :
156- inputs ["requirements" ] = "numpy\n torch\n setuptools\n ninja\n triton"
157-
158- success = workflow .create_dispatch (
159- get_github_branch_name (),
160- inputs = inputs
161- )
162- if success :
163- await asyncio .sleep (2 )
164- runs = list (workflow .get_runs ())
165-
166- for run in runs :
167- if run .created_at .replace (tzinfo = timezone .utc ) > trigger_time :
168- return run .id
169- return None
170-
171- except Exception as e :
172- logger .error (f"Error in trigger_github_action: { str (e )} " , exc_info = True )
173- return None
174-
175- async def trigger_github_amd (
176- self ,
177- script_content ,
178- filename ,
179- gpu_type ,
180- reference_content = None ,
181- eval_content = None ,
182- ):
183- logger .info (f"Attempting to trigger GitHub action for { gpu_type .name } GPU" )
184- gh = Github (GITHUB_TOKEN )
185- repo = gh .get_repo (GITHUB_REPO )
186-
187- try :
188- trigger_time = datetime .now (timezone .utc )
189- workflow_file = gpu_type .value
190- workflow = repo .get_workflow (workflow_file )
191-
192- if reference_content is not None :
193- eval_filename = "eval.py" if filename .endswith (".py" ) else "eval.cu"
194- reference_filename = "reference.py" if filename .endswith (".py" ) else "reference.cuh"
195- filename = "train.py" if filename .endswith (".py" ) else "train.cuh"
196- success = workflow .create_dispatch (
197- get_github_branch_name (),
198- {
199- "script_content" : script_content ,
200- "filename" : filename ,
201- "reference_content" : reference_content ,
202- "reference_filename" : reference_filename ,
203- "eval_content" : eval_content ,
204- "eval_filename" : eval_filename ,
205- },
206- )
207- else :
208- success = workflow .create_dispatch (
209- get_github_branch_name (),
210- {"script_content" : script_content , "filename" : filename },
211- )
139+ if gpu_type == GPUType .NVIDIA :
140+ inputs ["requirements" ] = nvidia_requirements
141+ else :
142+ inputs ["requirements" ] = amd_requirements
212143
144+ success = workflow .create_dispatch (get_github_branch_name (), inputs = inputs )
213145 if success :
214146 await asyncio .sleep (2 )
215147 runs = list (workflow .get_runs ())
@@ -258,10 +190,7 @@ async def check_workflow_status(self, run_id, thread, gpu_type):
258190 )
259191
260192 if run .status == "completed" :
261- if gpu_type .value == "nvidia" :
262- result = await self .download_results (run_id )
263- else :
264- result = await self .handle_training_log (run_id )
193+ result = await self .download_results (run_id )
265194 return run .conclusion , result , run .html_url
266195
267196 await thread .send (
@@ -271,6 +200,7 @@ async def check_workflow_status(self, run_id, thread, gpu_type):
271200 )
272201 await asyncio .sleep (20 )
273202 except Exception as e :
203+ logger .error ("Error" , exc_info = e )
274204 return "error" , str (e ), None
275205
276206 async def download_results (self , run_id ) -> FullResult :
@@ -285,21 +215,14 @@ async def download_results(self, run_id) -> FullResult:
285215 run = RunResult (** data ["run" ])
286216 return FullResult (success = True , error = "" , compile = comp , run = run )
287217 except Exception as e :
218+ logger .error ("Error downloading artifacts" , exc_info = e )
288219 return FullResult (
289220 success = False ,
290- error = f"Error downloading artifacts: { str (e )} " ,
221+ error = f"Error downloading artifacts: { repr (e )} " ,
291222 compile = None ,
292223 run = None ,
293224 )
294225
295- async def handle_training_log (self , run_id ):
296- try :
297- data = await self .download_artifact (run_id , name = "training-artifacts" )
298- logs = data ["training.log" ].decode ("utf-8" )
299- return logs
300- except Exception as e :
301- return f"Error downloading artifacts: { str (e )} "
302-
303226 async def download_artifact (self , run_id , name : str ):
304227 logger .info (f"Attempting to download artifact { name } for run { run_id } " )
305228 gh = Github (GITHUB_TOKEN )
@@ -330,4 +253,4 @@ async def download_artifact(self, run_id, name: str):
330253 raise RuntimeError (
331254 f"Failed to download artifact. Status code: { response .status_code } "
332255 )
333- return RuntimeError (f"Could not find artifact { name } " )
256+ raise RuntimeError (f"Could not find artifact { name } " )
0 commit comments