|
7 | 7 |
|
8 | 8 | import os |
9 | 9 | import sys |
10 | | -from dataclasses import dataclass |
| 10 | +from dataclasses import dataclass, field |
11 | 11 | from typing import Optional |
12 | 12 |
|
13 | 13 | import httpx |
@@ -42,6 +42,28 @@ def name(self) -> str: |
42 | 42 | return self.job_id |
43 | 43 |
|
44 | 44 |
|
| 45 | +@dataclass |
| 46 | +class ModalBudgetWindowBatchExecution: |
| 47 | + """ |
| 48 | + Represents a budget-window batch execution in the Modal simulation API. |
| 49 | + """ |
| 50 | + |
| 51 | + batch_job_id: str |
| 52 | + status: str |
| 53 | + progress: Optional[int] = None |
| 54 | + completed_years: list[str] = field(default_factory=list) |
| 55 | + running_years: list[str] = field(default_factory=list) |
| 56 | + queued_years: list[str] = field(default_factory=list) |
| 57 | + failed_years: list[str] = field(default_factory=list) |
| 58 | + result: Optional[dict] = None |
| 59 | + error: Optional[str] = None |
| 60 | + |
| 61 | + @property |
| 62 | + def name(self) -> str: |
| 63 | + """Alias for batch_job_id.""" |
| 64 | + return self.batch_job_id |
| 65 | + |
| 66 | + |
45 | 67 | class SimulationAPIModal: |
46 | 68 | """ |
47 | 69 | HTTP client for the Modal Simulation API. |
@@ -154,6 +176,57 @@ def run(self, payload: dict) -> ModalSimulationExecution: |
154 | 176 | ) |
155 | 177 | raise |
156 | 178 |
|
| 179 | + def run_budget_window_batch(self, payload: dict) -> ModalBudgetWindowBatchExecution: |
| 180 | + """ |
| 181 | + Submit a budget-window batch job to the Modal API. |
| 182 | + """ |
| 183 | + try: |
| 184 | + modal_payload = dict(payload) |
| 185 | + if "model_version" in modal_payload: |
| 186 | + modal_payload["version"] = modal_payload.pop("model_version") |
| 187 | + modal_payload.pop("data_version", None) |
| 188 | + |
| 189 | + response = self.client.post( |
| 190 | + f"{self.base_url}/simulate/economy/budget-window", |
| 191 | + json=modal_payload, |
| 192 | + ) |
| 193 | + response.raise_for_status() |
| 194 | + data = response.json() |
| 195 | + |
| 196 | + logger.log_struct( |
| 197 | + { |
| 198 | + "message": "Modal budget-window batch submitted", |
| 199 | + "batch_job_id": data.get("batch_job_id"), |
| 200 | + "status": data.get("status"), |
| 201 | + }, |
| 202 | + severity="INFO", |
| 203 | + ) |
| 204 | + |
| 205 | + return ModalBudgetWindowBatchExecution( |
| 206 | + batch_job_id=data["batch_job_id"], |
| 207 | + status=data["status"], |
| 208 | + ) |
| 209 | + |
| 210 | + except httpx.HTTPStatusError as e: |
| 211 | + logger.log_struct( |
| 212 | + { |
| 213 | + "message": f"Modal batch API HTTP error: {e.response.status_code}", |
| 214 | + "response_text": e.response.text[:500], |
| 215 | + }, |
| 216 | + severity="ERROR", |
| 217 | + ) |
| 218 | + raise |
| 219 | + |
| 220 | + except httpx.RequestError as e: |
| 221 | + logger.log_struct( |
| 222 | + { |
| 223 | + "message": f"Modal batch API request error: {str(e)}", |
| 224 | + "run_id": (payload.get("_telemetry") or {}).get("run_id"), |
| 225 | + }, |
| 226 | + severity="ERROR", |
| 227 | + ) |
| 228 | + raise |
| 229 | + |
157 | 230 | def resolve_app_name( |
158 | 231 | self, country: str, version: Optional[str] = None |
159 | 232 | ) -> tuple[str, str]: |
@@ -235,6 +308,51 @@ def get_execution_by_id(self, job_id: str) -> ModalSimulationExecution: |
235 | 308 | ) |
236 | 309 | raise |
237 | 310 |
|
| 311 | + def get_budget_window_batch_by_id( |
| 312 | + self, batch_job_id: str |
| 313 | + ) -> ModalBudgetWindowBatchExecution: |
| 314 | + """ |
| 315 | + Poll the Modal API for the current status of a budget-window batch. |
| 316 | + """ |
| 317 | + try: |
| 318 | + response = self.client.get( |
| 319 | + f"{self.base_url}/budget-window-jobs/{batch_job_id}" |
| 320 | + ) |
| 321 | + if response.status_code not in (200, 202, 500): |
| 322 | + response.raise_for_status() |
| 323 | + data = response.json() |
| 324 | + |
| 325 | + return ModalBudgetWindowBatchExecution( |
| 326 | + batch_job_id=batch_job_id, |
| 327 | + status=data["status"], |
| 328 | + progress=data.get("progress"), |
| 329 | + completed_years=data.get("completed_years", []), |
| 330 | + running_years=data.get("running_years", []), |
| 331 | + queued_years=data.get("queued_years", []), |
| 332 | + failed_years=data.get("failed_years", []), |
| 333 | + result=data.get("result"), |
| 334 | + error=data.get("error"), |
| 335 | + ) |
| 336 | + |
| 337 | + except httpx.HTTPStatusError as e: |
| 338 | + logger.log_struct( |
| 339 | + { |
| 340 | + "message": f"Modal batch API HTTP error polling job {batch_job_id}: {e.response.status_code}", |
| 341 | + "response_text": e.response.text[:500], |
| 342 | + }, |
| 343 | + severity="ERROR", |
| 344 | + ) |
| 345 | + raise |
| 346 | + |
| 347 | + except httpx.RequestError as e: |
| 348 | + logger.log_struct( |
| 349 | + { |
| 350 | + "message": f"Modal batch API request error polling job {batch_job_id}: {str(e)}", |
| 351 | + }, |
| 352 | + severity="ERROR", |
| 353 | + ) |
| 354 | + raise |
| 355 | + |
238 | 356 | def get_execution_status(self, execution: ModalSimulationExecution) -> str: |
239 | 357 | """ |
240 | 358 | Get the status string from an execution. |
|
0 commit comments