Skip to content

Commit 00f0573

Browse files
committed
docs: add slurm page
1 parent b7a59b8 commit 00f0573

File tree

6 files changed

+172
-78
lines changed

6 files changed

+172
-78
lines changed

docs/src/.vitepress/config.mts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ export default defineConfig({
131131
text: 'Advanced',
132132
items: [
133133
{ text: 'Tuning', link: '/tuning' },
134+
{ text: 'Slurm', link: '/slurm' },
134135
{ text: 'Backend', link: '/backend' },
135136
]
136137
}

docs/src/slurm.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Slurm (Multi-node)
2+
3+
PySR supports running across multiple nodes on Slurm via `cluster_manager="slurm"`.
4+
This backend is **allocation-based**: you request resources with Slurm (`sbatch`/`salloc`), then PySR launches Julia workers inside that allocation (using `SlurmClusterManager.jl`).
5+
6+
Here is a minimal `sbatch` example using 3 workers on each of 2 nodes (6 workers total).
7+
8+
Save this as `pysr_job.sh`:
9+
10+
```bash
11+
#!/bin/bash
12+
#SBATCH --job-name=pysr
13+
#SBATCH --partition=normal
14+
#SBATCH --nodes=2
15+
#SBATCH --ntasks-per-node=3
16+
#SBATCH --time=01:00:00
17+
18+
set -euo pipefail
19+
python pysr_script.py
20+
```
21+
22+
Save this as `pysr_script.py` in the same directory:
23+
24+
```python
25+
import numpy as np
26+
from pysr import PySRRegressor
27+
28+
X = np.random.RandomState(0).randn(1000, 2)
29+
y = X[:, 0] + 2 * X[:, 1]
30+
31+
model = PySRRegressor(
32+
niterations=200,
33+
populations=2,
34+
parallelism="multiprocessing",
35+
cluster_manager="slurm",
36+
procs=6, # must match the Slurm allocation's total task count
37+
)
38+
model.fit(X, y)
39+
print(model)
40+
```
41+
42+
Submit it with:
43+
44+
```bash
45+
sbatch pysr_job.sh
46+
```
47+
48+
## Notes
49+
50+
- `procs` is the number of Julia worker processes. It must match the Slurm allocation's total tasks (e.g., `--ntasks` or `--nodes * --ntasks-per-node`).
51+
- Run the Python script once (as the master) inside the allocation; do not wrap it in `srun`.

docs/src/tuning.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ First, my general tips would be to avoid using redundant operators, like how `po
66

77
When running PySR, I usually do the following:
88

9-
I run from IPython (Jupyter Notebooks don't work as well[^1]) on the head node of a slurm cluster. Passing `cluster_manager="slurm"` will make PySR set up a run over the entire allocation. I set `procs` equal to the total number of cores over my entire allocation.
9+
I run from IPython (Jupyter Notebooks don't work as well[^1]) on the head node of a slurm cluster. Passing `cluster_manager="slurm"` will make PySR set up a run over the entire allocation. I set `procs` equal to the total number of tasks across my entire allocation (see the [Slurm page](slurm.md) for a complete multi-node example).
1010

1111
I use the [tensorboard feature](https://ai.damtp.cam.ac.uk/pysr/examples/#12-using-tensorboard-for-logging) for experiment tracking.
1212

pysr/test/slurm_docker_cluster/config/slurm.conf

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ SlurmctldLogFile=/var/log/slurm/slurmctld.log
3737
SlurmdDebug=info
3838
SlurmdLogFile=/var/log/slurm/slurmd.log
3939

40-
NodeName=c1 CPUs=2 RealMemory=1000 State=UNKNOWN
41-
NodeName=c2 CPUs=2 RealMemory=1000 State=UNKNOWN
40+
NodeName=c1 CPUs=4 RealMemory=1000 State=UNKNOWN
41+
NodeName=c2 CPUs=4 RealMemory=1000 State=UNKNOWN
4242
PartitionName=normal Nodes=c1,c2 Default=YES MaxTime=INFINITE State=UP

pysr/test/slurm_docker_cluster/docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ services:
1717
command: ["slurmd"]
1818
hostname: c1
1919
working_dir: /data
20-
cpus: 2
20+
cpus: 4
2121
privileged: true
2222
volumes:
2323
- etc_munge:/etc/munge
@@ -32,7 +32,7 @@ services:
3232
command: ["slurmd"]
3333
hostname: c2
3434
working_dir: /data
35-
cpus: 2
35+
cpus: 4
3636
privileged: true
3737
volumes:
3838
- etc_munge:/etc/munge

pysr/test/test_slurm.py

Lines changed: 115 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -330,82 +330,124 @@ def _wait_for_cluster_ready(self, *, timeout_s: int):
330330
time.sleep(2)
331331

332332
def test_pysr_slurm_cluster_manager(self):
333-
slurm_job = self.data_dir / "pysr_slurm_job.sh"
334-
slurm_job.write_text(
335-
"\n".join(
336-
[
337-
"#!/bin/bash",
338-
"#SBATCH --job-name=pysr-slurm-test",
339-
"#SBATCH --partition=normal",
340-
"#SBATCH --nodes=2",
341-
"#SBATCH --ntasks-per-node=1",
342-
"#SBATCH --time=40:00",
343-
"set -euo pipefail",
344-
'jobid="${SLURM_JOB_ID:-${SLURM_JOBID:-}}"',
345-
'if [ -z "$jobid" ]; then echo "Missing SLURM_JOB_ID/SLURM_JOBID" >&2; exit 1; fi',
346-
"monitor_steps() {",
347-
" while true; do",
348-
" echo PYSR_SCONTROL_STEP_SAMPLE",
349-
' scontrol show step "$jobid" -o 2>/dev/null || true',
350-
" sleep 1",
351-
" done",
352-
"}",
353-
"monitor_steps &",
354-
"MONITOR_PID=$!",
355-
"trap 'kill $MONITOR_PID 2>/dev/null || true' EXIT",
356-
"python3 - <<'PY'",
357-
"import os",
358-
"os.environ['JULIA_DEBUG'] = 'SlurmClusterManager'",
359-
"import numpy as np",
360-
"from pysr import PySRRegressor",
361-
"X = np.random.RandomState(0).randn(30, 2)",
362-
"y = X[:, 0] + 1.0",
363-
"model = PySRRegressor(",
364-
" niterations=2,",
365-
" populations=2,",
366-
" progress=False,",
367-
" temp_equation_file=True,",
368-
" parallelism='multiprocessing',",
369-
" procs=2,",
370-
" cluster_manager='slurm',",
371-
" verbosity=0,",
372-
")",
373-
"model.fit(X, y)",
374-
"print('PYSR_SLURM_OK:slurm')",
375-
"PY",
376-
]
333+
def _assert_worker_distribution(
334+
output: str,
335+
*,
336+
expected_procs: int,
337+
expected_nodes: int,
338+
expected_per_node: int,
339+
) -> None:
340+
worker_hosts = re.findall(
341+
r"Worker \d+ ready on host ([^,]+), port \d+",
342+
output,
343+
)
344+
self.assertEqual(
345+
len(worker_hosts),
346+
expected_procs,
347+
msg=f"Expected {expected_procs} workers.\n\n{output}",
348+
)
349+
counts: dict[str, int] = {}
350+
for host in worker_hosts:
351+
counts[host] = counts.get(host, 0) + 1
352+
self.assertEqual(
353+
len(counts),
354+
expected_nodes,
355+
msg=f"Expected workers on {expected_nodes} nodes.\n\n{output}",
356+
)
357+
self.assertTrue(
358+
all(v == expected_per_node for v in counts.values()),
359+
msg=f"Expected exactly {expected_per_node} workers per node.\n\n{output}",
377360
)
378-
+ "\n"
379-
)
380-
slurm_job.chmod(0o755)
381-
382-
slurm_output = self._run_sbatch(slurm_job)
383-
self.assertIn("PYSR_SLURM_OK:slurm", slurm_output)
384-
self._assert_scontrol_step_usage(
385-
slurm_output,
386-
expected_tasks=2,
387-
expected_nodes=2,
388-
expected_nodelist={"c1,c2", "c[1-2]"},
389-
label="slurm",
390-
)
391-
self.assertEqual(
392-
len(re.findall(r"^PYSR_SLURM_OK:slurm$", slurm_output, flags=re.MULTILINE)),
393-
1,
394-
msg=f"Expected slurm marker exactly once.\n\n{slurm_output}",
395-
)
396361

397-
self.assertEqual(
398-
len(
399-
re.findall(
400-
r"^\[ Info: Starting SLURM job .*", slurm_output, re.MULTILINE
362+
def _run_case(*, ntasks_per_node: int, procs: int, seed: int) -> str:
363+
marker = f"PYSR_SLURM_OK:slurm:{procs}"
364+
job = self.data_dir / f"pysr_slurm_job_{procs}.sh"
365+
job.write_text(
366+
"\n".join(
367+
[
368+
"#!/bin/bash",
369+
f"#SBATCH --job-name=pysr-slurm-test-{procs}",
370+
"#SBATCH --partition=normal",
371+
"#SBATCH --nodes=2",
372+
f"#SBATCH --ntasks-per-node={ntasks_per_node}",
373+
"#SBATCH --time=40:00",
374+
"set -euo pipefail",
375+
'jobid="${SLURM_JOB_ID:-${SLURM_JOBID:-}}"',
376+
'if [ -z "$jobid" ]; then echo "Missing SLURM_JOB_ID/SLURM_JOBID" >&2; exit 1; fi',
377+
"monitor_steps() {",
378+
" while true; do",
379+
" echo PYSR_SCONTROL_STEP_SAMPLE",
380+
' scontrol show step "$jobid" -o 2>/dev/null || true',
381+
" sleep 1",
382+
" done",
383+
"}",
384+
"monitor_steps &",
385+
"MONITOR_PID=$!",
386+
"trap 'kill $MONITOR_PID 2>/dev/null || true' EXIT",
387+
"python3 - <<'PY'",
388+
"import os",
389+
"os.environ['JULIA_DEBUG'] = 'SlurmClusterManager'",
390+
"import numpy as np",
391+
"from pysr import PySRRegressor",
392+
f"X = np.random.RandomState({seed}).randn(30, 2)",
393+
"y = X[:, 0] + 1.0",
394+
"model = PySRRegressor(",
395+
" niterations=2,",
396+
" populations=2,",
397+
" progress=False,",
398+
" temp_equation_file=True,",
399+
" parallelism='multiprocessing',",
400+
f" procs={procs},",
401+
" cluster_manager='slurm',",
402+
" verbosity=0,",
403+
")",
404+
"model.fit(X, y)",
405+
f"print('{marker}')",
406+
"PY",
407+
]
401408
)
402-
),
403-
0,
404-
msg=(
405-
"Expected Slurm backend to use SlurmClusterManager (allocation-based), "
406-
"not ClusterManagers.\n\n" + slurm_output
407-
),
408-
)
409+
+ "\n"
410+
)
411+
job.chmod(0o755)
412+
413+
output = self._run_sbatch(job)
414+
self.assertIn(marker, output)
415+
self.assertEqual(
416+
len(re.findall(rf"^{re.escape(marker)}$", output, flags=re.MULTILINE)),
417+
1,
418+
msg=f"Expected marker exactly once.\n\n{output}",
419+
)
420+
self.assertEqual(
421+
len(
422+
re.findall(r"^\[ Info: Starting SLURM job .*", output, re.MULTILINE)
423+
),
424+
0,
425+
msg=(
426+
"Expected Slurm backend to use SlurmClusterManager (allocation-based), "
427+
"not ClusterManagers.\n\n" + output
428+
),
429+
)
430+
return output
431+
432+
cases = [
433+
dict(ntasks_per_node=1, procs=2, seed=0),
434+
dict(ntasks_per_node=3, procs=6, seed=2),
435+
]
436+
for case in cases:
437+
output = _run_case(**case)
438+
self._assert_scontrol_step_usage(
439+
output,
440+
expected_tasks=case["procs"],
441+
expected_nodes=2,
442+
expected_nodelist={"c1,c2", "c[1-2]"},
443+
label=f"slurm({case['procs']})",
444+
)
445+
_assert_worker_distribution(
446+
output,
447+
expected_procs=case["procs"],
448+
expected_nodes=2,
449+
expected_per_node=case["ntasks_per_node"],
450+
)
409451

410452

411453
def runtests(just_tests=False):

0 commit comments

Comments
 (0)