Skip to content

Commit 64861fc

Browse files
committed
redo subprocess determinism tests
1 parent 3ddeb6d commit 64861fc

File tree

11 files changed

+766
-656
lines changed

11 files changed

+766
-656
lines changed

packages/llama-agents-dbos/tests/fixtures/__init__.py

Whitespace-only changes.
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# SPDX-FileCopyrightText: 2025-present LlamaIndex Inc. <[email protected]>
2+
# SPDX-License-Identifier: MIT
3+
"""Subprocess runner for workflow tests with DBOS isolation.
4+
5+
This module provides a CLI runner for executing workflows in isolated
6+
subprocesses, supporting interrupt/resume testing and human-in-the-loop
7+
response simulation.
8+
9+
Usage:
10+
python /path/to/packages/llama-agents-dbos/tests/fixtures/runner.py \
11+
--workflow "tests.fixtures.workflows.hitl:TestWorkflow" \
12+
--db-url "sqlite+pysqlite:///path/to/db" \
13+
--run-id "test-001" \
14+
--config '{"interrupt_on": "AskInputEvent"}'
15+
16+
Config modes:
17+
- interrupt_on: Interrupt when event type is seen (uses os._exit(0))
18+
- String form: "EventName" - interrupt on any instance of EventName
19+
- Dict form: {"event": "EventName", "condition": {"field": value}}
20+
- interrupt only when type matches AND all condition fields match
21+
- respond: Respond to InputRequiredEvent subtypes with specified events
22+
- run-to-completion: Empty config or omit both fields
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import argparse
28+
import asyncio
29+
import importlib
30+
import json
31+
import os
32+
import sys
33+
from pathlib import Path
34+
from types import ModuleType
35+
from typing import Any
36+
37+
# Add package source directories to sys.path for imports
38+
# Runner is at: packages/llama-agents-dbos/tests/fixtures/runner.py
39+
# We need to add:
40+
# - packages/llama-agents-dbos/src for llama_agents.dbos
41+
# - packages/llama-index-workflows/src for workflows.*
42+
# - packages/llama-agents-dbos (parent of tests/) so tests.fixtures.workflows.* can be imported
43+
TESTS_DIR = Path(__file__).parent.parent
44+
DBOS_PACKAGE_DIR = TESTS_DIR.parent
45+
DBOS_PACKAGE_SRC_PATH = str(DBOS_PACKAGE_DIR / "src")
46+
WORKFLOWS_PACKAGE_SRC_PATH = str(
47+
DBOS_PACKAGE_DIR.parent / "llama-index-workflows" / "src"
48+
)
49+
50+
# Insert at front of path so these packages take precedence
51+
# Add the parent of tests/ so "import tests.fixtures.workflows..." works
52+
sys.path.insert(0, str(DBOS_PACKAGE_DIR))
53+
sys.path.insert(0, DBOS_PACKAGE_SRC_PATH)
54+
sys.path.insert(0, WORKFLOWS_PACKAGE_SRC_PATH)
55+
56+
from dbos import DBOS, DBOSConfig # noqa: E402
57+
from llama_agents.dbos import DBOSRuntime # noqa: E402
58+
from workflows.context import Context # noqa: E402
59+
from workflows.events import Event, InputRequiredEvent, StartEvent # noqa: E402
60+
from workflows.workflow import Workflow # noqa: E402
61+
62+
63+
def import_workflow(path: str) -> tuple[type[Workflow], ModuleType]:
64+
"""Import a workflow class from a module path.
65+
66+
Args:
67+
path: Module path with class name, e.g., "tests.fixtures.workflows.hitl:TestWorkflow"
68+
69+
Returns:
70+
Tuple of (workflow_class, module) for accessing classes defined in the module.
71+
72+
Raises:
73+
ValueError: If path format is invalid.
74+
ImportError: If module cannot be imported.
75+
AttributeError: If class not found in module.
76+
"""
77+
if ":" not in path:
78+
raise ValueError(
79+
f"Invalid workflow path format: {path}. Expected 'module.path:ClassName'"
80+
)
81+
module_path, class_name = path.rsplit(":", 1)
82+
module = importlib.import_module(module_path)
83+
workflow_class = getattr(module, class_name)
84+
if not (isinstance(workflow_class, type) and issubclass(workflow_class, Workflow)):
85+
raise TypeError(f"{class_name} is not a Workflow subclass")
86+
return workflow_class, module
87+
88+
89+
def get_event_class_by_name(module: ModuleType, name: str) -> type[Event] | None:
90+
"""Find an event class in a module by its name.
91+
92+
Searches through all attributes of the module to find an Event subclass
93+
with a matching class name.
94+
95+
Args:
96+
module: The module to search in.
97+
name: The class name to find.
98+
99+
Returns:
100+
The event class if found, None otherwise.
101+
"""
102+
for attr_name in dir(module):
103+
attr = getattr(module, attr_name)
104+
if isinstance(attr, type) and issubclass(attr, Event) and attr.__name__ == name:
105+
return attr
106+
return None
107+
108+
109+
def parse_config(config_json: str | None) -> dict[str, Any]:
110+
"""Parse the JSON config string.
111+
112+
Args:
113+
config_json: JSON string with configuration, or None.
114+
115+
Returns:
116+
Parsed config dict, or empty dict if None.
117+
"""
118+
if not config_json:
119+
return {}
120+
return json.loads(config_json)
121+
122+
123+
def setup_dbos(db_url: str, app_name: str = "test-workflow") -> DBOSRuntime:
124+
"""Set up DBOS with the given database URL.
125+
126+
Args:
127+
db_url: SQLite database URL.
128+
app_name: Application name for DBOS config.
129+
130+
Returns:
131+
Configured DBOSRuntime instance.
132+
"""
133+
config: DBOSConfig = {
134+
"name": app_name,
135+
"system_database_url": db_url,
136+
"run_admin_server": False,
137+
"internal_polling_interval_sec": 0.01, # type: ignore[typeddict-unknown-key]
138+
}
139+
DBOS(config=config)
140+
return DBOSRuntime(polling_interval_sec=0.01)
141+
142+
143+
async def run_workflow(
144+
workflow_path: str,
145+
db_url: str,
146+
run_id: str,
147+
config: dict[str, Any],
148+
) -> None:
149+
"""Run the workflow with the specified configuration.
150+
151+
Args:
152+
workflow_path: Module path with class name.
153+
db_url: SQLite database URL.
154+
run_id: Unique run ID for the workflow.
155+
config: Configuration dict with interrupt_on and/or respond settings.
156+
"""
157+
# Import workflow and get module for event class lookup
158+
workflow_class, module = import_workflow(workflow_path)
159+
160+
# Parse config options
161+
interrupt_on_config = config.get("interrupt_on")
162+
respond_config = config.get("respond", {})
163+
164+
# Resolve interrupt config (can be string or dict with condition)
165+
interrupt_event_class: type[Event] | None = None
166+
interrupt_condition: dict[str, Any] | None = None
167+
if interrupt_on_config:
168+
if isinstance(interrupt_on_config, str):
169+
interrupt_event_name = interrupt_on_config
170+
else:
171+
interrupt_event_name = interrupt_on_config.get("event")
172+
interrupt_condition = interrupt_on_config.get("condition")
173+
interrupt_event_class = get_event_class_by_name(module, interrupt_event_name)
174+
if interrupt_event_class is None:
175+
print(
176+
f"ERROR:ValueError:Event class '{interrupt_event_name}' not found in module"
177+
)
178+
sys.exit(1)
179+
180+
# Build response event mapping: {trigger_class: (response_class, fields)}
181+
response_map: dict[type[Event], tuple[type[Event], dict[str, Any]]] = {}
182+
for trigger_name, response_info in respond_config.items():
183+
trigger_class = get_event_class_by_name(module, trigger_name)
184+
if trigger_class is None:
185+
print(
186+
f"ERROR:ValueError:Trigger event class '{trigger_name}' not found in module"
187+
)
188+
sys.exit(1)
189+
response_event_name = response_info.get("event")
190+
response_fields = response_info.get("fields", {})
191+
response_class = get_event_class_by_name(module, response_event_name)
192+
if response_class is None:
193+
print(
194+
f"ERROR:ValueError:Response event class '{response_event_name}' not found in module"
195+
)
196+
sys.exit(1)
197+
# Both trigger_class and response_class are narrowed after sys.exit(1) guards
198+
assert trigger_class is not None
199+
assert response_class is not None
200+
response_map[trigger_class] = (response_class, response_fields)
201+
202+
# Set up DBOS and runtime
203+
runtime = setup_dbos(db_url)
204+
205+
# Create workflow instance and launch
206+
wf = workflow_class(runtime=runtime)
207+
runtime.launch()
208+
209+
try:
210+
ctx = Context(wf)
211+
handler = ctx._workflow_run(wf, StartEvent(), run_id=run_id)
212+
213+
async for event in handler.stream_events():
214+
event_name = type(event).__name__
215+
print(f"EVENT:{event_name}", flush=True)
216+
217+
# Check for interrupt condition
218+
if interrupt_event_class is not None and isinstance(
219+
event, interrupt_event_class
220+
):
221+
# Check condition fields if present
222+
should_interrupt = True
223+
if interrupt_condition:
224+
for field, expected_value in interrupt_condition.items():
225+
actual_value = getattr(event, field, None)
226+
if actual_value != expected_value:
227+
should_interrupt = False
228+
break
229+
if should_interrupt:
230+
print("INTERRUPTING", flush=True)
231+
os._exit(0)
232+
233+
# Check for response condition (InputRequiredEvent subtypes)
234+
if isinstance(event, InputRequiredEvent):
235+
for trigger_class, (response_class, fields) in response_map.items():
236+
if isinstance(event, trigger_class):
237+
if handler.ctx:
238+
response_event = response_class(**fields)
239+
handler.ctx.send_event(response_event)
240+
break
241+
242+
result = await handler
243+
print(f"RESULT:{result}", flush=True)
244+
print("SUCCESS", flush=True)
245+
246+
except Exception as e:
247+
print(f"ERROR:{type(e).__name__}:{e}", flush=True)
248+
raise
249+
250+
finally:
251+
runtime.destroy()
252+
253+
254+
def main() -> None:
255+
"""Entry point for the subprocess runner."""
256+
parser = argparse.ArgumentParser(
257+
description="Run workflows in isolated subprocesses for testing"
258+
)
259+
parser.add_argument(
260+
"--workflow",
261+
required=True,
262+
help="Module path with class name (e.g., 'tests.fixtures.workflows.hitl:TestWorkflow')",
263+
)
264+
parser.add_argument(
265+
"--db-url",
266+
required=True,
267+
help="SQLite database URL",
268+
)
269+
parser.add_argument(
270+
"--run-id",
271+
required=True,
272+
help="Unique run ID for the workflow",
273+
)
274+
parser.add_argument(
275+
"--config",
276+
default=None,
277+
help="JSON string with configuration",
278+
)
279+
280+
args = parser.parse_args()
281+
282+
config = parse_config(args.config)
283+
284+
asyncio.run(
285+
run_workflow(
286+
workflow_path=args.workflow,
287+
db_url=args.db_url,
288+
run_id=args.run_id,
289+
config=config,
290+
)
291+
)
292+
293+
294+
if __name__ == "__main__":
295+
main()

packages/llama-agents-dbos/tests/fixtures/workflows/__init__.py

Whitespace-only changes.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-FileCopyrightText: 2025-present LlamaIndex Inc. <[email protected]>
2+
# SPDX-License-Identifier: MIT
3+
"""Chained workflow fixture with StepOneEvent, StepTwoEvent, and ChainedWorkflow."""
4+
5+
from __future__ import annotations
6+
7+
from pydantic import Field
8+
from workflows.context import Context
9+
from workflows.decorators import step
10+
from workflows.events import Event, StartEvent, StopEvent
11+
from workflows.workflow import Workflow
12+
13+
14+
class StepOneEvent(Event):
15+
value: str = Field(default="one")
16+
17+
18+
class StepTwoEvent(Event):
19+
value: str = Field(default="two")
20+
21+
22+
class ChainedWorkflow(Workflow):
23+
@step
24+
async def step_one(self, ctx: Context, ev: StartEvent) -> StepOneEvent:
25+
await ctx.store.set("step_one", True)
26+
print("STEP:one:complete", flush=True)
27+
return StepOneEvent()
28+
29+
@step
30+
async def step_two(self, ctx: Context, ev: StepOneEvent) -> StepTwoEvent:
31+
await ctx.store.set("step_two", True)
32+
print("STEP:two:complete", flush=True)
33+
return StepTwoEvent()
34+
35+
@step
36+
async def step_three(self, ctx: Context, ev: StepTwoEvent) -> StopEvent:
37+
await ctx.store.set("step_three", True)
38+
print("STEP:three:complete", flush=True)
39+
return StopEvent(result="done")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-FileCopyrightText: 2025-present LlamaIndex Inc. <[email protected]>
2+
# SPDX-License-Identifier: MIT
3+
"""Concurrent workers workflow fixture with num_workers=2."""
4+
5+
from __future__ import annotations
6+
7+
import asyncio
8+
import random
9+
10+
from pydantic import Field
11+
from workflows.context import Context
12+
from workflows.decorators import step
13+
from workflows.events import Event, StartEvent, StopEvent
14+
from workflows.workflow import Workflow
15+
16+
17+
class WorkItem(Event):
18+
item_id: int = Field(default=0)
19+
20+
21+
class WorkDone(Event):
22+
item_id: int = Field(default=0)
23+
24+
25+
class ConcurrentWorkersWorkflow(Workflow):
26+
@step
27+
async def dispatch(self, ctx: Context, ev: StartEvent) -> WorkItem:
28+
# Dispatch work items that will be processed by concurrent workers
29+
ctx.send_event(WorkItem(item_id=1))
30+
ctx.send_event(WorkItem(item_id=2))
31+
print("STEP:dispatch:complete", flush=True)
32+
return WorkItem(item_id=0)
33+
34+
@step(num_workers=2)
35+
async def worker(self, ctx: Context, ev: WorkItem) -> WorkDone:
36+
# Variable processing time for each item
37+
await asyncio.sleep(random.uniform(0.01, 0.05))
38+
print(f"STEP:worker:{ev.item_id}:complete", flush=True)
39+
return WorkDone(item_id=ev.item_id)
40+
41+
@step
42+
async def finish(self, ctx: Context, ev: WorkDone) -> StopEvent:
43+
# First WorkDone to arrive ends the workflow
44+
print(f"STEP:finish:{ev.item_id}:complete", flush=True)
45+
return StopEvent(result={"first_done": ev.item_id})

0 commit comments

Comments
 (0)