Skip to content

Commit 989e278

Browse files
committed
Add @Flow.model decorator, new annotation that pulls from deps
Signed-off-by: Nijat Khanbabayev <[email protected]>
1 parent 95119e3 commit 989e278

File tree

10 files changed

+3163
-2
lines changed

10 files changed

+3163
-2
lines changed

ccflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .compose import *
1111
from .callable import *
1212
from .context import *
13+
from .dep import *
1314
from .enums import Enum
1415
from .global_state import *
1516
from .local_persistence import *

ccflow/callable.py

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ResultBase,
2929
ResultType,
3030
)
31+
from .dep import Dep, extract_dep
3132
from .local_persistence import create_ccflow_model
3233
from .validators import str_to_log_level
3334

@@ -128,7 +129,7 @@ def _check_result_type(cls, result_type):
128129
@model_validator(mode="after")
129130
def _check_signature(self):
130131
sig_call = _cached_signature(self.__class__.__call__)
131-
if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context")
132+
if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters:
132133
raise ValueError("__call__ method must take a single argument, named 'context'")
133134

134135
sig_deps = _cached_signature(self.__class__.__deps__)
@@ -195,6 +196,114 @@ def _get_logging_evaluator(log_level):
195196
return LoggingEvaluator(log_level=log_level)
196197

197198

199+
def _get_dep_fields(model_class) -> Dict[str, Dep]:
200+
"""Analyze class fields to find Dep-annotated fields.
201+
202+
Returns a dict mapping field name to Dep instance for fields that need resolution.
203+
"""
204+
dep_fields = {}
205+
206+
# Get type hints from the class
207+
hints = {}
208+
for cls in model_class.__mro__:
209+
if hasattr(cls, "__annotations__"):
210+
for name, annotation in cls.__annotations__.items():
211+
if name not in hints: # Don't override child class annotations
212+
hints[name] = annotation
213+
214+
for name, annotation in hints.items():
215+
base_type, dep = extract_dep(annotation)
216+
if dep is not None:
217+
dep_fields[name] = dep
218+
219+
return dep_fields
220+
221+
222+
def _wrap_with_dep_resolution(fn):
223+
"""Wrap a function to auto-resolve DepOf fields before calling.
224+
225+
For each Dep-annotated field on the model that contains a CallableModel,
226+
resolves it using __deps__ and temporarily sets the resolved value on self.
227+
228+
Note: This wrapper is only applied at runtime when the function is called,
229+
not during decoration. This avoids issues with functools.wraps flattening
230+
the __wrapped__ chain.
231+
232+
Args:
233+
fn: The original function
234+
235+
Returns:
236+
The original function unchanged - dep resolution happens at the call site
237+
"""
238+
# Don't modify the function - dep resolution is handled in ModelEvaluationContext
239+
return fn
240+
241+
242+
def _resolve_deps_and_call(model, context, fn):
243+
"""Resolve DepOf fields and call the function.
244+
245+
This is called from ModelEvaluationContext.__call__ to handle dep resolution.
246+
247+
Args:
248+
model: The CallableModel instance
249+
context: The context to pass to the function
250+
fn: The function to call
251+
252+
Returns:
253+
The result of calling fn(model, context)
254+
"""
255+
# Don't resolve deps for __deps__ method
256+
if fn.__name__ == "__deps__":
257+
return fn(model, context)
258+
259+
# Get Dep-annotated fields for this model class
260+
dep_fields = _get_dep_fields(model.__class__)
261+
262+
if not dep_fields:
263+
return fn(model, context)
264+
265+
# Get dependencies from __deps__
266+
deps_result = model.__deps__(context)
267+
# Build a map from model instance id to (model, contexts) for lookup
268+
dep_map = {}
269+
for dep_model, contexts in deps_result:
270+
dep_map[id(dep_model)] = (dep_model, contexts)
271+
272+
# Store original values and resolve
273+
originals = {}
274+
for field_name, dep in dep_fields.items():
275+
field_value = getattr(model, field_name, None)
276+
if field_value is None:
277+
continue
278+
279+
# Check if field is a CallableModel that needs resolution
280+
if not isinstance(field_value, _CallableModel):
281+
continue # Already a resolved value, skip
282+
283+
originals[field_name] = field_value
284+
285+
# Check if this field is in __deps__ (for custom transforms)
286+
if id(field_value) in dep_map:
287+
dep_model, contexts = dep_map[id(field_value)]
288+
# Call dependency with the (transformed) context
289+
resolved = dep_model(contexts[0]) if contexts else dep_model(context)
290+
else:
291+
# Not in __deps__, use Dep annotation transform directly
292+
transformed_ctx = dep.apply(context)
293+
resolved = field_value(transformed_ctx)
294+
295+
# Temporarily set resolved value on model
296+
object.__setattr__(model, field_name, resolved)
297+
298+
try:
299+
# Call original function
300+
return fn(model, context)
301+
finally:
302+
# Restore original CallableModel values
303+
for field_name, original_value in originals.items():
304+
object.__setattr__(model, field_name, original_value)
305+
306+
198307
class FlowOptions(BaseModel):
199308
"""Options for Flow evaluation.
200309
@@ -246,6 +355,9 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase":
246355
return self._get_evaluator_from_options(options)
247356

248357
def __call__(self, fn):
358+
# Wrap function with dependency resolution for DepOf fields
359+
fn = _wrap_with_dep_resolution(fn)
360+
249361
# Used for building a graph of model evaluation contexts without evaluating
250362
def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None):
251363
# Create the evaluation context.
@@ -451,6 +563,33 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult:
451563
452564
# The generated context inherits from DateContext, so it's compatible
453565
# with infrastructure expecting DateContext instances.
566+
567+
Auto-Resolve Dependencies Example:
568+
When __call__ has parameters beyond 'self' and 'context' that match field
569+
names annotated with DepOf/Dep, those dependencies are automatically resolved
570+
using __deps__ (if defined) or auto-generated from Dep annotations.
571+
572+
class MyModel(CallableModel):
573+
data: Annotated[GenericResult[dict], Dep(transform=my_transform)]
574+
575+
@Flow.call
576+
def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]:
577+
# data is automatically resolved - no manual calling needed
578+
return GenericResult(value=process(data.value))
579+
580+
For transforms that need access to instance fields, define __deps__ manually:
581+
582+
class MyModel(CallableModel):
583+
data: DepOf[..., GenericResult[dict]]
584+
window: int = 7
585+
586+
def __deps__(self, context):
587+
# Can access self.window here
588+
return [(self.data, [context.with_lookback(self.window)])]
589+
590+
@Flow.call
591+
def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]:
592+
return GenericResult(value=process(data.value))
454593
"""
455594
# Extract auto_context option (not part of FlowOptions)
456595
# Can be: False, True, or a ContextBase subclass
@@ -502,6 +641,78 @@ def deps(*args, **kwargs):
502641
# Note that the code below is executed only once
503642
return FlowOptionsDeps(**kwargs)
504643

