Skip to content

Commit 7ab2267

Browse files
committed
fix parent id for operations in virtual context
1 parent 9dd8b9e commit 7ab2267

9 files changed

Lines changed: 64 additions & 56 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from aws_durable_execution_sdk_python.config import (
2424
ChildConfig,
25-
CheckpointMode,
2625
NestingType,
2726
)
2827
from aws_durable_execution_sdk_python.exceptions import (
@@ -138,6 +137,7 @@ class ConcurrentExecutor(ABC, Generic[CallableType, ResultType]):
138137

139138
def __init__(
140139
self,
140+
operation_identifier: OperationIdentifier,
141141
executables: list[Executable[CallableType]],
142142
max_concurrency: int | None,
143143
completion_config: CompletionConfig,
@@ -158,6 +158,7 @@ def __init__(
158158
handle large BatchResult payloads efficiently. Matches TypeScript behavior in
159159
run-in-child-context-handler.ts.
160160
"""
161+
self.operation_identifier = operation_identifier
161162
self.executables = executables
162163
self.max_concurrency = max_concurrency
163164
self.completion_config = completion_config
@@ -412,17 +413,13 @@ def _execute_item_in_child_context(
412413
executable.index
413414
)
414415
name = f"{self.name_prefix}{executable.index}"
415-
child_context = executor_context.create_child_context(operation_id)
416+
non_virtual_parent_id = self.operation_identifier.operation_id if self.nesting_type is NestingType.FLAT else None
417+
child_context = executor_context.create_child_context(operation_id, non_virtual_parent_id)
416418
operation_identifier = OperationIdentifier(
417419
operation_id,
418420
executor_context._parent_id, # noqa: SLF001
419421
name,
420422
)
421-
checkpoint_mode = (
422-
CheckpointMode.NO_CHECKPOINT
423-
if self.nesting_type == NestingType.FLAT
424-
else CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
425-
)
426423

427424
def run_in_child_handler():
428425
return self.execute_item(child_context, executable)
@@ -435,7 +432,7 @@ def run_in_child_handler():
435432
serdes=self.item_serdes or self.serdes,
436433
sub_type=self.sub_type_iteration,
437434
summary_generator=self.summary_generator,
438-
checkpoint_mode=checkpoint_mode,
435+
is_virtual=self.nesting_type is NestingType.FLAT,
439436
),
440437
)
441438
child_context.state.track_replay(operation_id=operation_id)

src/aws_durable_execution_sdk_python/config.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,6 @@ class StepConfig:
235235
serdes: SerDes | None = None
236236

237237

238-
class CheckpointMode(Enum):
239-
NO_CHECKPOINT = ("NO_CHECKPOINT",)
240-
CHECKPOINT_AT_START_AND_FINISH = "CHECKPOINT_AT_START_AND_FINISH"
241-
242-
243238
@dataclass(frozen=True)
244239
class ChildConfig(Generic[T]):
245240
"""Configuration options for child context operations.
@@ -276,19 +271,18 @@ class ChildConfig(Generic[T]):
276271
Used internally by map/parallel operations to handle large BatchResult payloads.
277272
Signature: (result: T) -> str
278273
279-
checkpoint_mode: controls when checkpoints are created
280-
- CHECKPOINT_AT_START_AND_FINISH: Checkpoint at both start and completion (default)
281-
- CHECKPOINT_AT_FINISH: Only checkpoint when operation completes (not implemented)
282-
- NO_CHECKPOINT: No automatic checkpointing
274+
is_virtual: Whether the child operation is virtual (doesn't represent a real operation).
275+
Virtual contexts are used for concurrency operations and don't appear in
276+
the final execution history. Default is False.
283277
284278
See TypeScript reference: aws-durable-execution-sdk-js/src/types/index.ts
285279
"""
286280

287-
checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
288281
serdes: SerDes | None = None
289282
item_serdes: SerDes | None = None
290283
sub_type: OperationSubType | None = None
291284
summary_generator: SummaryGenerator | None = None
285+
is_virtual: bool = False
292286

293287

294288
class ItemsPerBatchUnit(Enum):

src/aws_durable_execution_sdk_python/context.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ def __init__(
237237
lambda_context: LambdaContext | None = None,
238238
parent_id: str | None = None,
239239
logger: Logger | None = None,
240+
non_virtual_parent_id: str | None = None,
240241
) -> None:
241242
self.state: ExecutionState = state
242243
self.execution_context: ExecutionContext = execution_context
243244
self.lambda_context = lambda_context
244245
self._parent_id: str | None = parent_id
245246
self._step_counter: OrderedCounter = OrderedCounter()
247+
self._non_virtual_parent_id = non_virtual_parent_id or parent_id
246248

247249
log_info = LogInfo(
248250
execution_state=state,
@@ -269,7 +271,7 @@ def from_lambda_context(
269271
parent_id=None,
270272
)
271273

272-
def create_child_context(self, parent_id: str) -> DurableContext:
274+
def create_child_context(self, parent_id: str, non_virtual_parent_id=None) -> DurableContext:
273275
"""Create a child context from the given parent."""
274276
logger.debug("Creating child context for parent %s", parent_id)
275277
return DurableContext(
@@ -283,6 +285,7 @@ def create_child_context(self, parent_id: str) -> DurableContext:
283285
parent_id=parent_id,
284286
)
285287
),
288+
non_virtual_parent_id=non_virtual_parent_id,
286289
)
287290

288291
# endregion factories
@@ -347,7 +350,7 @@ def create_callback(
347350
executor: CallbackOperationExecutor = CallbackOperationExecutor(
348351
state=self.state,
349352
operation_identifier=OperationIdentifier(
350-
operation_id=operation_id, parent_id=self._parent_id, name=name
353+
operation_id=operation_id, parent_id=self._non_virtual_parent_id, name=name
351354
),
352355
config=config,
353356
)
@@ -388,7 +391,7 @@ def invoke(
388391
state=self.state,
389392
operation_identifier=OperationIdentifier(
390393
operation_id=operation_id,
391-
parent_id=self._parent_id,
394+
parent_id=self._non_virtual_parent_id,
392395
name=name,
393396
),
394397
config=config,
@@ -409,7 +412,7 @@ def map(
409412

410413
operation_id = self._create_step_id()
411414
operation_identifier = OperationIdentifier(
412-
operation_id=operation_id, parent_id=self._parent_id, name=map_name
415+
operation_id=operation_id, parent_id=self._non_virtual_parent_id, name=map_name
413416
)
414417
map_context = self.create_child_context(parent_id=operation_id)
415418

@@ -454,7 +457,7 @@ def parallel(
454457
operation_id = self._create_step_id()
455458
parallel_context = self.create_child_context(parent_id=operation_id)
456459
operation_identifier = OperationIdentifier(
457-
operation_id=operation_id, parent_id=self._parent_id, name=name
460+
operation_id=operation_id, parent_id=self._non_virtual_parent_id, name=name
458461
)
459462

460463
def parallel_in_child_context() -> BatchResult[T]:
@@ -515,7 +518,7 @@ def callable_with_child_context():
515518
func=callable_with_child_context,
516519
state=self.state,
517520
operation_identifier=OperationIdentifier(
518-
operation_id=operation_id, parent_id=self._parent_id, name=step_name
521+
operation_id=operation_id, parent_id=self._non_virtual_parent_id, name=step_name
519522
),
520523
config=config,
521524
)
@@ -539,7 +542,7 @@ def step(
539542
state=self.state,
540543
operation_identifier=OperationIdentifier(
541544
operation_id=operation_id,
542-
parent_id=self._parent_id,
545+
parent_id=self._non_virtual_parent_id,
543546
name=step_name,
544547
),
545548
context_logger=self.logger,
@@ -566,7 +569,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
566569
state=self.state,
567570
operation_identifier=OperationIdentifier(
568571
operation_id=operation_id,
569-
parent_id=self._parent_id,
572+
parent_id=self._non_virtual_parent_id,
570573
name=name,
571574
),
572575
)
@@ -621,7 +624,7 @@ def wait_for_condition(
621624
state=self.state,
622625
operation_identifier=OperationIdentifier(
623626
operation_id=operation_id,
624-
parent_id=self._parent_id,
627+
parent_id=self._non_virtual_parent_id,
625628
name=name,
626629
),
627630
context_logger=self.logger,

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from typing import TYPE_CHECKING, TypeVar
77

8-
from aws_durable_execution_sdk_python.config import ChildConfig, CheckpointMode
8+
from aws_durable_execution_sdk_python.config import ChildConfig
99
from aws_durable_execution_sdk_python.exceptions import (
1010
InvocationError,
1111
SuspendExecution,
@@ -118,11 +118,7 @@ def check_result_status(self) -> CheckResult[T]:
118118
checkpointed_result.raise_callable_error()
119119

120120
# Create START checkpoint if not exists
121-
if (
122-
not checkpointed_result.is_existent()
123-
and self.config.checkpoint_mode
124-
== CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
125-
):
121+
if not checkpointed_result.is_existent() and not self.config.is_virtual:
126122
start_operation: OperationUpdate = OperationUpdate.create_context_start(
127123
identifier=self.operation_identifier,
128124
sub_type=self.sub_type,
@@ -161,6 +157,15 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
161157
try:
162158
raw_result: T = self.func()
163159

160+
161+
if self.config.is_virtual:
162+
logger.debug(
163+
"Virtual context: Exiting child context without creating another checkpoint. id: %s, name: %s",
164+
self.operation_identifier.operation_id,
165+
self.operation_identifier.name,
166+
)
167+
return raw_result
168+
164169
# If in replay_children mode, return without checkpointing
165170
if checkpointed_result.is_replay_children():
166171
logger.debug(
@@ -207,21 +212,20 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
207212
else ""
208213
)
209214

210-
if self.config.checkpoint_mode != CheckpointMode.NO_CHECKPOINT:
211-
# Checkpoint SUCCEED
212-
success_operation: OperationUpdate = (
213-
OperationUpdate.create_context_succeed(
214-
identifier=self.operation_identifier,
215-
payload=serialized_result,
216-
sub_type=self.sub_type,
217-
context_options=ContextOptions(replay_children=replay_children),
218-
)
215+
# Checkpoint SUCCEED
216+
success_operation: OperationUpdate = (
217+
OperationUpdate.create_context_succeed(
218+
identifier=self.operation_identifier,
219+
payload=serialized_result,
220+
sub_type=self.sub_type,
221+
context_options=ContextOptions(replay_children=replay_children),
219222
)
220-
# Checkpoint child context SUCCEED with blocking (is_sync=True, default).
221-
# Must ensure the child context result is persisted before returning to the parent.
222-
# This guarantees the result is durable and child operations won't be re-executed on replay
223-
# (unless replay_children=True for large payloads).
224-
self.state.create_checkpoint(operation_update=success_operation)
223+
)
224+
# Checkpoint child context SUCCEED with blocking (is_sync=True, default).
225+
# Must ensure the child context result is persisted before returning to the parent.
226+
# This guarantees the result is durable and child operations won't be re-executed on replay
227+
# (unless replay_children=True for large payloads).
228+
self.state.create_checkpoint(operation_update=success_operation)
225229

226230
logger.debug(
227231
"✅ Successfully completed child context for id: %s, name: %s",

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
class MapExecutor(Generic[T, R], ConcurrentExecutor[Callable, R]): # noqa: PYI059
3737
def __init__(
3838
self,
39+
operation_identifier: OperationIdentifier,
3940
executables: list[Executable[Callable]],
4041
items: Sequence[T],
4142
max_concurrency: int | None,
@@ -49,6 +50,7 @@ def __init__(
4950
nesting_type: NestingType = NestingType.NESTED,
5051
):
5152
super().__init__(
53+
operation_identifier=operation_identifier,
5254
executables=executables,
5355
max_concurrency=max_concurrency,
5456
completion_config=completion_config,
@@ -65,6 +67,7 @@ def __init__(
6567
@classmethod
6668
def from_items(
6769
cls,
70+
operation_identifier: OperationIdentifier,
6871
items: Sequence[T],
6972
func: Callable,
7073
config: MapConfig,
@@ -75,6 +78,7 @@ def from_items(
7578
]
7679

7780
return cls(
81+
operation_identifier=operation_identifier,
7882
executables=executables,
7983
items=items,
8084
max_concurrency=config.max_concurrency,
@@ -112,6 +116,7 @@ def map_handler(
112116
# See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/map-handler/map-handler.ts (~line 79)
113117

114118
executor: MapExecutor[T, R] = MapExecutor.from_items(
119+
operation_identifier=operation_identifier,
115120
items=items,
116121
func=func,
117122
config=config or MapConfig(summary_generator=MapSummaryGenerator()),

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
class ParallelExecutor(ConcurrentExecutor[Callable, R]):
3030
def __init__(
3131
self,
32+
operation_identifier: OperationIdentifier,
3233
executables: list[Executable[Callable]],
3334
max_concurrency: int | None,
3435
completion_config,
@@ -41,6 +42,7 @@ def __init__(
4142
nesting_type: NestingType = NestingType.NESTED,
4243
):
4344
super().__init__(
45+
operation_identifier=operation_identifier,
4446
executables=executables,
4547
max_concurrency=max_concurrency,
4648
completion_config=completion_config,
@@ -56,6 +58,7 @@ def __init__(
5658
@classmethod
5759
def from_callables(
5860
cls,
61+
operation_identifier: OperationIdentifier,
5962
callables: Sequence[Callable],
6063
config: ParallelConfig,
6164
) -> ParallelExecutor:
@@ -64,6 +67,7 @@ def from_callables(
6467
Executable(index=i, func=func) for i, func in enumerate(callables)
6568
]
6669
return cls(
70+
operation_identifier=operation_identifier,
6771
executables=executables,
6872
max_concurrency=config.max_concurrency,
6973
completion_config=config.completion_config,
@@ -98,6 +102,7 @@ def parallel_handler(
98102
# See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/parallel-handler/parallel-handler.ts (~line 112)
99103

100104
executor = ParallelExecutor.from_callables(
105+
operation_identifier,
101106
callables,
102107
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
103108
)

tests/concurrency_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def execute_item(self, child_context, executable):
881881
serdes=None,
882882
nesting_type=NestingType.NESTED,
883883
)
884-
assert executor_nested.nesting_type == NestingType.NESTED
884+
assert executor_nested.nesting_type is NestingType.NESTED
885885

886886
# Test with FLAT
887887
executor_flat = TestExecutor(
@@ -894,7 +894,7 @@ def execute_item(self, child_context, executable):
894894
serdes=None,
895895
nesting_type=NestingType.FLAT,
896896
)
897-
assert executor_flat.nesting_type == NestingType.FLAT
897+
assert executor_flat.nesting_type is NestingType.FLAT
898898

899899

900900
def test_concurrent_executor_default_nesting_type():
@@ -916,7 +916,7 @@ def execute_item(self, child_context, executable):
916916
name_prefix="test_",
917917
serdes=None,
918918
)
919-
assert executor.nesting_type == NestingType.NESTED
919+
assert executor.nesting_type is NestingType.NESTED
920920

921921

922922
def test_concurrent_executor_full_execution_path():

tests/operation/map_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_map_executor_init():
6767

6868
assert executor.items == items
6969
assert executor.executables == executables
70-
assert executor.nesting_type == NestingType.FLAT
70+
assert executor.nesting_type is NestingType.FLAT
7171

7272

7373
def test_map_executor_from_items():
@@ -85,7 +85,7 @@ def callable_func(ctx, item, idx, items):
8585
assert executor.items == items
8686
assert all(exe.func == callable_func for exe in executor.executables)
8787
assert [exe.index for exe in executor.executables] == [0, 1, 2]
88-
assert executor.nesting_type == NestingType.FLAT
88+
assert executor.nesting_type is NestingType.FLAT
8989

9090

9191
def test_map_executor_from_items_default_config():
@@ -99,7 +99,7 @@ def callable_func(ctx, item, idx, items):
9999

100100
assert len(executor.executables) == 1
101101
assert executor.items == items
102-
assert executor.nesting_type == NestingType.NESTED
102+
assert executor.nesting_type is NestingType.NESTED
103103

104104

105105
@patch("aws_durable_execution_sdk_python.operation.map.logger")

0 commit comments

Comments
 (0)