Skip to content

Commit b1cedc2

Browse files
committed
update pytorch script runner to use pipe
1 parent 95ab010 commit b1cedc2

File tree

5 files changed

+133
-106
lines changed

5 files changed

+133
-106
lines changed

src/discord-cluster-manager/cogs/verify_run_cog.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import re
3+
from pathlib import Path
34
from unittest.mock import AsyncMock
45

56
import discord
@@ -12,19 +13,16 @@
1213
logger = setup_logging()
1314

1415

15-
def create_mock_attachment():
16+
def create_mock_attachment(file_name: str, content: str):
1617
"Create an AsyncMock to simulate discord.Attachment"
1718

1819
mock_attachment = AsyncMock(spec=discord.Attachment)
19-
mock_attachment.filename = "test_script.py"
20+
mock_attachment.filename = file_name
2021
mock_attachment.content_type = "text/plain"
21-
mock_attachment.read = AsyncMock(return_value="print('Hello, world!')".encode("utf-8"))
22+
mock_attachment.read = AsyncMock(return_value=content.encode("utf-8"))
2223
return mock_attachment
2324

2425

25-
script_file = create_mock_attachment()
26-
27-
2826
class VerifyRunCog(commands.Cog):
2927
"""
3028
A Discord cog for verifying the success of training runs.
@@ -45,6 +43,7 @@ async def verify_github_run(
4543
interaction: discord.Interaction,
4644
) -> bool:
4745
github_command = github_cog.run_github
46+
script_file = create_mock_attachment("test_script.py", "print('Hello, world!')")
4847
github_thread = await github_command.callback(github_cog, interaction, script_file, choice)
4948

5049
message_contents = [msg.content async for msg in github_thread.history(limit=None)]
@@ -86,7 +85,13 @@ async def verify_modal_run(self, modal_cog: ModalCog, interaction: discord.Inter
8685
t4 = app_commands.Choice(name="T4", value="t4")
8786
modal_command = modal_cog.run_modal
8887

89-
modal_thread = await modal_command.callback(modal_cog, interaction, script_file, t4)
88+
sub_code = create_mock_attachment(
89+
"submission.py", Path("examples/identity_py/submission.py").read_text()
90+
)
91+
ref_code = Path("examples/identity_py/reference.py").read_text()
92+
modal_thread = await modal_command.callback(
93+
modal_cog, interaction, sub_code, t4, reference_code=ref_code
94+
)
9095

9196
message_contents = [msg.content async for msg in modal_thread.history(limit=None)]
9297

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import math
2+
import os
3+
import time
4+
5+
import torch
6+
from reference import check_implementation, generate_input, ref_kernel
7+
from train import custom_kernel
8+
9+
10+
class PopcornLogger:
11+
def __init__(self, fd):
12+
self.channel = open(fd, "w")
13+
14+
def log(self, key: str, value):
15+
print(f"{key}: {value}\n", file=self.channel)
16+
17+
18+
def correctness() -> bool:
19+
for _ in range(10): # check multiple times
20+
inputs = generate_input()
21+
22+
custom_output = custom_kernel(inputs)
23+
ref_output = ref_kernel(inputs)
24+
25+
if not check_implementation(custom_output, ref_output):
26+
return False
27+
28+
print("custom implementation matches the reference implementation.")
29+
return True
30+
31+
32+
def metric(logger: PopcornLogger):
33+
warmup_runs = 10
34+
timed_runs = 100
35+
36+
# Warmup Code
37+
print("warming up...")
38+
for _ in range(warmup_runs):
39+
inputs = generate_input()
40+
_ = custom_kernel(inputs)
41+
torch.cuda.synchronize()
42+
43+
# Timing Code
44+
times = []
45+
46+
for _ in range(timed_runs):
47+
inputs = generate_input()
48+
49+
start_time = time.time()
50+
custom_output = custom_kernel(inputs)
51+
torch.cuda.synchronize()
52+
end_time = time.time()
53+
times.append(end_time - start_time)
54+
55+
ref_output = ref_kernel(inputs)
56+
torch.cuda.synchronize()
57+
if not check_implementation(custom_output, ref_output):
58+
logger.log("check", "fail")
59+
exit(1)
60+
61+
total_time = sum(times)
62+
average_duration = total_time / timed_runs
63+
variance = sum(map(lambda x: (x - average_duration) ** 2, times)) # noqa
64+
standard_deviation = math.sqrt(variance / (timed_runs - 1))
65+
standard_error = standard_deviation / math.sqrt(timed_runs)
66+
67+
logger.log("check", "pass")
68+
logger.log("duration.mean", average_duration * 1e9)
69+
logger.log("duration.std", standard_deviation * 1e9)
70+
logger.log("duration.err", standard_error * 1e9)
71+
logger.log("duration.best", min(times) * 1e9)
72+
logger.log("duration.worst", max(times) * 1e9)
73+
74+
print(f"Submitted kernel runtime: {average_duration:.4f} ± {standard_error:.4} seconds")
75+
76+
77+
def main():
78+
logger = PopcornLogger(int(os.environ["POPCORN_FD"]))
79+
if not correctness():
80+
logger.log("check", "fail")
81+
exit(1)
82+
metric(logger)
83+
84+
85+
if __name__ == "__main__":
86+
main()

src/discord-cluster-manager/leaderboard_eval.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,5 @@
44

55
from pathlib import Path
66

7-
py_eval = """
8-
import torch
9-
import time
10-
from reference import ref_kernel, generate_input, check_implementation
11-
from train import custom_kernel
12-
13-
14-
def correctness() -> bool:
15-
for _ in range(10): # check multiple times
16-
inputs = generate_input()
17-
18-
custom_output = custom_kernel(inputs)
19-
ref_output = ref_kernel(inputs)
20-
21-
if not check_implementation(custom_output, ref_output):
22-
return False
23-
24-
print('custom implementation matches the reference implementation.')
25-
return True
26-
27-
28-
def metric():
29-
warmup_runs = 10
30-
timed_runs = 100
31-
32-
# Warmup Code
33-
print('warming up...')
34-
for _ in range(warmup_runs):
35-
inputs = generate_input()
36-
_ = custom_kernel(inputs)
37-
torch.cuda.synchronize()
38-
39-
# Timing Code
40-
total_time = 0.0
41-
42-
for _ in range(timed_runs):
43-
inputs = generate_input()
44-
45-
start_time = time.time()
46-
custom_output = custom_kernel(inputs)
47-
torch.cuda.synchronize()
48-
end_time = time.time()
49-
total_time += (end_time - start_time)
50-
51-
ref_output = ref_kernel(inputs)
52-
torch.cuda.synchronize()
53-
if not check_implementation(custom_output, ref_output):
54-
return -1
55-
56-
57-
custom_duration = total_time / timed_runs
58-
59-
print(f'Submitted kernel runtime: {custom_duration:.4f} seconds')
60-
61-
return custom_duration
62-
63-
def main():
64-
assert (correctness())
65-
s = metric()
66-
67-
print(f'score:{s}')
68-
69-
if __name__ == '__main__':
70-
main()
71-
72-
"""
73-
7+
py_eval = Path.read_text(Path(__file__).parent / "eval.py")
748
cu_eval = Path.read_text(Path(__file__).parent / "eval.cu")

src/discord-cluster-manager/modal_runner.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,40 @@ def modal_run_pytorch_script( # noqa: C901
7979
"""Modal version of run_pytorch_script, handling timeouts"""
8080
try:
8181
with timeout(timeout_seconds):
82-
return run_pytorch_script(
82+
run_result = run_pytorch_script(
8383
script_content=script_content,
8484
reference_content=reference_content,
8585
submission_content=submission_content,
8686
arch=arch,
8787
)
88+
if not run_result.success:
89+
# exit code 1 encodes failed tests
90+
if run_result.exit_code == 1:
91+
return f"check_implementation failed:\n{run_result.stderr}", 0.0
92+
else:
93+
return (
94+
f"Script failed with exit code "
95+
f"({run_result.exit_code}):\n{run_result.stderr}",
96+
0.0,
97+
)
98+
99+
print("run process stdout:", run_result.stdout)
100+
print("run process stderr:", run_result.stderr)
101+
102+
score = float(run_result.result.get("duration.mean", "0.0")) / 1e9
103+
passed = run_result.result.get("check", "") == "pass"
104+
if not passed:
105+
return "check_implementation failed", 0.0
106+
107+
if score is None:
108+
return run_result.stdout, run_result.duration
109+
110+
return run_result.stdout, score
88111

89112
except TimeoutException as e:
90113
return f"Timeout Error: {str(e)}", 0.0
114+
except Exception as e:
115+
return f"Error executing script: {str(e)}", 0.0
91116

92117

93118
def modal_run_cuda_script( # # noqa: C901

src/discord-cluster-manager/run_eval.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def compile_cuda_script( # # noqa: C901
114114
)
115115

116116

117-
def run_cuda_program(args: list[str]) -> RunResult:
117+
def run_program(args: list[str]) -> RunResult:
118118
# set up a pipe so the tester can communicate its verdict with us
119119
env = os.environ.copy()
120120
pipe_read, pipe_write = os.pipe()
@@ -142,7 +142,9 @@ def run_cuda_program(args: list[str]) -> RunResult:
142142
result_dict[key.strip()] = value.strip()
143143

144144
return RunResult(
145-
success=run_process.returncode == 0,
145+
# TODO should we return 0 also on test failure?
146+
# TODO check what return codes python uses, e.g. on uncaught exception
147+
success=(run_process.returncode == 0 or run_process.returncode == 1),
146148
command=_make_cmd(run_process.args),
147149
stdout=run_process.stdout,
148150
stderr=run_process.stderr,
@@ -206,7 +208,7 @@ def run_cuda_script( # # noqa: C901
206208
result={},
207209
)
208210

209-
run_result = run_cuda_program(["./eval.out"])
211+
run_result = run_program(["./eval.out"])
210212
return compile_result, run_result
211213

212214
finally:
@@ -221,9 +223,9 @@ def run_pytorch_script( # noqa: C901
221223
reference_content: Optional[str] = None,
222224
submission_content: Optional[str] = None,
223225
arch: int = None,
224-
) -> tuple[str, float]:
226+
) -> RunResult:
225227
"""
226-
Executes the provided PyTorch GPU kernel in an isolated environment with a timeout
228+
Executes the provided PyTorch GPU kernel in an isolated environment
227229
228230
Args:
229231
script_content: The PyTorch script containing the GPU kernel to benchmark
@@ -247,33 +249,8 @@ def run_pytorch_script( # noqa: C901
247249
with open("eval.py", "w") as f:
248250
f.write(script_content)
249251

250-
execution_start_time = time.perf_counter()
251-
result = subprocess.run(
252-
["python", "eval.py"],
253-
stdout=subprocess.PIPE,
254-
stderr=subprocess.PIPE,
255-
text=True,
256-
)
257-
258-
if result.returncode != 0:
259-
raise RuntimeError(
260-
"Script execution failed with return code "
261-
+ f"{result.returncode}:\n{result.stderr}"
262-
)
263-
264-
score = None
265-
for line in result.stdout.splitlines():
266-
if line.startswith("score:"):
267-
score = float(line.split(":")[1].strip())
268-
return "score", score
269-
270-
if score is None:
271-
execution_end_time = time.perf_counter()
272-
score = execution_end_time - execution_start_time
252+
return run_program(["python", "eval.py"])
273253

274-
return result.stdout, score
275-
except Exception as e:
276-
return f"Error executing script: {str(e)}", 0.0
277254
finally:
278255
tmp_files = ["eval.py", "reference.py", "train.py"]
279256
for f in tmp_files:

0 commit comments

Comments
 (0)