Skip to content

Commit 2c50fc0

Browse files
Move to Simulation API (#2262)
* Test with normal request * Versioning and scale down specs * Format * Move to Simulation API Fixes #2251 * Format * Add dep * Fix bug * Fix bug * Load JSON manually * Move to Simulation API Fixes #2251 * Format * Add comparison check * Add test * Update test * Add changes * Revert to originals * Format * Add safety check * Format * Revert changes to versioning * Move APIv2 interface to class * Allow float/int * Add changes to review * Format * Add type hinting * Make check against APIv2 * Add failsafe * Add extra failsafe * Use default creds * Format
1 parent 733753f commit 2c50fc0

7 files changed

Lines changed: 226 additions & 3 deletions

File tree

.github/workflows/pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
env:
7474
POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN: ${{ secrets.POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN }}
7575
HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }}
76-
POLICYENGINE_DB_PASSWORD: ${{ secrets.POLICYENGINE_DB_PASSWORD }}
76+
POLICYENGINE_DB_PASSWORD: ${{ secrets.POLICYENGINE_DB_PASSWORD }}
7777
test:
7878
name: Test
7979
runs-on: ubuntu-latest

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
install:
2-
pip install -e .[dev] --config-settings editable_mode=compat
2+
pip install -e ".[dev]" --config-settings editable_mode=compat
33

44
debug:
55
FLASK_APP=policyengine_api.api FLASK_DEBUG=1 flask run --without-threads

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: minor
2+
changes:
3+
changed:
4+
- Handle economy simulations in the Simulation API.

gcp/policyengine_api/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ ENV POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN .github_microdata_token
66
ENV ANTHROPIC_API_KEY .anthropic_api_key
77
ENV OPENAI_API_KEY .openai_api_key
88
ENV HUGGING_FACE_TOKEN .hugging_face_token
9+
ENV CREDENTIALS_JSON_API_V2 .credentials_json_api_v2
910

1011
WORKDIR /app
1112

