Skip to content

Commit 2366d78

Browse files
committed
allow arbitrary sources and headers
1 parent 6d030d0 commit 2366d78

File tree

5 files changed

+69
-68
lines changed

5 files changed

+69
-68
lines changed

scripts/ci_test_cuda.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def test_does_not_compile():
1919
output_t custom_kernel(input_tt data) { }
2020
"""
2121

22-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
22+
comp, run = run_cuda_script(
23+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
24+
)
2325
assert comp.success is False
2426
assert run.success is False
2527
assert comp.nvcc_found is True
@@ -50,7 +52,9 @@ def test_cuda_runtime_error():
5052
}
5153
5254
"""
53-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
55+
comp, run = run_cuda_script(
56+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
57+
)
5458
assert comp.success is True
5559
assert run.success is False
5660
assert run.command == "./eval.out"
@@ -78,7 +82,9 @@ def test_cuda_validation_fail():
7882
}
7983
8084
"""
81-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
85+
comp, run = run_cuda_script(
86+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
87+
)
8288
assert comp.success is True
8389
assert run.success is False
8490
assert run.command == "./eval.out"
@@ -92,7 +98,9 @@ def test_cuda_validation_fail():
9298
def test_cuda_correct():
9399
sub = Path("examples/identity_cuda/submission.cuh").read_text()
94100

95-
comp, run = run_cuda_script(cu_eval, ref.read_text(), sub, arch=None)
101+
comp, run = run_cuda_script(
102+
{"eval.cu": cu_eval}, {"reference.cuh": ref.read_text(), "train.cuh": sub}, arch=None
103+
)
96104
assert comp.success is True
97105
assert run.success is True
98106
assert "warming up..." in run.stdout

scripts/ci_test_python.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def test_does_not_import():
1919
this is a syntax error
2020
"""
2121

