Skip to content

Commit b5ae1b1

Browse files
committed
allow resubmit failed, non-running jobs
1 parent 7706bfc commit b5ae1b1

File tree

5 files changed

+220
-17
lines changed

5 files changed

+220
-17
lines changed

babs/base.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
read_yaml,
2323
results_branch_dataframe,
2424
results_status_columns,
25+
scheduler_status_columns,
2526
status_dtypes,
2627
update_job_batch_status,
2728
update_results_status,
@@ -444,15 +445,51 @@ def get_currently_running_jobs_df(self):
444445
Index: []
445446
446447
"""
448+
449+
def _empty_running():
450+
cols = scheduler_status_columns + ['sub_id']
451+
if self.processing_level == 'session':
452+
cols = cols + ['ses_id']
453+
return pd.DataFrame(columns=cols)
454+
455+
job_status_df = self.get_job_status_df()
447456
last_submitted_jobs_df = self.get_latest_submitted_jobs_df()
448-
if last_submitted_jobs_df.empty:
449-
return EMPTY_JOB_SUBMIT_DF
450-
job_ids = last_submitted_jobs_df['job_id'].unique()
451-
if not len(job_ids) == 1:
452-
raise Exception(f'Expected 1 job id, got {len(job_ids)}')
453-
job_id = job_ids[0]
454-
currently_running_df = request_all_job_status(self.queue, job_id)
455-
return identify_running_jobs(last_submitted_jobs_df, currently_running_df)
457+
458+
# Rows that are submitted but don't have results yet (candidates for "running")
459+
if not job_status_df.empty:
460+
sub = job_status_df['submitted'].fillna(False)
461+
no_res = ~job_status_df['has_results'].fillna(False)
462+
job_status_df = job_status_df.loc[sub & no_res].copy()
463+
464+
# Use status rows (submitted, no results) or last submit file for job_id -> sub/ses
465+
mapping_df = job_status_df if not job_status_df.empty else last_submitted_jobs_df.copy()
466+
if mapping_df.empty:
467+
return _empty_running()
468+
469+
# Keep only columns needed to join scheduler output with subject/session
470+
mapping_cols = ['job_id', 'task_id', 'sub_id']
471+
if 'ses_id' in mapping_df:
472+
mapping_cols.append('ses_id')
473+
mapping_df = mapping_df[mapping_cols].copy()
474+
# Drop rows with missing or invalid job/task ids so we only query real jobs
475+
mapping_df = mapping_df[
476+
mapping_df['job_id'].notna()
477+
& mapping_df['task_id'].notna()
478+
& (mapping_df['job_id'] > 0)
479+
& (mapping_df['task_id'] > 0)
480+
]
481+
if mapping_df.empty:
482+
return _empty_running()
483+
484+
# Ask scheduler for each distinct job_id, keep only non-empty responses
485+
job_ids = sorted({int(j) for j in mapping_df['job_id'].unique()})
486+
running_dfs = [request_all_job_status(self.queue, j) for j in job_ids]
487+
running_dfs = [d for d in running_dfs if not d.empty]
488+
if not running_dfs:
489+
return _empty_running()
490+
491+
# Attach sub_id (and ses_id) to scheduler rows and return
492+
return identify_running_jobs(mapping_df, pd.concat(running_dfs, ignore_index=True))
456493

457494
def get_job_status_df(self):
458495
"""

babs/cli.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,14 @@ def _parse_submit():
351351
' If this flag is specified, it will override the `--select` flag.',
352352
type=PathExists,
353353
)
354+
parser.add_argument(
355+
'--skip-running-jobs',
356+
action='store_true',
357+
help=(
358+
'Allow submission when there are running/pending jobs by skipping '
359+
'those jobs instead of raising errrors.'
360+
),
361+
)
354362

355363
return parser
356364

@@ -376,6 +384,7 @@ def babs_submit_main(
376384
count: int | None,
377385
select: list | None,
378386
inclusion_file: Path | None,
387+
skip_running_jobs: bool = False,
379388
):
380389
"""This is the core function of ``babs submit``.
381390
@@ -389,6 +398,8 @@ def babs_submit_main(
389398
list of subject IDs and session IDs to be submitted.
390399
inclusion_file: Path
391400
path to a CSV file that lists the subjects (and sessions) to analyze.
401+
skip_running_jobs: bool
402+
whether to allow submission when there are running/pending jobs
392403
"""
393404
import pandas as pd
394405

