33import shlex
44import subprocess
55import time
6- from typing import Optional
6+ from pathlib import Path
77
88from consts import CUDA_FLAGS , ExitCode
99
@@ -153,19 +153,19 @@ def run_program(args: list[str]) -> RunResult:
153153
154154
155155def run_cuda_script ( # # noqa: C901
156- script_content : str ,
157- reference_content : str = None ,
158- submission_content : str = None ,
156+ sources : dict [str , str ],
157+ headers : dict [str , str ] = None ,
159158 arch : int = None ,
160159 include_dirs : list [str ] = None ,
161160) -> tuple [CompileResult , RunResult ]:
162161 """
163162 Executes the provided CUDA kernel in an isolated environment with a timeout
164163
165164 Args:
166- script_content: The CUDA script containing the GPU kernel
167- reference_content: The (optional) reference code, used for leaderboards.
168- submission_content: The (optional) submission code, used for leaderboards.
165+ sources: The source files to compile. Mapping file name to content.
166+ headers: Additional header files to create for the compile run.
167+ Mapping of file name to file contents. These files will _not_ be added to the
168+ compile command.
169169 arch: The arch code for the compute/sm versions. If None, native arch is used.
170170 include_dirs: Additional include directories, e.g., for thunderkittens/cutlass etc
171171
@@ -177,80 +177,66 @@ def run_cuda_script( # # noqa: C901
177177
178178 try :
179179 # Write submission files to directory
180- if reference_content is not None :
181- with open ("reference.cuh" , "w" ) as f :
182- f .write (reference_content )
180+ for source , content in sources .items ():
181+ Path (source ).write_text (content )
183182
184- if submission_content is not None :
185- with open ("train.cuh" , "w" ) as f :
186- f .write (submission_content )
187-
188- with open ("eval.cu" , "w" ) as f :
189- f .write (script_content )
183+ for header , content in headers .items ():
184+ Path (header ).write_text (content )
190185
191186 compile_result = compile_cuda_script (
192- files = [ "eval.cu" ] ,
187+ files = list ( sources . keys ()) ,
193188 arch = arch ,
194189 include_dirs = include_dirs ,
195190 verbose = True ,
196191 )
197-
198- if not compile_result .success :
199- return compile_result , RunResult (
200- success = False ,
201- command = "" ,
202- stdout = "" ,
203- stderr = "" ,
204- exit_code = - 1 ,
205- duration = 0.0 ,
206- result = {},
207- )
208-
209- run_result = run_program (["./eval.out" ])
210- return compile_result , run_result
211-
192+ # cleaning up all source files _before_ we let the user code run, just in
193+ # case there's something in there that the user isn't supposed to snoop
212194 finally :
213- tmp_files = [ "reference.cuh" , "train.cuh" , "eval.cu" , "eval.out" ]
195+ tmp_files = list ( sources . keys ()) + list ( headers . keys ())
214196 for f in tmp_files :
215197 if os .path .exists (f ):
216198 os .remove (f )
217199
200+ if not compile_result .success :
201+ return compile_result , RunResult (
202+ success = False ,
203+ command = "" ,
204+ stdout = "" ,
205+ stderr = "" ,
206+ exit_code = - 1 ,
207+ duration = 0.0 ,
208+ result = {},
209+ )
210+
211+ run_result = run_program (["./eval.out" ])
212+ return compile_result , run_result
213+
218214
219215def run_pytorch_script ( # noqa: C901
220- script_content : str ,
221- reference_content : Optional [str ] = None ,
222- submission_content : Optional [str ] = None ,
216+ sources : dict [str , str ],
217+ main : str ,
223218 arch : int = None ,
224219) -> RunResult :
225220 """
226221 Executes the provided PyTorch GPU kernel in an isolated environment
227222
228223 Args:
229- script_content: The PyTorch script containing the GPU kernel to benchmark
230- reference_content: The (optional) reference code, used for leaderboards.
231- submission_content: The (optional) submission code, used for leaderboards.
224+ sources: Files to generate
225+ main: Which file to run. Must be one of the keys in sources.
232226 arch: The arch code for the compute/sm versions.
233227
234228 Returns:
235229 tuple[str, float]: (Kernel output, execution time in milliseconds)
236230 """
237231 try :
238- # Write submission files to directory
239- if reference_content is not None :
240- with open ("reference.py" , "w" ) as f :
241- f .write (reference_content )
242-
243- if submission_content is not None :
244- with open ("train.py" , "w" ) as f :
245- f .write (submission_content )
232+ assert main in sources .keys ()
246233
247- with open ( "eval.py" , "w" ) as f :
248- f . write ( script_content )
249-
250- return run_program (["python" , "eval.py" ])
234+ # Write submission files to directory
235+ for source , content in sources . items ():
236+ Path ( source ). write_text ( content )
237+ return run_program (["python" , main ])
251238
252239 finally :
253- tmp_files = ["eval.py" , "reference.py" , "train.py" ]
254- for f in tmp_files :
240+ for f in sources .keys ():
255241 if os .path .exists (f ):
256242 os .remove (f )
0 commit comments