|
28 | 28 | ResultBase, |
29 | 29 | ResultType, |
30 | 30 | ) |
| 31 | +from .dep import Dep, extract_dep |
31 | 32 | from .local_persistence import create_ccflow_model |
32 | 33 | from .validators import str_to_log_level |
33 | 34 |
|
@@ -128,7 +129,7 @@ def _check_result_type(cls, result_type): |
128 | 129 | @model_validator(mode="after") |
129 | 130 | def _check_signature(self): |
130 | 131 | sig_call = _cached_signature(self.__class__.__call__) |
131 | | - if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context") |
| 132 | + if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: |
132 | 133 | raise ValueError("__call__ method must take a single argument, named 'context'") |
133 | 134 |
|
134 | 135 | sig_deps = _cached_signature(self.__class__.__deps__) |
@@ -195,6 +196,114 @@ def _get_logging_evaluator(log_level): |
195 | 196 | return LoggingEvaluator(log_level=log_level) |
196 | 197 |
|
197 | 198 |
|
| 199 | +def _get_dep_fields(model_class) -> Dict[str, Dep]: |
| 200 | + """Analyze class fields to find Dep-annotated fields. |
| 201 | +
|
| 202 | + Returns a dict mapping field name to Dep instance for fields that need resolution. |
| 203 | + """ |
| 204 | + dep_fields = {} |
| 205 | + |
| 206 | + # Get type hints from the class |
| 207 | + hints = {} |
| 208 | + for cls in model_class.__mro__: |
| 209 | + if hasattr(cls, "__annotations__"): |
| 210 | + for name, annotation in cls.__annotations__.items(): |
| 211 | + if name not in hints: # Don't override child class annotations |
| 212 | + hints[name] = annotation |
| 213 | + |
| 214 | + for name, annotation in hints.items(): |
| 215 | + base_type, dep = extract_dep(annotation) |
| 216 | + if dep is not None: |
| 217 | + dep_fields[name] = dep |
| 218 | + |
| 219 | + return dep_fields |
| 220 | + |
| 221 | + |
| 222 | +def _wrap_with_dep_resolution(fn): |
| 223 | + """Wrap a function to auto-resolve DepOf fields before calling. |
| 224 | +
|
| 225 | + For each Dep-annotated field on the model that contains a CallableModel, |
| 226 | + resolves it using __deps__ and temporarily sets the resolved value on self. |
| 227 | +
|
| 228 | + Note: This wrapper is only applied at runtime when the function is called, |
| 229 | + not during decoration. This avoids issues with functools.wraps flattening |
| 230 | + the __wrapped__ chain. |
| 231 | +
|
| 232 | + Args: |
| 233 | + fn: The original function |
| 234 | +
|
| 235 | + Returns: |
| 236 | + The original function unchanged - dep resolution happens at the call site |
| 237 | + """ |
| 238 | + # Don't modify the function - dep resolution is handled in ModelEvaluationContext |
| 239 | + return fn |
| 240 | + |
| 241 | + |
| 242 | +def _resolve_deps_and_call(model, context, fn): |
| 243 | + """Resolve DepOf fields and call the function. |
| 244 | +
|
| 245 | + This is called from ModelEvaluationContext.__call__ to handle dep resolution. |
| 246 | +
|
| 247 | + Args: |
| 248 | + model: The CallableModel instance |
| 249 | + context: The context to pass to the function |
| 250 | + fn: The function to call |
| 251 | +
|
| 252 | + Returns: |
| 253 | + The result of calling fn(model, context) |
| 254 | + """ |
| 255 | + # Don't resolve deps for __deps__ method |
| 256 | + if fn.__name__ == "__deps__": |
| 257 | + return fn(model, context) |
| 258 | + |
| 259 | + # Get Dep-annotated fields for this model class |
| 260 | + dep_fields = _get_dep_fields(model.__class__) |
| 261 | + |
| 262 | + if not dep_fields: |
| 263 | + return fn(model, context) |
| 264 | + |
| 265 | + # Get dependencies from __deps__ |
| 266 | + deps_result = model.__deps__(context) |
| 267 | + # Build a map from model instance id to (model, contexts) for lookup |
| 268 | + dep_map = {} |
| 269 | + for dep_model, contexts in deps_result: |
| 270 | + dep_map[id(dep_model)] = (dep_model, contexts) |
| 271 | + |
| 272 | + # Store original values and resolve |
| 273 | + originals = {} |
| 274 | + for field_name, dep in dep_fields.items(): |
| 275 | + field_value = getattr(model, field_name, None) |
| 276 | + if field_value is None: |
| 277 | + continue |
| 278 | + |
| 279 | + # Check if field is a CallableModel that needs resolution |
| 280 | + if not isinstance(field_value, _CallableModel): |
| 281 | + continue # Already a resolved value, skip |
| 282 | + |
| 283 | + originals[field_name] = field_value |
| 284 | + |
| 285 | + # Check if this field is in __deps__ (for custom transforms) |
| 286 | + if id(field_value) in dep_map: |
| 287 | + dep_model, contexts = dep_map[id(field_value)] |
| 288 | + # Call dependency with the (transformed) context |
| 289 | + resolved = dep_model(contexts[0]) if contexts else dep_model(context) |
| 290 | + else: |
| 291 | + # Not in __deps__, use Dep annotation transform directly |
| 292 | + transformed_ctx = dep.apply(context) |
| 293 | + resolved = field_value(transformed_ctx) |
| 294 | + |
| 295 | + # Temporarily set resolved value on model |
| 296 | + object.__setattr__(model, field_name, resolved) |
| 297 | + |
| 298 | + try: |
| 299 | + # Call original function |
| 300 | + return fn(model, context) |
| 301 | + finally: |
| 302 | + # Restore original CallableModel values |
| 303 | + for field_name, original_value in originals.items(): |
| 304 | + object.__setattr__(model, field_name, original_value) |
| 305 | + |
| 306 | + |
198 | 307 | class FlowOptions(BaseModel): |
199 | 308 | """Options for Flow evaluation. |
200 | 309 |
|
@@ -246,6 +355,9 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase": |
246 | 355 | return self._get_evaluator_from_options(options) |
247 | 356 |
|
248 | 357 | def __call__(self, fn): |
| 358 | + # Wrap function with dependency resolution for DepOf fields |
| 359 | + fn = _wrap_with_dep_resolution(fn) |
| 360 | + |
249 | 361 | # Used for building a graph of model evaluation contexts without evaluating |
250 | 362 | def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None): |
251 | 363 | # Create the evaluation context. |
@@ -451,6 +563,33 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: |
451 | 563 |
|
452 | 564 | # The generated context inherits from DateContext, so it's compatible |
453 | 565 | # with infrastructure expecting DateContext instances. |
| 566 | +
|
| 567 | + Auto-Resolve Dependencies Example: |
| 568 | + When __call__ has parameters beyond 'self' and 'context' that match field |
| 569 | + names annotated with DepOf/Dep, those dependencies are automatically resolved |
| 570 | + using __deps__ (if defined) or auto-generated from Dep annotations. |
| 571 | +
|
| 572 | + class MyModel(CallableModel): |
| 573 | + data: Annotated[GenericResult[dict], Dep(transform=my_transform)] |
| 574 | +
|
| 575 | + @Flow.call |
| 576 | + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: |
| 577 | + # data is automatically resolved - no manual calling needed |
| 578 | + return GenericResult(value=process(data.value)) |
| 579 | +
|
| 580 | + For transforms that need access to instance fields, define __deps__ manually: |
| 581 | +
|
| 582 | + class MyModel(CallableModel): |
| 583 | + data: DepOf[..., GenericResult[dict]] |
| 584 | + window: int = 7 |
| 585 | +
|
| 586 | + def __deps__(self, context): |
| 587 | + # Can access self.window here |
| 588 | + return [(self.data, [context.with_lookback(self.window)])] |
| 589 | +
|
| 590 | + @Flow.call |
| 591 | + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: |
| 592 | + return GenericResult(value=process(data.value)) |
454 | 593 | """ |
455 | 594 | # Extract auto_context option (not part of FlowOptions) |
456 | 595 | # Can be: False, True, or a ContextBase subclass |
@@ -502,6 +641,78 @@ def deps(*args, **kwargs): |
502 | 641 | # Note that the code below is executed only once |
503 | 642 | return FlowOptionsDeps(**kwargs) |
504 | 643 |
|
| 644 | + @staticmethod |
| 645 | + def model(*args, **kwargs): |
| 646 | + """Decorator that generates a CallableModel class from a plain Python function. |
| 647 | +
|
| 648 | + This is syntactic sugar over CallableModel. The decorator generates a real |
| 649 | + CallableModel class with proper __call__ and __deps__ methods, so all existing |
| 650 | + features (caching, evaluation, registry, serialization) work unchanged. |
| 651 | +
|
| 652 | + Args: |
| 653 | + context_args: List of parameter names that come from context (for unpacked mode) |
| 654 | + cacheable: Enable caching of results (default: False) |
| 655 | + volatile: Mark as volatile (default: False) |
| 656 | + log_level: Logging verbosity (default: logging.DEBUG) |
| 657 | + validate_result: Validate return type (default: True) |
| 658 | + verbose: Verbose logging output (default: True) |
| 659 | + evaluator: Custom evaluator (default: None) |
| 660 | +
|
| 661 | + Two Context Modes: |
| 662 | +
|
| 663 | + Mode 1 - Explicit context parameter: |
| 664 | + Function has a 'context' parameter annotated with a ContextBase subclass. |
| 665 | +
|
| 666 | + @Flow.model |
| 667 | + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: |
| 668 | + return GenericResult(value=query_db(source, context.start_date, context.end_date)) |
| 669 | +
|
| 670 | + Mode 2 - Unpacked context_args: |
| 671 | + Context fields are unpacked into function parameters. |
| 672 | +
|
| 673 | + @Flow.model(context_args=["start_date", "end_date"]) |
| 674 | + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: |
| 675 | + return GenericResult(value=query_db(source, start_date, end_date)) |
| 676 | +
|
| 677 | + Dependencies: |
| 678 | + Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies: |
| 679 | +
|
| 680 | + from ccflow import Dep, DepOf |
| 681 | + from typing import Annotated |
| 682 | +
|
| 683 | + @Flow.model |
| 684 | + def compute_returns( |
| 685 | + context: DateRangeContext, |
| 686 | + prices: Annotated[GenericResult[pl.DataFrame], Dep( |
| 687 | + transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) |
| 688 | + )] |
| 689 | + ) -> GenericResult[pl.DataFrame]: |
| 690 | + return GenericResult(value=prices.value.pct_change()) |
| 691 | +
|
| 692 | + # Or use DepOf shorthand for no transform: |
| 693 | + @Flow.model |
| 694 | + def compute_stats( |
| 695 | + context: DateRangeContext, |
| 696 | + data: DepOf[..., GenericResult[pl.DataFrame]] |
| 697 | + ) -> GenericResult[pl.DataFrame]: |
| 698 | + return GenericResult(value=data.value.describe()) |
| 699 | +
|
| 700 | + Usage: |
| 701 | + # Create model instances |
| 702 | + loader = load_prices(source="prod_db") |
| 703 | + returns = compute_returns(prices=loader) |
| 704 | +
|
| 705 | + # Execute |
| 706 | + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) |
| 707 | + result = returns(ctx) |
| 708 | +
|
| 709 | + Returns: |
| 710 | + A factory function that creates CallableModel instances |
| 711 | + """ |
| 712 | + from .flow_model import flow_model |
| 713 | + |
| 714 | + return flow_model(*args, **kwargs) |
| 715 | + |
505 | 716 |
|
506 | 717 | # ***************************************************************************** |
507 | 718 | # Define "Evaluators" and associated types |
@@ -555,7 +766,8 @@ def _context_validator(cls, values, handler, info): |
555 | 766 | def __call__(self) -> ResultType: |
556 | 767 | fn = getattr(self.model, self.fn) |
557 | 768 | if hasattr(fn, "__wrapped__"): |
558 | | - result = fn.__wrapped__(self.model, self.context) |
| 769 | + # Call through _resolve_deps_and_call to handle DepOf field resolution |
| 770 | + result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__) |
559 | 771 | # If it's a callable model, then we can validate the result |
560 | 772 | if self.options.get("validate_result", True): |
561 | 773 | if fn.__name__ == "__deps__": |
|
0 commit comments