|
| 1 | +# SPDX-FileCopyrightText: 2025-present LlamaIndex Inc. <[email protected]> |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | +"""Subprocess runner for workflow tests with DBOS isolation. |
| 4 | +
|
| 5 | +This module provides a CLI runner for executing workflows in isolated |
| 6 | +subprocesses, supporting interrupt/resume testing and human-in-the-loop |
| 7 | +response simulation. |
| 8 | +
|
| 9 | +Usage: |
| 10 | + python /path/to/packages/llama-agents-dbos/tests/fixtures/runner.py \ |
| 11 | + --workflow "tests.fixtures.workflows.hitl:TestWorkflow" \ |
| 12 | + --db-url "sqlite+pysqlite:///path/to/db" \ |
| 13 | + --run-id "test-001" \ |
| 14 | + --config '{"interrupt_on": "AskInputEvent"}' |
| 15 | +
|
| 16 | +Config modes: |
| 17 | + - interrupt_on: Interrupt when event type is seen (uses os._exit(0)) |
| 18 | + - String form: "EventName" - interrupt on any instance of EventName |
| 19 | + - Dict form: {"event": "EventName", "condition": {"field": value}} |
| 20 | + - interrupt only when type matches AND all condition fields match |
| 21 | + - respond: Respond to InputRequiredEvent subtypes with specified events |
| 22 | + - run-to-completion: Empty config or omit both fields |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import argparse |
| 28 | +import asyncio |
| 29 | +import importlib |
| 30 | +import json |
| 31 | +import os |
| 32 | +import sys |
| 33 | +from pathlib import Path |
| 34 | +from types import ModuleType |
| 35 | +from typing import Any |
| 36 | + |
| 37 | +# Add package source directories to sys.path for imports |
| 38 | +# Runner is at: packages/llama-agents-dbos/tests/fixtures/runner.py |
| 39 | +# We need to add: |
| 40 | +# - packages/llama-agents-dbos/src for llama_agents.dbos |
| 41 | +# - packages/llama-index-workflows/src for workflows.* |
| 42 | +# - packages/llama-agents-dbos (parent of tests/) so tests.fixtures.workflows.* can be imported |
| 43 | +TESTS_DIR = Path(__file__).parent.parent |
| 44 | +DBOS_PACKAGE_DIR = TESTS_DIR.parent |
| 45 | +DBOS_PACKAGE_SRC_PATH = str(DBOS_PACKAGE_DIR / "src") |
| 46 | +WORKFLOWS_PACKAGE_SRC_PATH = str( |
| 47 | + DBOS_PACKAGE_DIR.parent / "llama-index-workflows" / "src" |
| 48 | +) |
| 49 | + |
| 50 | +# Insert at front of path so these packages take precedence |
| 51 | +# Add the parent of tests/ so "import tests.fixtures.workflows..." works |
| 52 | +sys.path.insert(0, str(DBOS_PACKAGE_DIR)) |
| 53 | +sys.path.insert(0, DBOS_PACKAGE_SRC_PATH) |
| 54 | +sys.path.insert(0, WORKFLOWS_PACKAGE_SRC_PATH) |
| 55 | + |
| 56 | +from dbos import DBOS, DBOSConfig # noqa: E402 |
| 57 | +from llama_agents.dbos import DBOSRuntime # noqa: E402 |
| 58 | +from workflows.context import Context # noqa: E402 |
| 59 | +from workflows.events import Event, InputRequiredEvent, StartEvent # noqa: E402 |
| 60 | +from workflows.workflow import Workflow # noqa: E402 |
| 61 | + |
| 62 | + |
| 63 | +def import_workflow(path: str) -> tuple[type[Workflow], ModuleType]: |
| 64 | + """Import a workflow class from a module path. |
| 65 | +
|
| 66 | + Args: |
| 67 | + path: Module path with class name, e.g., "tests.fixtures.workflows.hitl:TestWorkflow" |
| 68 | +
|
| 69 | + Returns: |
| 70 | + Tuple of (workflow_class, module) for accessing classes defined in the module. |
| 71 | +
|
| 72 | + Raises: |
| 73 | + ValueError: If path format is invalid. |
| 74 | + ImportError: If module cannot be imported. |
| 75 | + AttributeError: If class not found in module. |
| 76 | + """ |
| 77 | + if ":" not in path: |
| 78 | + raise ValueError( |
| 79 | + f"Invalid workflow path format: {path}. Expected 'module.path:ClassName'" |
| 80 | + ) |
| 81 | + module_path, class_name = path.rsplit(":", 1) |
| 82 | + module = importlib.import_module(module_path) |
| 83 | + workflow_class = getattr(module, class_name) |
| 84 | + if not (isinstance(workflow_class, type) and issubclass(workflow_class, Workflow)): |
| 85 | + raise TypeError(f"{class_name} is not a Workflow subclass") |
| 86 | + return workflow_class, module |
| 87 | + |
| 88 | + |
| 89 | +def get_event_class_by_name(module: ModuleType, name: str) -> type[Event] | None: |
| 90 | + """Find an event class in a module by its name. |
| 91 | +
|
| 92 | + Searches through all attributes of the module to find an Event subclass |
| 93 | + with a matching class name. |
| 94 | +
|
| 95 | + Args: |
| 96 | + module: The module to search in. |
| 97 | + name: The class name to find. |
| 98 | +
|
| 99 | + Returns: |
| 100 | + The event class if found, None otherwise. |
| 101 | + """ |
| 102 | + for attr_name in dir(module): |
| 103 | + attr = getattr(module, attr_name) |
| 104 | + if isinstance(attr, type) and issubclass(attr, Event) and attr.__name__ == name: |
| 105 | + return attr |
| 106 | + return None |
| 107 | + |
| 108 | + |
| 109 | +def parse_config(config_json: str | None) -> dict[str, Any]: |
| 110 | + """Parse the JSON config string. |
| 111 | +
|
| 112 | + Args: |
| 113 | + config_json: JSON string with configuration, or None. |
| 114 | +
|
| 115 | + Returns: |
| 116 | + Parsed config dict, or empty dict if None. |
| 117 | + """ |
| 118 | + if not config_json: |
| 119 | + return {} |
| 120 | + return json.loads(config_json) |
| 121 | + |
| 122 | + |
| 123 | +def setup_dbos(db_url: str, app_name: str = "test-workflow") -> DBOSRuntime: |
| 124 | + """Set up DBOS with the given database URL. |
| 125 | +
|
| 126 | + Args: |
| 127 | + db_url: SQLite database URL. |
| 128 | + app_name: Application name for DBOS config. |
| 129 | +
|
| 130 | + Returns: |
| 131 | + Configured DBOSRuntime instance. |
| 132 | + """ |
| 133 | + config: DBOSConfig = { |
| 134 | + "name": app_name, |
| 135 | + "system_database_url": db_url, |
| 136 | + "run_admin_server": False, |
| 137 | + "internal_polling_interval_sec": 0.01, # type: ignore[typeddict-unknown-key] |
| 138 | + } |
| 139 | + DBOS(config=config) |
| 140 | + return DBOSRuntime(polling_interval_sec=0.01) |
| 141 | + |
| 142 | + |
| 143 | +async def run_workflow( |
| 144 | + workflow_path: str, |
| 145 | + db_url: str, |
| 146 | + run_id: str, |
| 147 | + config: dict[str, Any], |
| 148 | +) -> None: |
| 149 | + """Run the workflow with the specified configuration. |
| 150 | +
|
| 151 | + Args: |
| 152 | + workflow_path: Module path with class name. |
| 153 | + db_url: SQLite database URL. |
| 154 | + run_id: Unique run ID for the workflow. |
| 155 | + config: Configuration dict with interrupt_on and/or respond settings. |
| 156 | + """ |
| 157 | + # Import workflow and get module for event class lookup |
| 158 | + workflow_class, module = import_workflow(workflow_path) |
| 159 | + |
| 160 | + # Parse config options |
| 161 | + interrupt_on_config = config.get("interrupt_on") |
| 162 | + respond_config = config.get("respond", {}) |
| 163 | + |
| 164 | + # Resolve interrupt config (can be string or dict with condition) |
| 165 | + interrupt_event_class: type[Event] | None = None |
| 166 | + interrupt_condition: dict[str, Any] | None = None |
| 167 | + if interrupt_on_config: |
| 168 | + if isinstance(interrupt_on_config, str): |
| 169 | + interrupt_event_name = interrupt_on_config |
| 170 | + else: |
| 171 | + interrupt_event_name = interrupt_on_config.get("event") |
| 172 | + interrupt_condition = interrupt_on_config.get("condition") |
| 173 | + interrupt_event_class = get_event_class_by_name(module, interrupt_event_name) |
| 174 | + if interrupt_event_class is None: |
| 175 | + print( |
| 176 | + f"ERROR:ValueError:Event class '{interrupt_event_name}' not found in module" |
| 177 | + ) |
| 178 | + sys.exit(1) |
| 179 | + |
| 180 | + # Build response event mapping: {trigger_class: (response_class, fields)} |
| 181 | + response_map: dict[type[Event], tuple[type[Event], dict[str, Any]]] = {} |
| 182 | + for trigger_name, response_info in respond_config.items(): |
| 183 | + trigger_class = get_event_class_by_name(module, trigger_name) |
| 184 | + if trigger_class is None: |
| 185 | + print( |
| 186 | + f"ERROR:ValueError:Trigger event class '{trigger_name}' not found in module" |
| 187 | + ) |
| 188 | + sys.exit(1) |
| 189 | + response_event_name = response_info.get("event") |
| 190 | + response_fields = response_info.get("fields", {}) |
| 191 | + response_class = get_event_class_by_name(module, response_event_name) |
| 192 | + if response_class is None: |
| 193 | + print( |
| 194 | + f"ERROR:ValueError:Response event class '{response_event_name}' not found in module" |
| 195 | + ) |
| 196 | + sys.exit(1) |
| 197 | + # Both trigger_class and response_class are narrowed after sys.exit(1) guards |
| 198 | + assert trigger_class is not None |
| 199 | + assert response_class is not None |
| 200 | + response_map[trigger_class] = (response_class, response_fields) |
| 201 | + |
| 202 | + # Set up DBOS and runtime |
| 203 | + runtime = setup_dbos(db_url) |
| 204 | + |
| 205 | + # Create workflow instance and launch |
| 206 | + wf = workflow_class(runtime=runtime) |
| 207 | + runtime.launch() |
| 208 | + |
| 209 | + try: |
| 210 | + ctx = Context(wf) |
| 211 | + handler = ctx._workflow_run(wf, StartEvent(), run_id=run_id) |
| 212 | + |
| 213 | + async for event in handler.stream_events(): |
| 214 | + event_name = type(event).__name__ |
| 215 | + print(f"EVENT:{event_name}", flush=True) |
| 216 | + |
| 217 | + # Check for interrupt condition |
| 218 | + if interrupt_event_class is not None and isinstance( |
| 219 | + event, interrupt_event_class |
| 220 | + ): |
| 221 | + # Check condition fields if present |
| 222 | + should_interrupt = True |
| 223 | + if interrupt_condition: |
| 224 | + for field, expected_value in interrupt_condition.items(): |
| 225 | + actual_value = getattr(event, field, None) |
| 226 | + if actual_value != expected_value: |
| 227 | + should_interrupt = False |
| 228 | + break |
| 229 | + if should_interrupt: |
| 230 | + print("INTERRUPTING", flush=True) |
| 231 | + os._exit(0) |
| 232 | + |
| 233 | + # Check for response condition (InputRequiredEvent subtypes) |
| 234 | + if isinstance(event, InputRequiredEvent): |
| 235 | + for trigger_class, (response_class, fields) in response_map.items(): |
| 236 | + if isinstance(event, trigger_class): |
| 237 | + if handler.ctx: |
| 238 | + response_event = response_class(**fields) |
| 239 | + handler.ctx.send_event(response_event) |
| 240 | + break |
| 241 | + |
| 242 | + result = await handler |
| 243 | + print(f"RESULT:{result}", flush=True) |
| 244 | + print("SUCCESS", flush=True) |
| 245 | + |
| 246 | + except Exception as e: |
| 247 | + print(f"ERROR:{type(e).__name__}:{e}", flush=True) |
| 248 | + raise |
| 249 | + |
| 250 | + finally: |
| 251 | + runtime.destroy() |
| 252 | + |
| 253 | + |
| 254 | +def main() -> None: |
| 255 | + """Entry point for the subprocess runner.""" |
| 256 | + parser = argparse.ArgumentParser( |
| 257 | + description="Run workflows in isolated subprocesses for testing" |
| 258 | + ) |
| 259 | + parser.add_argument( |
| 260 | + "--workflow", |
| 261 | + required=True, |
| 262 | + help="Module path with class name (e.g., 'tests.fixtures.workflows.hitl:TestWorkflow')", |
| 263 | + ) |
| 264 | + parser.add_argument( |
| 265 | + "--db-url", |
| 266 | + required=True, |
| 267 | + help="SQLite database URL", |
| 268 | + ) |
| 269 | + parser.add_argument( |
| 270 | + "--run-id", |
| 271 | + required=True, |
| 272 | + help="Unique run ID for the workflow", |
| 273 | + ) |
| 274 | + parser.add_argument( |
| 275 | + "--config", |
| 276 | + default=None, |
| 277 | + help="JSON string with configuration", |
| 278 | + ) |
| 279 | + |
| 280 | + args = parser.parse_args() |
| 281 | + |
| 282 | + config = parse_config(args.config) |
| 283 | + |
| 284 | + asyncio.run( |
| 285 | + run_workflow( |
| 286 | + workflow_path=args.workflow, |
| 287 | + db_url=args.db_url, |
| 288 | + run_id=args.run_id, |
| 289 | + config=config, |
| 290 | + ) |
| 291 | + ) |
| 292 | + |
| 293 | + |
| 294 | +if __name__ == "__main__": |
| 295 | + main() |
0 commit comments