Skip to content

Commit 48b7a5a

Browse files
committed
allow arbitrary sources and headers
1 parent d26faaf commit 48b7a5a

File tree

5 files changed

+73
-68
lines changed

5 files changed

+73
-68
lines changed

scripts/ci_test_cuda.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def test_does_not_compile():
2020
output_t custom_kernel(input_tt data) { }
2121
"""
2222

23-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
23+
comp, run = run_cuda_script(
24+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
25+
)
2426
assert comp.success is False
2527
assert run.success is False
2628
assert comp.nvcc_found is True
@@ -52,7 +54,9 @@ def test_cuda_runtime_error():
5254
}
5355
5456
"""
55-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
57+
comp, run = run_cuda_script(
58+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
59+
)
5660
assert comp.success is True
5761
assert run.success is False
5862
assert run.command == "./eval.out"
@@ -80,7 +84,9 @@ def test_cuda_validation_fail():
8084
}
8185
8286
"""
83-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
87+
comp, run = run_cuda_script(
88+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
89+
)
8490
assert comp.success is True
8591
assert run.success is False
8692
assert run.command == "./eval.out"
@@ -94,7 +100,9 @@ def test_cuda_validation_fail():
94100
def test_cuda_correct():
95101
sub = Path("examples/identity_cuda/submission.cuh").read_text()
96102

97-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
103+
comp, run = run_cuda_script(
104+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
105+
)
98106
assert comp.success is True
99107
assert run.success is True
100108
assert "warming up..." in run.stdout

scripts/ci_test_python.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def test_does_not_import():
2020
this is a syntax error
2121
"""
2222