22-
run = run_pytorch_script(py_eval, ref.read_text(), sub, arch=None)
22+
run = run_pytorch_script(
23+
{"eval.py": py_eval, "reference.py": ref.read_text(), "train.py": sub}, "eval.py"
24+
)
2325
assert run.success is False
2426
assert run.exit_code == 1
2527
assert "IndentationError: unexpected indent\n" in run.stderr
@@ -32,7 +34,9 @@ def test_error():
3234
def custom_kernel(input):
3335
return [torch.zeros_like(i) for i in input]
3436
"""
35-
run = run_pytorch_script(py_eval, ref.read_text(), sub, arch=None)
37+
run = run_pytorch_script(
38+
{"eval.py": py_eval, "reference.py": ref.read_text(), "train.py": sub}, "eval.py"
39+
)
3640
assert run.success is False
3741
assert run.command == "python eval.py"
3842
# we never reach the benchmark part, because the test fails
@@ -45,7 +49,9 @@ def custom_kernel(input):
4549
def test_correct():
4650
sub = Path("examples/identity_py/submission.py").read_text()
4751

48-
run = run_pytorch_script(py_eval, ref.read_text(), sub, arch=None)
52+
run = run_pytorch_script(
53+
{"eval.py": py_eval, "reference.py": ref.read_text(), "train.py": sub}, "eval.py"
54+
)
4955
assert run.success is True
5056
assert "warming up..." in run.stdout
5157
assert run.exit_code == 0

scripts/local-test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
ref = Path("examples/identity_cuda/reference.cuh")
1010
sub = Path("examples/identity_cuda/submission.cuh")
1111

12-
cout, score = run_cuda_script(cu_eval, ref.read_text(), sub.read_text(), arch=None)
12+
cout, score = run_cuda_script(
13+
{"eval.cu": cu_eval},
14+
{"reference.cuh": ref.read_text(), "train.cuh": sub.read_text()},
15+
arch=None,
16+
)
1317
print(cout)
1418
print(score)
1519
exit(0 if score > 0 else 1)

src/discord-cluster-manager/modal_runner.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,8 @@ def modal_run_pytorch_script( # noqa: C901
8080
try:
8181
with timeout(timeout_seconds):
8282
run_result = run_pytorch_script(
83-
script_content=script_content,
84-
reference_content=reference_content,
85-
submission_content=submission_content,
86-
arch=arch,
83+
{"eval.py": script_content, "reference.py": reference_content, "train.py": submission_content},
84+
"eval.py"
8785
)
8886
if not run_result.success:
8987
# exit code 1 encodes failed tests
@@ -126,9 +124,8 @@ def modal_run_cuda_script( # # noqa: C901
126124
try:
127125
with timeout(timeout_seconds):
128126
compile_result, run_result = run_cuda_script(
129-
script_content,
130-
reference_content=reference_content,
131-
submission_content=submission_content,
127+
{"eval.cu": script_content},
128+
{"reference.cuh": reference_content, "train.cuh": submission_content},
132129
arch=arch,
133130
include_dirs=MODAL_CUDA_INCLUDE_DIRS,
134131
)

src/discord-cluster-manager/run_eval.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import shlex
44
import subprocess
55
import time
6-
from typing import Optional
6+
from pathlib import Path
77

88
from consts import CUDA_FLAGS
99

@@ -155,19 +155,19 @@ def run_program(args: list[str]) -> RunResult:
155155

156156

157157
def run_cuda_script( # # noqa: C901
158-
script_content: str,
159-
reference_content: str = None,
160-
submission_content: str = None,
158+
sources: dict[str, str],
159+
headers: dict[str, str] = None,
161160
arch: int = None,
162161
include_dirs: list[str] = None,
163162
) -> tuple[CompileResult, RunResult]:
164163
"""
165164
Executes the provided CUDA kernel in an isolated environment with a timeout
166165
167166
Args:
168-
script_content: The CUDA script containing the GPU kernel
169-
reference_content: The (optional) reference code, used for leaderboards.
170-
submission_content: The (optional) submission code, used for leaderboards.
167+
sources: The source files to compile. Mapping file name to content.
168+
headers: Additional header files to create for the compile run.
169+
Mapping of file name to file contents. These files will _not_ be added to the
170+
compile command.
171171
arch: The arch code for the compute/sm versions. If None, native arch is used.
172172
include_dirs: Additional include directories, e.g., for thunderkittens/cutlass etc
173173
@@ -179,80 +179,66 @@ def run_cuda_script( # # noqa: C901
179179

180180
try:
181181
# Write submission files to directory
182-
if reference_content is not None:
183-
with open("reference.cuh", "w") as f:
184-
f.write(reference_content)
182+
for source, content in sources.items():
183+
Path(source).write_text(content)
185184

186-
if submission_content is not None:
187-
with open("train.cuh", "w") as f:
188-
f.write(submission_content)
189-
190-
with open("eval.cu", "w") as f:
191-
f.write(script_content)
185+
for header, content in headers.items():
186+
Path(header).write_text(content)
192187

193188
compile_result = compile_cuda_script(
194-
files=["eval.cu"],
189+
files=list(sources.keys()),
195190
arch=arch,
196191
include_dirs=include_dirs,
197192
verbose=True,
198193
)
199-
200-
if not compile_result.success:
201-
return compile_result, RunResult(
202-
success=False,
203-
command="",
204-
stdout="",
205-
stderr="",
206-
exit_code=-1,
207-
duration=0.0,
208-
result={},
209-
)
210-
211-
run_result = run_program(["./eval.out"])
212-
return compile_result, run_result
213-
194+
# cleaning up all source files _before_ we let the user code run, just in
195+
# case there's something in there that the user isn't supposed to snoop
214196
finally:
215-
tmp_files = ["reference.cuh", "train.cuh", "eval.cu", "eval.out"]
197+
tmp_files = list(sources.keys()) + list(headers.keys())
216198
for f in tmp_files:
217199
if os.path.exists(f):
218200
os.remove(f)
219201

202+
if not compile_result.success:
203+
return compile_result, RunResult(
204+
success=False,
205+
command="",
206+
stdout="",
207+
stderr="",
208+
exit_code=-1,
209+
duration=0.0,
210+
result={},
211+
)
212+
213+
run_result = run_program(["./eval.out"])
214+
return compile_result, run_result
215+
220216

221217
def run_pytorch_script( # noqa: C901
222-
script_content: str,
223-
reference_content: Optional[str] = None,
224-
submission_content: Optional[str] = None,
218+
sources: dict[str, str],
219+
main: str,
225220
arch: int = None,
226221
) -> RunResult:
227222
"""
228223
Executes the provided PyTorch GPU kernel in an isolated environment
229224
230225
Args:
231-
script_content: The PyTorch script containing the GPU kernel to benchmark
232-
reference_content: The (optional) reference code, used for leaderboards.
233-
submission_content: The (optional) submission code, used for leaderboards.
226+
sources: Files to generate
227+
main: Which file to run. Must be one of the keys in sources.
234228
arch: The arch code for the compute/sm versions.
235229
236230
Returns:
237231
tuple[str, float]: (Kernel output, execution time in milliseconds)
238232
"""
239233
try:
240-
# Write submission files to directory
241-
if reference_content is not None:
242-
with open("reference.py", "w") as f:
243-
f.write(reference_content)
244-
245-
if submission_content is not None:
246-
with open("train.py", "w") as f:
247-
f.write(submission_content)
234+
assert main in sources.keys()
248235

249-
with open("eval.py", "w") as f:
250-
f.write(script_content)
251-
252-
return run_program(["python", "eval.py"])
236+
# Write submission files to directory
237+
for source, content in sources.items():
238+
Path(source).write_text(content)
239+
return run_program(["python", main])
253240

254241
finally:
255-
tmp_files = ["eval.py", "reference.py", "train.py"]
256-
for f in tmp_files:
242+
for f in sources.keys():
257243
if os.path.exists(f):
258244
os.remove(f)

0 commit comments

Comments
 (0)