|
7 | 7 | from typing import Type |
8 | 8 | import pandas as pd |
9 | 9 | import numpy as np |
| 10 | +from google.cloud import workflows_v1 |
| 11 | +from google.cloud.workflows import executions_v1 |
| 12 | +from typing import Tuple |
10 | 13 |
|
11 | 14 | from policyengine_api.jobs import BaseJob |
12 | 15 | from policyengine_api.jobs.tasks import compute_general_economy |
|
23 | 26 |
|
24 | 27 | from policyengine_us import Microsimulation |
25 | 28 | from policyengine_uk import Microsimulation |
| 29 | +import logging |
26 | 30 |
|
27 | 31 | reform_impacts_service = ReformImpactsService() |
28 | 32 |
|
|
33 | 37 | CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5" |
34 | 38 | POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5" |
35 | 39 |
|
| 40 | +check_against_api_v2 = ( |
| 41 | + os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None |
| 42 | +) |
| 43 | + |
| 44 | +if not check_against_api_v2: |
| 45 | + logging.warn( |
| 46 | + "Didn't find any GOOGLE_APPLICATION_CREDENTIALS, so will not check results for matches against APIv2." |
| 47 | + ) |
| 48 | + |
36 | 49 |
|
37 | 50 | class CalculateEconomySimulationJob(BaseJob): |
38 | 51 | def __init__(self): |
39 | 52 | super().__init__() |
| 53 | + if check_against_api_v2: |
| 54 | + self.api_v2 = SimulationAPIv2() |
40 | 55 |
|
41 | 56 | def run( |
42 | 57 | self, |
@@ -136,6 +151,17 @@ def run( |
136 | 151 | comment = lambda x: set_comment_on_job(x, *identifiers) |
137 | 152 | comment("Computing baseline") |
138 | 153 |
|
| 154 | + # Kick off APIv2 job |
| 155 | + if check_against_api_v2: |
| 156 | + input_data = { |
| 157 | + "country": country_id, |
| 158 | + "scope": "macro", |
| 159 | + "reform": json.loads(reform_policy), |
| 160 | + "baseline": json.loads(baseline_policy), |
| 161 | + "time_period": time_period, |
| 162 | + } |
| 163 | + execution = self.api_v2.run(input_data) |
| 164 | + |
139 | 165 | # Compute baseline economy |
140 | 166 | baseline_economy = self._compute_economy( |
141 | 167 | country_id=country_id, |
@@ -164,6 +190,19 @@ def run( |
164 | 190 | baseline_economy, reform_economy, country_id=country_id |
165 | 191 | ) |
166 | 192 |
|
| 193 | + # Wait for APIv2 job to complete |
| 194 | + if check_against_api_v2: |
| 195 | + result = self.api_v2.wait_for_completion(execution) |
| 196 | + if result is None: |
| 197 | + print("APIv2 COMPARISON failed: result is not JSON.") |
| 198 | + else: |
| 199 | + try: |
| 200 | + print( |
| 201 | + f"APIv2 COMPARISON: match={is_similar(result, json.loads(json.dumps(impact)))}" |
| 202 | + ) |
| 203 | + except: |
| 204 | + print("APIv2 COMPARISON: ERROR COMPARING", result) |
| 205 | + |
167 | 206 | # Finally, update all reform impact rows with the same baseline and reform policy IDs |
168 | 207 | reform_impacts_service.set_complete_reform_impact( |
169 | 208 | country_id=country_id, |
@@ -360,6 +399,9 @@ def _create_simulation_us( |
360 | 399 | else: |
361 | 400 | sim_options["dataset"] = df[state_code == region.upper()] |
362 | 401 |
|
| 402 | + if dataset == "default" and region == "us": |
| 403 | + sim_options["dataset"] = CPS |
| 404 | + |
363 | 405 | # Return completed simulation |
364 | 406 | return Microsimulation(**sim_options) |
365 | 407 |
|
@@ -419,3 +461,178 @@ def _compute_cliff_impacts(self, simulation: Microsimulation) -> Dict: |
419 | 461 | "cliff_share": float(cliff_share), |
420 | 462 | "type": "cliff", |
421 | 463 | } |
| 464 | + |
| 465 | + |
| 466 | +def is_similar(x, y, parent_name: str = "") -> bool: |
| 467 | + # Handle None values |
| 468 | + if x is None or y is None: |
| 469 | + equal = x is y |
| 470 | + if not equal: |
| 471 | + print(f"Not equal: {x} vs {y} in {parent_name}") |
| 472 | + return equal |
| 473 | + |
| 474 | + # Handle different types |
| 475 | + if type(x) != type(y): |
| 476 | + if float in ((type(x), type(y))) and int in ((type(x), type(y))): |
| 477 | + pass |
| 478 | + else: |
| 479 | + print(f"Different types: {type(x)} vs {type(y)} in {parent_name}") |
| 480 | + return False |
| 481 | + |
| 482 | + # Handle numeric values |
| 483 | + if isinstance(x, (int, float)): |
| 484 | + if x == 0: |
| 485 | + close = y == 0 |
| 486 | + else: |
| 487 | + close = (abs(y - x) / abs(x) < 0.01) or (abs(y - x) < 1e-2) |
| 488 | + if not close: |
| 489 | + print(f"Not close: {x} vs {y} in {parent_name}") |
| 490 | + return close |
| 491 | + |
| 492 | + # Handle boolean values |
| 493 | + elif isinstance(x, bool): |
| 494 | + equal = x == y |
| 495 | + if not equal: |
| 496 | + print(f"Not equal: {x} vs {y} in {parent_name}") |
| 497 | + return equal |
| 498 | + |
| 499 | + # Handle string values |
| 500 | + elif isinstance(x, str): |
| 501 | + equal = x == y |
| 502 | + if not equal: |
| 503 | + print(f"Not equal: {x} vs {y} in {parent_name}") |
| 504 | + return equal |
| 505 | + |
| 506 | + # Handle dictionaries |
| 507 | + elif isinstance(x, dict): |
| 508 | + # Check for keys in both dictionaries |
| 509 | + all_keys = set(x.keys()) | set(y.keys()) |
| 510 | + for k in all_keys: |
| 511 | + if k not in x: |
| 512 | + print(f"Key {k} missing in first dict in {parent_name}") |
| 513 | + return False |
| 514 | + if k not in y: |
| 515 | + print(f"Key {k} missing in second dict in {parent_name}") |
| 516 | + return False |
| 517 | + if not is_similar(x[k], y[k], parent_name=parent_name + "/" + k): |
| 518 | + return False |
| 519 | + return True |
| 520 | + |
| 521 | + # Handle lists |
| 522 | + elif isinstance(x, list): |
| 523 | + if len(x) != len(y): |
| 524 | + print(f"Different lengths: {len(x)} vs {len(y)} in {parent_name}") |
| 525 | + return False |
| 526 | + return all( |
| 527 | + is_similar(x[i], y[i], parent_name=parent_name + f"[{i}]") |
| 528 | + for i in range(len(x)) |
| 529 | + ) |
| 530 | + |
| 531 | + # Handle other types |
| 532 | + else: |
| 533 | + equal = x == y |
| 534 | + if not equal: |
| 535 | + print(f"Not equal: {x} vs {y} in {parent_name}") |
| 536 | + return equal |
| 537 | + |
| 538 | + |
| 539 | +class SimulationAPIv2: |
| 540 | + project: str |
| 541 | + location: str |
| 542 | + workflow: str |
| 543 | + |
| 544 | + def __init__(self): |
| 545 | + self.project = "prod-api-v2-c4d5" |
| 546 | + self.location = "us-central1" |
| 547 | + self.workflow = "simulation-workflow" |
| 548 | + |
| 549 | + def run(self, payload: dict) -> executions_v1.Execution: |
| 550 | + """ |
| 551 | + Run a simulation using the v2 API |
| 552 | +
|
| 553 | + Parameters: |
| 554 | + ----------- |
| 555 | + payload : dict |
| 556 | + The payload to send to the API |
| 557 | +
|
| 558 | + Returns: |
| 559 | + -------- |
| 560 | + execution : executions_v1.Execution |
| 561 | + The execution object |
| 562 | + """ |
| 563 | + self.execution_client = executions_v1.ExecutionsClient() |
| 564 | + self.workflows_client = workflows_v1.WorkflowsClient() |
| 565 | + json_input = json.dumps(payload) |
| 566 | + workflow_path = self.workflows_client.workflow_path( |
| 567 | + self.project, self.location, self.workflow |
| 568 | + ) |
| 569 | + execution = self.execution_client.create_execution( |
| 570 | + parent=workflow_path, |
| 571 | + execution=executions_v1.Execution(argument=json_input), |
| 572 | + ) |
| 573 | + return execution |
| 574 | + |
| 575 | + def get_execution_status(self, execution: executions_v1.Execution) -> str: |
| 576 | + """ |
| 577 | + Get the status of an execution |
| 578 | +
|
| 579 | + Parameters: |
| 580 | + ----------- |
| 581 | + execution : executions_v1.Execution |
| 582 | + The execution object |
| 583 | +
|
| 584 | + Returns: |
| 585 | + -------- |
| 586 | + status : str |
| 587 | + The status of the execution |
| 588 | + """ |
| 589 | + return self.execution_client.get_execution( |
| 590 | + name=execution.name |
| 591 | + ).state.name |
| 592 | + |
| 593 | + def get_execution_result( |
| 594 | + self, execution: executions_v1.Execution |
| 595 | + ) -> dict | None: |
| 596 | + """ |
| 597 | + Get the result of an execution |
| 598 | +
|
| 599 | + Parameters: |
| 600 | + ----------- |
| 601 | + execution : executions_v1.Execution |
| 602 | + The execution object |
| 603 | +
|
| 604 | + Returns: |
| 605 | + -------- |
| 606 | + result : str |
| 607 | + The result of the execution |
| 608 | + """ |
| 609 | + result = self.execution_client.get_execution( |
| 610 | + name=execution.name |
| 611 | + ).result |
| 612 | + try: |
| 613 | + return json.loads(result) |
| 614 | + except: |
| 615 | + return None |
| 616 | + return result |
| 617 | + |
| 618 | + def wait_for_completion( |
| 619 | + self, execution: executions_v1.Execution |
| 620 | + ) -> dict | None: |
| 621 | + """ |
| 622 | + Wait for an execution to complete |
| 623 | +
|
| 624 | + Parameters: |
| 625 | + ----------- |
| 626 | + execution : executions_v1.Execution |
| 627 | + The execution object |
| 628 | +
|
| 629 | + Returns: |
| 630 | + -------- |
| 631 | + result : str |
| 632 | + The result of the execution |
| 633 | + """ |
| 634 | + while self.get_execution_status(execution) == "ACTIVE": |
| 635 | + time.sleep(5) |
| 636 | + print("Waiting for APIv2 job to complete...") |
| 637 | + |
| 638 | + return self.get_execution_result(execution) |
0 commit comments