23-
run = run_pytorch_script(py_eval, ref.read_text(), sub, arch=None)
23+
run = run_pytorch_script(
24+
{"eval.py": py_eval, "reference.py": ref.read_text(), "train.py": sub}, "eval.py"
25+
)
2426
assert run.success is False
2527
assert run.exit_code != ExitCode.SUCCESS
2628
assert "IndentationError: unexpected indent\n" in run.stderr
@@ -33,7 +35,9 @@ def test_error():
3335
def custom_kernel(input):
3436
return [torch.zeros_like(i) for i in input]
3537
"""
36-
run = run_pytorch_script(py_eval, ref.read_text(), sub, arch=None)
38+
run = run_pytorch_script(
39+
{"eval.py": py_eval, "reference.py": ref.read_text(), "train.py": sub}, "eval.py"
40+
)
3741
assert run.success is False
3842
assert run.command == "python eval.py"
3943
# we never reach the benchmark part, because the test fails
@@ -46,7 +50,9 @@ def custom_kernel(input):
4650
def test_correct():
4751
sub = Path("examples/identity_py/submission.py").read_text()
4852

49-
run = run_pytorch_script(py_eval, ref.read_text(), sub, arch=None)
53+
run = run_pytorch_script(
54+
{"eval.py": py_eval, "reference.py": ref.read_text(), "train.py": sub}, "eval.py"
55+
)
5056
assert run.success is True
5157
assert "warming up..." in run.stdout
5258
assert run.exit_code == ExitCode.SUCCESS

scripts/local-test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
ref = Path("examples/identity_cuda/reference.cuh")
1010
sub = Path("examples/identity_cuda/submission.cuh")
1111

12-
cout, score = run_cuda_script(cu_eval, ref.read_text(), sub.read_text(), arch=None)
12+
cout, score = run_cuda_script(
13+
{"eval.cu": cu_eval},
14+
{"reference.cuh": ref.read_text(), "train.cuh": sub.read_text()},
15+
arch=None,
16+
)
1317
print(cout)
1418
print(score)
1519
exit(0 if score > 0 else 1)

src/discord-cluster-manager/modal_runner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ def modal_run_pytorch_script( # noqa: C901
8080
try:
8181
with timeout(timeout_seconds):
8282
run_result = run_pytorch_script(
83-
script_content=script_content,
84-
reference_content=reference_content,
85-
submission_content=submission_content,
86-
arch=arch,
83+
{
84+
"eval.py": script_content,
85+
"reference.py": reference_content,
86+
"train.py": submission_content,
87+
},
88+
"eval.py",
8789
)
8890
if not run_result.success:
8991
# exit code 1 encodes failed tests
@@ -126,9 +128,8 @@ def modal_run_cuda_script( # # noqa: C901
126128
try:
127129
with timeout(timeout_seconds):
128130
compile_result, run_result = run_cuda_script(
129-
script_content,
130-
reference_content=reference_content,
131-
submission_content=submission_content,
131+
{"eval.cu": script_content},
132+
{"reference.cuh": reference_content, "train.cuh": submission_content},
132133
arch=arch,
133134
include_dirs=MODAL_CUDA_INCLUDE_DIRS,
134135
)

src/discord-cluster-manager/run_eval.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import shlex
44
import subprocess
55
import time
6-
from typing import Optional
6+
from pathlib import Path
77

88
from consts import CUDA_FLAGS, ExitCode
99

@@ -153,19 +153,19 @@ def run_program(args: list[str]) -> RunResult:
153153

154154

155155
def run_cuda_script( # # noqa: C901
156-
script_content: str,
157-
reference_content: str = None,
158-
submission_content: str = None,
156+
sources: dict[str, str],
157+
headers: dict[str, str] = None,
159158
arch: int = None,
160159
include_dirs: list[str] = None,
161160
) -> tuple[CompileResult, RunResult]:
162161
"""
163162
Executes the provided CUDA kernel in an isolated environment with a timeout
164163
165164
Args:
166-
script_content: The CUDA script containing the GPU kernel
167-
reference_content: The (optional) reference code, used for leaderboards.
168-
submission_content: The (optional) submission code, used for leaderboards.
165+
sources: The source files to compile. Mapping file name to content.
166+
headers: Additional header files to create for the compile run.
167+
Mapping of file name to file contents. These files will _not_ be added to the
168+
compile command.
169169
arch: The arch code for the compute/sm versions. If None, native arch is used.
170170
include_dirs: Additional include directories, e.g., for thunderkittens/cutlass etc
171171
@@ -177,80 +177,66 @@ def run_cuda_script( # # noqa: C901
177177

178178
try:
179179
# Write submission files to directory
180-
if reference_content is not None:
181-
with open("reference.cuh", "w") as f:
182-
f.write(reference_content)
180+
for source, content in sources.items():
181+
Path(source).write_text(content)
183182

184-
if submission_content is not None:
185-
with open("train.cuh", "w") as f:
186-
f.write(submission_content)
187-
188-
with open("eval.cu", "w") as f:
189-
f.write(script_content)
183+
for header, content in headers.items():
184+
Path(header).write_text(content)
190185

191186
compile_result = compile_cuda_script(
192-
files=["eval.cu"],
187+
files=list(sources.keys()),
193188
arch=arch,
194189
include_dirs=include_dirs,
195190
verbose=True,
196191
)
197-
198-
if not compile_result.success:
199-
return compile_result, RunResult(
200-
success=False,
201-
command="",
202-
stdout="",
203-
stderr="",
204-
exit_code=-1,
205-
duration=0.0,
206-
result={},
207-
)
208-
209-
run_result = run_program(["./eval.out"])
210-
return compile_result, run_result
211-
192+
# cleaning up all source files _before_ we let the user code run, just in
193+
# case there's something in there that the user isn't supposed to snoop
212194
finally:
213-
tmp_files = ["reference.cuh", "train.cuh", "eval.cu", "eval.out"]
195+
tmp_files = list(sources.keys()) + list(headers.keys())
214196
for f in tmp_files:
215197
if os.path.exists(f):
216198
os.remove(f)
217199

200+
if not compile_result.success:
201+
return compile_result, RunResult(
202+
success=False,
203+
command="",
204+
stdout="",
205+
stderr="",
206+
exit_code=-1,
207+
duration=0.0,
208+
result={},
209+
)
210+
211+
run_result = run_program(["./eval.out"])
212+
return compile_result, run_result
213+
218214

219215
def run_pytorch_script( # noqa: C901
220-
script_content: str,
221-
reference_content: Optional[str] = None,
222-
submission_content: Optional[str] = None,
216+
sources: dict[str, str],
217+
main: str,
223218
arch: int = None,
224219
) -> RunResult:
225220
"""
226221
Executes the provided PyTorch GPU kernel in an isolated environment
227222
228223
Args:
229-
script_content: The PyTorch script containing the GPU kernel to benchmark
230-
reference_content: The (optional) reference code, used for leaderboards.
231-
submission_content: The (optional) submission code, used for leaderboards.
224+
sources: Files to generate
225+
main: Which file to run. Must be one of the keys in sources.
232226
arch: The arch code for the compute/sm versions.
233227
234228
Returns:
235229
tuple[str, float]: (Kernel output, execution time in milliseconds)
236230
"""
237231
try:
238-
# Write submission files to directory
239-
if reference_content is not None:
240-
with open("reference.py", "w") as f:
241-
f.write(reference_content)
242-
243-
if submission_content is not None:
244-
with open("train.py", "w") as f:
245-
f.write(submission_content)
232+
assert main in sources.keys()
246233

247-
with open("eval.py", "w") as f:
248-
f.write(script_content)
249-
250-
return run_program(["python", "eval.py"])
234+
# Write submission files to directory
235+
for source, content in sources.items():
236+
Path(source).write_text(content)
237+
return run_program(["python", main])
251238

252239
finally:
253-
tmp_files = ["eval.py", "reference.py", "train.py"]
254-
for f in tmp_files:
240+
for f in sources.keys():
255241
if os.path.exists(f):
256242
os.remove(f)

0 commit comments

Comments
 (0)