33import shlex
44import subprocess
55import time
6- from typing import Optional
6+ from pathlib import Path
77
88from consts import CUDA_FLAGS
99
@@ -155,19 +155,19 @@ def run_program(args: list[str]) -> RunResult:
155155
156156
157157def run_cuda_script ( # # noqa: C901
158- script_content : str ,
159- reference_content : str = None ,
160- submission_content : str = None ,
158+ sources : dict [str , str ],
159+ headers : dict [str , str ] = None ,
161160 arch : int = None ,
162161 include_dirs : list [str ] = None ,
163162) -> tuple [CompileResult , RunResult ]:
164163 """
165164 Executes the provided CUDA kernel in an isolated environment with a timeout
166165
167166 Args:
168- script_content: The CUDA script containing the GPU kernel
169- reference_content: The (optional) reference code, used for leaderboards.
170- submission_content: The (optional) submission code, used for leaderboards.
167+ sources: The source files to compile. Mapping file name to content.
168+ headers: Additional header files to create for the compile run.
169+ Mapping of file name to file contents. These files will _not_ be added to the
170+ compile command.
171171 arch: The arch code for the compute/sm versions. If None, native arch is used.
172172 include_dirs: Additional include directories, e.g., for thunderkittens/cutlass etc
173173
@@ -179,80 +179,66 @@ def run_cuda_script( # # noqa: C901
179179
180180 try :
181181 # Write submission files to directory
182- if reference_content is not None :
183- with open ("reference.cuh" , "w" ) as f :
184- f .write (reference_content )
182+ for source , content in sources .items ():
183+ Path (source ).write_text (content )
185184
186- if submission_content is not None :
187- with open ("train.cuh" , "w" ) as f :
188- f .write (submission_content )
189-
190- with open ("eval.cu" , "w" ) as f :
191- f .write (script_content )
185+ for header , content in headers .items ():
186+ Path (header ).write_text (content )
192187
193188 compile_result = compile_cuda_script (
194- files = [ "eval.cu" ] ,
189+ files = list ( sources . keys ()) ,
195190 arch = arch ,
196191 include_dirs = include_dirs ,
197192 verbose = True ,
198193 )
199-
200- if not compile_result .success :
201- return compile_result , RunResult (
202- success = False ,
203- command = "" ,
204- stdout = "" ,
205- stderr = "" ,
206- exit_code = - 1 ,
207- duration = 0.0 ,
208- result = {},
209- )
210-
211- run_result = run_program (["./eval.out" ])
212- return compile_result , run_result
213-
194+ # cleaning up all source files _before_ we let the user code run, just in
195+ # case there's something in there that the user isn't supposed to snoop
214196 finally :
215- tmp_files = [ "reference.cuh" , "train.cuh" , "eval.cu" , "eval.out" ]
197+ tmp_files = list ( sources . keys ()) + list ( headers . keys ())
216198 for f in tmp_files :
217199 if os .path .exists (f ):
218200 os .remove (f )
219201
202+ if not compile_result .success :
203+ return compile_result , RunResult (
204+ success = False ,
205+ command = "" ,
206+ stdout = "" ,
207+ stderr = "" ,
208+ exit_code = - 1 ,
209+ duration = 0.0 ,
210+ result = {},
211+ )
212+
213+ run_result = run_program (["./eval.out" ])
214+ return compile_result , run_result
215+
220216
221217def run_pytorch_script ( # noqa: C901
222- script_content : str ,
223- reference_content : Optional [str ] = None ,
224- submission_content : Optional [str ] = None ,
218+ sources : dict [str , str ],
219+ main : str ,
225220 arch : int = None ,
226221) -> RunResult :
227222 """
228223 Executes the provided PyTorch GPU kernel in an isolated environment
229224
230225 Args:
231- script_content: The PyTorch script containing the GPU kernel to benchmark
232- reference_content: The (optional) reference code, used for leaderboards.
233- submission_content: The (optional) submission code, used for leaderboards.
226+ sources: Files to generate
227+ main: Which file to run. Must be one of the keys in sources.
234228 arch: The arch code for the compute/sm versions.
235229
236230 Returns:
237231 tuple[str, float]: (Kernel output, execution time in milliseconds)
238232 """
239233 try :
240- # Write submission files to directory
241- if reference_content is not None :
242- with open ("reference.py" , "w" ) as f :
243- f .write (reference_content )
244-
245- if submission_content is not None :
246- with open ("train.py" , "w" ) as f :
247- f .write (submission_content )
234+ assert main in sources .keys ()
248235
249- with open ( "eval.py" , "w" ) as f :
250- f . write ( script_content )
251-
252- return run_program (["python" , "eval.py" ])
236+ # Write submission files to directory
237+ for source , content in sources . items ():
238+ Path ( source ). write_text ( content )
239+ return run_program (["python" , main ])
253240
254241 finally :
255- tmp_files = ["eval.py" , "reference.py" , "train.py" ]
256- for f in tmp_files :
242+ for f in sources .keys ():
257243 if os .path .exists (f ):
258244 os .remove (f )
0 commit comments