644+
@staticmethod
645+
def model(*args, **kwargs):
646+
"""Decorator that generates a CallableModel class from a plain Python function.
647+
648+
This is syntactic sugar over CallableModel. The decorator generates a real
649+
CallableModel class with proper __call__ and __deps__ methods, so all existing
650+
features (caching, evaluation, registry, serialization) work unchanged.
651+
652+
Args:
653+
context_args: List of parameter names that come from context (for unpacked mode)
654+
cacheable: Enable caching of results (default: False)
655+
volatile: Mark as volatile (default: False)
656+
log_level: Logging verbosity (default: logging.DEBUG)
657+
validate_result: Validate return type (default: True)
658+
verbose: Verbose logging output (default: True)
659+
evaluator: Custom evaluator (default: None)
660+
661+
Two Context Modes:
662+
663+
Mode 1 - Explicit context parameter:
664+
Function has a 'context' parameter annotated with a ContextBase subclass.
665+
666+
@Flow.model
667+
def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]:
668+
return GenericResult(value=query_db(source, context.start_date, context.end_date))
669+
670+
Mode 2 - Unpacked context_args:
671+
Context fields are unpacked into function parameters.
672+
673+
@Flow.model(context_args=["start_date", "end_date"])
674+
def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]:
675+
return GenericResult(value=query_db(source, start_date, end_date))
676+
677+
Dependencies:
678+
Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies:
679+
680+
from ccflow import Dep, DepOf
681+
from typing import Annotated
682+
683+
@Flow.model
684+
def compute_returns(
685+
context: DateRangeContext,
686+
prices: Annotated[GenericResult[pl.DataFrame], Dep(
687+
transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)})
688+
)]
689+
) -> GenericResult[pl.DataFrame]:
690+
return GenericResult(value=prices.value.pct_change())
691+
692+
# Or use DepOf shorthand for no transform:
693+
@Flow.model
694+
def compute_stats(
695+
context: DateRangeContext,
696+
data: DepOf[..., GenericResult[pl.DataFrame]]
697+
) -> GenericResult[pl.DataFrame]:
698+
return GenericResult(value=data.value.describe())
699+
700+
Usage:
701+
# Create model instances
702+
loader = load_prices(source="prod_db")
703+
returns = compute_returns(prices=loader)
704+
705+
# Execute
706+
ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))
707+
result = returns(ctx)
708+
709+
Returns:
710+
A factory function that creates CallableModel instances
711+
"""
712+
from .flow_model import flow_model
713+
714+
return flow_model(*args, **kwargs)
715+
505716

506717
# *****************************************************************************
507718
# Define "Evaluators" and associated types
@@ -555,7 +766,8 @@ def _context_validator(cls, values, handler, info):
555766
def __call__(self) -> ResultType:
556767
fn = getattr(self.model, self.fn)
557768
if hasattr(fn, "__wrapped__"):
558-
result = fn.__wrapped__(self.model, self.context)
769+
# Call through _resolve_deps_and_call to handle DepOf field resolution
770+
result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__)
559771
# If it's a callable model, then we can validate the result
560772
if self.options.get("validate_result", True):
561773
if fn.__name__ == "__deps__":

0 commit comments

Comments
 (0)