policyengine_api/endpoints/economy/compare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def uk_constituency_breakdown(
570570
reform_hnet = reform["household_net_income"]
571571

572572
constituency_weights_path = download_huggingface_dataset(
573-
repo="policyengine/policyengine-uk-data-public",
573+
repo="policyengine/policyengine-uk-data",
574574
repo_filename="parliamentary_constituency_weights.h5",
575575
)
576576
with h5py.File(constituency_weights_path, "r") as f:

policyengine_api/jobs/calculate_economy_simulation_job.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from typing import Type
88
import pandas as pd
99
import numpy as np
10+
from google.cloud import workflows_v1
11+
from google.cloud.workflows import executions_v1
12+
from typing import Tuple
1013

1114
from policyengine_api.jobs import BaseJob
1215
from policyengine_api.jobs.tasks import compute_general_economy
@@ -23,6 +26,7 @@
2326

2427
from policyengine_us import Microsimulation
2528
from policyengine_uk import Microsimulation
29+
import logging
2630

2731
reform_impacts_service = ReformImpactsService()
2832

@@ -33,10 +37,21 @@
3337
CPS = "hf://policyengine/policyengine-us-data/cps_2023.h5"
3438
POOLED_CPS = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5"
3539

40+
check_against_api_v2 = (
41+
os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is not None
42+
)
43+
44+
if not check_against_api_v2:
45+
logging.warn(
46+
"Didn't find any GOOGLE_APPLICATION_CREDENTIALS, so will not check results for matches against APIv2."
47+
)
48+
3649

3750
class CalculateEconomySimulationJob(BaseJob):
3851
def __init__(self):
3952
super().__init__()
53+
if check_against_api_v2:
54+
self.api_v2 = SimulationAPIv2()
4055

4156
def run(
4257
self,
@@ -136,6 +151,17 @@ def run(
136151
comment = lambda x: set_comment_on_job(x, *identifiers)
137152
comment("Computing baseline")
138153

154+
# Kick off APIv2 job
155+
if check_against_api_v2:
156+
input_data = {
157+
"country": country_id,
158+
"scope": "macro",
159+
"reform": json.loads(reform_policy),
160+
"baseline": json.loads(baseline_policy),
161+
"time_period": time_period,
162+
}
163+
execution = self.api_v2.run(input_data)
164+
139165
# Compute baseline economy
140166
baseline_economy = self._compute_economy(
141167
country_id=country_id,
@@ -164,6 +190,19 @@ def run(
164190
baseline_economy, reform_economy, country_id=country_id
165191
)
166192

193+
# Wait for APIv2 job to complete
194+
if check_against_api_v2:
195+
result = self.api_v2.wait_for_completion(execution)
196+
if result is None:
197+
print("APIv2 COMPARISON failed: result is not JSON.")
198+
else:
199+
try:
200+
print(
201+
f"APIv2 COMPARISON: match={is_similar(result, json.loads(json.dumps(impact)))}"
202+
)
203+
except:
204+
print("APIv2 COMPARISON: ERROR COMPARING", result)
205+
167206
# Finally, update all reform impact rows with the same baseline and reform policy IDs
168207
reform_impacts_service.set_complete_reform_impact(
169208
country_id=country_id,
@@ -360,6 +399,9 @@ def _create_simulation_us(
360399
else:
361400
sim_options["dataset"] = df[state_code == region.upper()]
362401

402+
if dataset == "default" and region == "us":
403+
sim_options["dataset"] = CPS
404+
363405
# Return completed simulation
364406
return Microsimulation(**sim_options)
365407

@@ -419,3 +461,178 @@ def _compute_cliff_impacts(self, simulation: Microsimulation) -> Dict:
419461
"cliff_share": float(cliff_share),
420462
"type": "cliff",
421463
}
464+
465+
466+
def is_similar(x, y, parent_name: str = "") -> bool:
467+
# Handle None values
468+
if x is None or y is None:
469+
equal = x is y
470+
if not equal:
471+
print(f"Not equal: {x} vs {y} in {parent_name}")
472+
return equal
473+
474+
# Handle different types
475+
if type(x) != type(y):
476+
if float in ((type(x), type(y))) and int in ((type(x), type(y))):
477+
pass
478+
else:
479+
print(f"Different types: {type(x)} vs {type(y)} in {parent_name}")
480+
return False
481+
482+
# Handle numeric values
483+
if isinstance(x, (int, float)):
484+
if x == 0:
485+
close = y == 0
486+
else:
487+
close = (abs(y - x) / abs(x) < 0.01) or (abs(y - x) < 1e-2)
488+
if not close:
489+
print(f"Not close: {x} vs {y} in {parent_name}")
490+
return close
491+
492+
# Handle boolean values
493+
elif isinstance(x, bool):
494+
equal = x == y
495+
if not equal:
496+
print(f"Not equal: {x} vs {y} in {parent_name}")
497+
return equal
498+
499+
# Handle string values
500+
elif isinstance(x, str):
501+
equal = x == y
502+
if not equal:
503+
print(f"Not equal: {x} vs {y} in {parent_name}")
504+
return equal
505+
506+
# Handle dictionaries
507+
elif isinstance(x, dict):
508+
# Check for keys in both dictionaries
509+
all_keys = set(x.keys()) | set(y.keys())
510+
for k in all_keys:
511+
if k not in x:
512+
print(f"Key {k} missing in first dict in {parent_name}")
513+
return False
514+
if k not in y:
515+
print(f"Key {k} missing in second dict in {parent_name}")
516+
return False
517+
if not is_similar(x[k], y[k], parent_name=parent_name + "/" + k):
518+
return False
519+
return True
520+
521+
# Handle lists
522+
elif isinstance(x, list):
523+
if len(x) != len(y):
524+
print(f"Different lengths: {len(x)} vs {len(y)} in {parent_name}")
525+
return False
526+
return all(
527+
is_similar(x[i], y[i], parent_name=parent_name + f"[{i}]")
528+
for i in range(len(x))
529+
)
530+
531+
# Handle other types
532+
else:
533+
equal = x == y
534+
if not equal:
535+
print(f"Not equal: {x} vs {y} in {parent_name}")
536+
return equal
537+
538+
539+
class SimulationAPIv2:
540+
project: str
541+
location: str
542+
workflow: str
543+
544+
def __init__(self):
545+
self.project = "prod-api-v2-c4d5"
546+
self.location = "us-central1"
547+
self.workflow = "simulation-workflow"
548+
549+
def run(self, payload: dict) -> executions_v1.Execution:
550+
"""
551+
Run a simulation using the v2 API
552+
553+
Parameters:
554+
-----------
555+
payload : dict
556+
The payload to send to the API
557+
558+
Returns:
559+
--------
560+
execution : executions_v1.Execution
561+
The execution object
562+
"""
563+
self.execution_client = executions_v1.ExecutionsClient()
564+
self.workflows_client = workflows_v1.WorkflowsClient()
565+
json_input = json.dumps(payload)
566+
workflow_path = self.workflows_client.workflow_path(
567+
self.project, self.location, self.workflow
568+
)
569+
execution = self.execution_client.create_execution(
570+
parent=workflow_path,
571+
execution=executions_v1.Execution(argument=json_input),
572+
)
573+
return execution
574+
575+
def get_execution_status(self, execution: executions_v1.Execution) -> str:
576+
"""
577+
Get the status of an execution
578+
579+
Parameters:
580+
-----------
581+
execution : executions_v1.Execution
582+
The execution object
583+
584+
Returns:
585+
--------
586+
status : str
587+
The status of the execution
588+
"""
589+
return self.execution_client.get_execution(
590+
name=execution.name
591+
).state.name
592+
593+
def get_execution_result(
594+
self, execution: executions_v1.Execution
595+
) -> dict | None:
596+
"""
597+
Get the result of an execution
598+
599+
Parameters:
600+
-----------
601+
execution : executions_v1.Execution
602+
The execution object
603+
604+
Returns:
605+
--------
606+
result : str
607+
The result of the execution
608+
"""
609+
result = self.execution_client.get_execution(
610+
name=execution.name
611+
).result
612+
try:
613+
return json.loads(result)
614+
except:
615+
return None
616+
return result
617+
618+
def wait_for_completion(
619+
self, execution: executions_v1.Execution
620+
) -> dict | None:
621+
"""
622+
Wait for an execution to complete
623+
624+
Parameters:
625+
-----------
626+
execution : executions_v1.Execution
627+
The execution object
628+
629+
Returns:
630+
--------
631+
result : str
632+
The result of the execution
633+
"""
634+
while self.get_execution_status(execution) == "ACTIVE":
635+
time.sleep(5)
636+
print("Waiting for APIv2 job to complete...")
637+
638+
return self.get_execution_result(execution)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"assertpy",
1414
"click>=8,<9",
1515
"cloud-sql-python-connector",
16+
"google-cloud-workflows",
1617
"faiss-cpu<1.8.0",
1718
"flask>=3,<4",
1819
"flask-cors>=5,<6",

0 commit comments

Comments
 (0)