@@ -406,7 +417,11 @@ def babs_submit_main(
406417
else:
407418
df_job_specified = None
408419

409-
babs_proj.babs_submit(count=count, submit_df=df_job_specified)
420+
babs_proj.babs_submit(
421+
count=count,
422+
submit_df=df_job_specified,
423+
skip_running_jobs=skip_running_jobs,
424+
)
410425

411426

412427
def _parse_status():

babs/interaction.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class BABSInteraction(BABS):
1616
"""Implement interactions with a BABS project - submitting jobs and checking status."""
1717

18-
def babs_submit(self, count=None, submit_df=None, skip_failed=False):
18+
def babs_submit(self, count=None, submit_df=None, skip_failed=False, skip_running_jobs=False):
1919
"""
2020
This function submits jobs that don't have results yet and prints out job status.
2121
@@ -28,22 +28,34 @@ def babs_submit(self, count=None, submit_df=None, skip_failed=False):
2828
submit_df: pd.DataFrame
2929
dataframe of jobs to be submitted
3030
default: None
31+
skip_running_jobs: bool
32+
whether to allow submission when there are running/pending jobs
3133
"""
3234

3335
# Check if there are still jobs running
3436
currently_running_df = self.get_currently_running_jobs_df()
37+
running_pending_df = currently_running_df.copy()
3538
if currently_running_df.shape[0] > 0:
3639
non_cg_states = (
3740
currently_running_df['state'].fillna('').ne('CG')
3841
if 'state' in currently_running_df
3942
else np.array([True] * currently_running_df.shape[0])
4043
)
4144
if non_cg_states.any():
42-
raise Exception(
43-
'There are still jobs running. Please wait for them to finish or cancel them.'
44-
f' Current running jobs:\n{currently_running_df}'
45-
)
46-
print('All currently running jobs are in CG state; proceeding with submission.')
45+
if not skip_running_jobs:
46+
raise Exception(
47+
'There are still jobs running. '
48+
'Please wait for them to finish or cancel them. '
49+
'Current running jobs:\n'
50+
f'{currently_running_df}'
51+
)
52+
if 'state' in currently_running_df:
53+
running_pending_df = currently_running_df[
54+
currently_running_df['state'].isin(['PD', 'R'])
55+
]
56+
else:
57+
running_pending_df = currently_running_df.iloc[0:0]
58+
print('All currently running jobs are in CG state; proceeding with submission.')
4759

4860
# Find the rows that don't have results yet
4961
status_df = self.get_job_status_df()
@@ -54,6 +66,38 @@ def babs_submit(self, count=None, submit_df=None, skip_failed=False):
5466
if submit_df is not None:
5567
df_needs_submit = submit_df
5668

69+
if skip_running_jobs and not running_pending_df.empty:
70+
# Build (sub_id,) or (sub_id, ses_id) keys for set lookup
71+
if self.processing_level == 'session':
72+
running_keys = set(
73+
zip(
74+
running_pending_df['sub_id'],
75+
running_pending_df['ses_id'],
76+
strict=False,
77+
)
78+
)
79+
submit_keys = list(
80+
zip(df_needs_submit['sub_id'], df_needs_submit['ses_id'], strict=False)
81+
)
82+
else:
83+
running_keys = set(running_pending_df['sub_id'].tolist())
84+
submit_keys = df_needs_submit['sub_id'].tolist()
85+
86+
# Mark which of the to-submit rows are still running/pending
87+
if running_keys:
88+
skip_mask = [key in running_keys for key in submit_keys]
89+
else:
90+
skip_mask = [False] * len(submit_keys)
91+
92+
if any(skip_mask):
93+
# Report skipped job IDs and filter them out of the submission list
94+
skip_job_ids = sorted(running_pending_df['job_id'].dropna().unique().tolist())
95+
print(
96+
'Skipping running/pending jobs from job IDs: '
97+
+ ', '.join(str(job_id) for job_id in skip_job_ids)
98+
)
99+
df_needs_submit = df_needs_submit.loc[~np.array(skip_mask)].reset_index(drop=True)
100+
57101
# only run `babs submit` when there are subjects/sessions not yet submitted
58102
if df_needs_submit.empty:
59103
print('No jobs to submit')

tests/test_babs_workflow.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ def test_babs_init_raw_bids(
9999

100100
# babs submit:
101101
babs_submit_opts = argparse.Namespace(
102-
project_root=project_root, select=None, inclusion_file=None, count=1
102+
project_root=project_root,
103+
select=None,
104+
inclusion_file=None,
105+
count=1,
106+
skip_running_jobs=False,
103107
)
104108
with mock.patch.object(argparse.ArgumentParser, 'parse_args', return_value=babs_submit_opts):
105109
_enter_submit()
@@ -124,7 +128,11 @@ def test_babs_init_raw_bids(
124128

125129
# Submit the last job:
126130
babs_submit_opts = argparse.Namespace(
127-
project_root=project_root, select=None, inclusion_file=None, count=None
131+
project_root=project_root,
132+
select=None,
133+
inclusion_file=None,
134+
count=None,
135+
skip_running_jobs=False,
128136
)
129137
with mock.patch.object(argparse.ArgumentParser, 'parse_args', return_value=babs_submit_opts):
130138
_enter_submit()

tests/test_interaction.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from babs.interaction import BABSInteraction
7+
from babs.utils import scheduler_status_columns
78

89

910
def _minimal_status_df():
@@ -26,6 +27,26 @@ def _minimal_status_df():
2627
)
2728

2829

30+
def _status_df_for_submit():
31+
return pd.DataFrame(
32+
{
33+
'sub_id': ['sub-01', 'sub-02', 'sub-03'],
34+
'submitted': [True, True, False],
35+
'has_results': [False, False, False],
36+
'is_failed': [False, True, False],
37+
'job_id': [10, 11, -1],
38+
'task_id': [1, 1, -1],
39+
'state': ['R', '', ''],
40+
'time_used': ['0:01', '', ''],
41+
'time_limit': ['5-00:00:00', '', ''],
42+
'nodes': [1, 0, 0],
43+
'cpus': [1, 0, 0],
44+
'partition': ['normal', '', ''],
45+
'name': ['test_array_job', '', ''],
46+
}
47+
)
48+
49+
2950
def test_babs_submit_blocks_non_cg_jobs(babs_project_subjectlevel, monkeypatch):
3051
babs_proj = BABSInteraction(project_root=babs_project_subjectlevel)
3152
running_df = pd.DataFrame(
@@ -76,3 +97,81 @@ def _mock_submit_array(analysis_path, queue, total_jobs):
7697
babs_proj.babs_submit(count=1)
7798

7899
assert submit_calls
100+
101+
102+
def test_babs_submit_allows_running_skips_jobs(babs_project_subjectlevel, monkeypatch, capsys):
103+
babs_proj = BABSInteraction(project_root=babs_project_subjectlevel)
104+
running_df = pd.DataFrame(
105+
{
106+
'job_id': [10],
107+
'task_id': [1],
108+
'state': ['R'],
109+
'time_used': ['0:01'],
110+
'time_limit': ['5-00:00:00'],
111+
'nodes': [1],
112+
'cpus': [1],
113+
'partition': ['normal'],
114+
'name': ['test_array_job'],
115+
'sub_id': ['sub-01'],
116+
}
117+
)
118+
monkeypatch.setattr(babs_proj, 'get_currently_running_jobs_df', lambda: running_df)
119+
monkeypatch.setattr(babs_proj, 'get_job_status_df', _status_df_for_submit)
120+
121+
submit_calls = []
122+
123+
def _mock_submit_array(analysis_path, queue, total_jobs):
124+
submit_calls.append((analysis_path, queue, total_jobs))
125+
return 123
126+
127+
monkeypatch.setattr('babs.interaction.submit_array', _mock_submit_array)
128+
129+
babs_proj.babs_submit(skip_running_jobs=True)
130+
131+
captured = capsys.readouterr()
132+
assert submit_calls
133+
assert submit_calls[0][2] == 2
134+
assert 'Skipping running/pending jobs from job IDs' in captured.out
135+
assert '10' in captured.out
136+
137+
138+
def test_get_currently_running_jobs_df_multiple_job_ids(babs_project_subjectlevel, monkeypatch):
139+
babs_proj = BABSInteraction(project_root=babs_project_subjectlevel)
140+
status_df = pd.DataFrame(
141+
{
142+
'sub_id': ['sub-01', 'sub-02'],
143+
'submitted': [True, True],
144+
'has_results': [False, False],
145+
'is_failed': [False, False],
146+
'job_id': [10, 20],
147+
'task_id': [1, 2],
148+
}
149+
)
150+
monkeypatch.setattr(babs_proj, 'get_job_status_df', lambda: status_df)
151+
monkeypatch.setattr(babs_proj, 'get_latest_submitted_jobs_df', pd.DataFrame)
152+
153+
calls = []
154+
155+
def _mock_request_all_job_status(queue, job_id):
156+
calls.append(job_id)
157+
task_id = 1 if job_id == 10 else 2
158+
return pd.DataFrame(
159+
{
160+
'job_id': [job_id],
161+
'task_id': [task_id],
162+
'state': ['R'],
163+
'time_used': ['0:01'],
164+
'time_limit': ['5-00:00:00'],
165+
'nodes': [1],
166+
'cpus': [1],
167+
'partition': ['normal'],
168+
'name': ['test_array_job'],
169+
}
170+
)[scheduler_status_columns]
171+
172+
monkeypatch.setattr('babs.base.request_all_job_status', _mock_request_all_job_status)
173+
174+
running_df = babs_proj.get_currently_running_jobs_df()
175+
176+
assert set(calls) == {10, 20}
177+
assert set(running_df['sub_id']) == {'sub-01', 'sub-02'}

0 commit comments

Comments
 (0)