Skip to content

Commit 9ddc534

Browse files
committed
fix parent id for operations in virtual context
1 parent 9dd8b9e commit 9ddc534

9 files changed

Lines changed: 34 additions & 34 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from aws_durable_execution_sdk_python.config import (
2424
ChildConfig,
25-
CheckpointMode,
2625
NestingType,
2726
)
2827
from aws_durable_execution_sdk_python.exceptions import (
@@ -138,6 +137,7 @@ class ConcurrentExecutor(ABC, Generic[CallableType, ResultType]):
138137

139138
def __init__(
140139
self,
140+
operation_identifier: OperationIdentifier,
141141
executables: list[Executable[CallableType]],
142142
max_concurrency: int | None,
143143
completion_config: CompletionConfig,
@@ -158,6 +158,7 @@ def __init__(
158158
handle large BatchResult payloads efficiently. Matches TypeScript behavior in
159159
run-in-child-context-handler.ts.
160160
"""
161+
self.operation_identifier = operation_identifier
161162
self.executables = executables
162163
self.max_concurrency = max_concurrency
163164
self.completion_config = completion_config
@@ -412,17 +413,13 @@ def _execute_item_in_child_context(
412413
executable.index
413414
)
414415
name = f"{self.name_prefix}{executable.index}"
415-
child_context = executor_context.create_child_context(operation_id)
416+
non_virtual_parent_id = self.operation_identifier.operation_id if self.nesting_type is NestingType.FLAT else None
417+
child_context = executor_context.create_child_context(operation_id, non_virtual_parent_id)
416418
operation_identifier = OperationIdentifier(
417419
operation_id,
418420
executor_context._parent_id, # noqa: SLF001
419421
name,
420422
)
421-
checkpoint_mode = (
422-
CheckpointMode.NO_CHECKPOINT
423-
if self.nesting_type == NestingType.FLAT
424-
else CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
425-
)
426423

427424
def run_in_child_handler():
428425
return self.execute_item(child_context, executable)
@@ -435,7 +432,7 @@ def run_in_child_handler():
435432
serdes=self.item_serdes or self.serdes,
436433
sub_type=self.sub_type_iteration,
437434
summary_generator=self.summary_generator,
438-
checkpoint_mode=checkpoint_mode,
435+
is_virtual=self.nesting_type is NestingType.FLAT,
439436
),
440437
)
441438
child_context.state.track_replay(operation_id=operation_id)

src/aws_durable_execution_sdk_python/config.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,6 @@ class StepConfig:
235235
serdes: SerDes | None = None
236236

237237

238-
class CheckpointMode(Enum):
239-
NO_CHECKPOINT = ("NO_CHECKPOINT",)
240-
CHECKPOINT_AT_START_AND_FINISH = "CHECKPOINT_AT_START_AND_FINISH"
241-
242-
243238
@dataclass(frozen=True)
244239
class ChildConfig(Generic[T]):
245240
"""Configuration options for child context operations.
@@ -276,19 +271,18 @@ class ChildConfig(Generic[T]):
276271
Used internally by map/parallel operations to handle large BatchResult payloads.
277272
Signature: (result: T) -> str
278273
279-
checkpoint_mode: controls when checkpoints are created
280-
- CHECKPOINT_AT_START_AND_FINISH: Checkpoint at both start and completion (default)
281-
- CHECKPOINT_AT_FINISH: Only checkpoint when operation completes (not implemented)
282-
- NO_CHECKPOINT: No automatic checkpointing
274+
is_virtual: Whether the child operation is virtual (doesn't represent a real operation).
275+
Virtual contexts are used for concurrency operations and don't appear in
276+
the final execution history. Default is False.
283277
284278
See TypeScript reference: aws-durable-execution-sdk-js/src/types/index.ts
285279
"""
286280

287-
checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
288281
serdes: SerDes | None = None
289282
item_serdes: SerDes | None = None
290283
sub_type: OperationSubType | None = None
291284
summary_generator: SummaryGenerator | None = None
285+
is_virtual: bool = False
292286

293287

294288
class ItemsPerBatchUnit(Enum):

src/aws_durable_execution_sdk_python/context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ def __init__(
237237
lambda_context: LambdaContext | None = None,
238238
parent_id: str | None = None,
239239
logger: Logger | None = None,
240+
non_virtual_parent_id: str | None = None,
240241
) -> None:
241242
self.state: ExecutionState = state
242243
self.execution_context: ExecutionContext = execution_context
243244
self.lambda_context = lambda_context
244245
self._parent_id: str | None = parent_id
245246
self._step_counter: OrderedCounter = OrderedCounter()
247+
self._non_virtual_parent_id = non_virtual_parent_id or parent_id
246248

247249
log_info = LogInfo(
248250
execution_state=state,
@@ -269,7 +271,7 @@ def from_lambda_context(
269271
parent_id=None,
270272
)
271273

272-
def create_child_context(self, parent_id: str) -> DurableContext:
274+
def create_child_context(self, parent_id: str, non_virtual_parent_id=None) -> DurableContext:
273275
"""Create a child context from the given parent."""
274276
logger.debug("Creating child context for parent %s", parent_id)
275277
return DurableContext(
@@ -283,6 +285,7 @@ def create_child_context(self, parent_id: str) -> DurableContext:
283285
parent_id=parent_id,
284286
)
285287
),
288+
non_virtual_parent_id=non_virtual_parent_id,
286289
)
287290

288291
# endregion factories

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,7 @@ def check_result_status(self) -> CheckResult[T]:
118118
checkpointed_result.raise_callable_error()
119119

120120
# Create START checkpoint if not exists
121-
if (
122-
not checkpointed_result.is_existent()
123-
and self.config.checkpoint_mode
124-
== CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
125-
):
121+
if not self.config.is_virtual:
126122
start_operation: OperationUpdate = OperationUpdate.create_context_start(
127123
identifier=self.operation_identifier,
128124
sub_type=self.sub_type,
@@ -207,7 +203,7 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
207203
else ""
208204
)
209205

210-
if self.config.checkpoint_mode != CheckpointMode.NO_CHECKPOINT:
206+
if not self.config.is_virtual:
211207
# Checkpoint SUCCEED
212208
success_operation: OperationUpdate = (
213209
OperationUpdate.create_context_succeed(

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
class MapExecutor(Generic[T, R], ConcurrentExecutor[Callable, R]): # noqa: PYI059
3737
def __init__(
3838
self,
39+
operation_identifier: OperationIdentifier,
3940
executables: list[Executable[Callable]],
4041
items: Sequence[T],
4142
max_concurrency: int | None,
@@ -49,6 +50,7 @@ def __init__(
4950
nesting_type: NestingType = NestingType.NESTED,
5051
):
5152
super().__init__(
53+
operation_identifier=operation_identifier,
5254
executables=executables,
5355
max_concurrency=max_concurrency,
5456
completion_config=completion_config,
@@ -65,6 +67,7 @@ def __init__(
6567
@classmethod
6668
def from_items(
6769
cls,
70+
operation_identifier: OperationIdentifier,
6871
items: Sequence[T],
6972
func: Callable,
7073
config: MapConfig,
@@ -75,6 +78,7 @@ def from_items(
7578
]
7679

7780
return cls(
81+
operation_identifier=operation_identifier,
7882
executables=executables,
7983
items=items,
8084
max_concurrency=config.max_concurrency,
@@ -112,6 +116,7 @@ def map_handler(
112116
# See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/map-handler/map-handler.ts (~line 79)
113117

114118
executor: MapExecutor[T, R] = MapExecutor.from_items(
119+
operation_identifier=operation_identifier,
115120
items=items,
116121
func=func,
117122
config=config or MapConfig(summary_generator=MapSummaryGenerator()),

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
class ParallelExecutor(ConcurrentExecutor[Callable, R]):
3030
def __init__(
3131
self,
32+
operation_identifier: OperationIdentifier,
3233
executables: list[Executable[Callable]],
3334
max_concurrency: int | None,
3435
completion_config,
@@ -41,6 +42,7 @@ def __init__(
4142
nesting_type: NestingType = NestingType.NESTED,
4243
):
4344
super().__init__(
45+
operation_identifier=operation_identifier,
4446
executables=executables,
4547
max_concurrency=max_concurrency,
4648
completion_config=completion_config,
@@ -56,6 +58,7 @@ def __init__(
5658
@classmethod
5759
def from_callables(
5860
cls,
61+
operation_identifier: OperationIdentifier,
5962
callables: Sequence[Callable],
6063
config: ParallelConfig,
6164
) -> ParallelExecutor:
@@ -64,6 +67,7 @@ def from_callables(
6467
Executable(index=i, func=func) for i, func in enumerate(callables)
6568
]
6669
return cls(
70+
operation_identifier=operation_identifier,
6771
executables=executables,
6872
max_concurrency=config.max_concurrency,
6973
completion_config=config.completion_config,
@@ -98,6 +102,7 @@ def parallel_handler(
98102
# See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/parallel-handler/parallel-handler.ts (~line 112)
99103

100104
executor = ParallelExecutor.from_callables(
105+
operation_identifier,
101106
callables,
102107
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
103108
)

tests/concurrency_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def execute_item(self, child_context, executable):
881881
serdes=None,
882882
nesting_type=NestingType.NESTED,
883883
)
884-
assert executor_nested.nesting_type == NestingType.NESTED
884+
assert executor_nested.nesting_type is NestingType.NESTED
885885

886886
# Test with FLAT
887887
executor_flat = TestExecutor(
@@ -894,7 +894,7 @@ def execute_item(self, child_context, executable):
894894
serdes=None,
895895
nesting_type=NestingType.FLAT,
896896
)
897-
assert executor_flat.nesting_type == NestingType.FLAT
897+
assert executor_flat.nesting_type is NestingType.FLAT
898898

899899

900900
def test_concurrent_executor_default_nesting_type():
@@ -916,7 +916,7 @@ def execute_item(self, child_context, executable):
916916
name_prefix="test_",
917917
serdes=None,
918918
)
919-
assert executor.nesting_type == NestingType.NESTED
919+
assert executor.nesting_type is NestingType.NESTED
920920

921921

922922
def test_concurrent_executor_full_execution_path():

tests/operation/map_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_map_executor_init():
6767

6868
assert executor.items == items
6969
assert executor.executables == executables
70-
assert executor.nesting_type == NestingType.FLAT
70+
assert executor.nesting_type is NestingType.FLAT
7171

7272

7373
def test_map_executor_from_items():
@@ -85,7 +85,7 @@ def callable_func(ctx, item, idx, items):
8585
assert executor.items == items
8686
assert all(exe.func == callable_func for exe in executor.executables)
8787
assert [exe.index for exe in executor.executables] == [0, 1, 2]
88-
assert executor.nesting_type == NestingType.FLAT
88+
assert executor.nesting_type is NestingType.FLAT
8989

9090

9191
def test_map_executor_from_items_default_config():
@@ -99,7 +99,7 @@ def callable_func(ctx, item, idx, items):
9999

100100
assert len(executor.executables) == 1
101101
assert executor.items == items
102-
assert executor.nesting_type == NestingType.NESTED
102+
assert executor.nesting_type is NestingType.NESTED
103103

104104

105105
@patch("aws_durable_execution_sdk_python.operation.map.logger")

tests/operation/parallel_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_parallel_executor_init():
7474
assert executor.sub_type_top == OperationSubType.PARALLEL
7575
assert executor.sub_type_iteration == OperationSubType.PARALLEL_BRANCH
7676
assert executor.name_prefix == "test-"
77-
assert executor.nesting_type == NestingType.FLAT
77+
assert executor.nesting_type is NestingType.FLAT
7878

7979

8080
def test_parallel_executor_from_callables():
@@ -100,7 +100,7 @@ def func2(ctx):
100100
assert executor.sub_type_top == OperationSubType.PARALLEL
101101
assert executor.sub_type_iteration == OperationSubType.PARALLEL_BRANCH
102102
assert executor.name_prefix == "parallel-branch-"
103-
assert executor.nesting_type == NestingType.FLAT
103+
assert executor.nesting_type is NestingType.FLAT
104104

105105

106106
def test_parallel_executor_from_callables_default_config():
@@ -117,7 +117,7 @@ def func1(ctx):
117117
assert len(executor.executables) == 1
118118
assert executor.max_concurrency is None
119119
assert executor.completion_config == CompletionConfig.all_successful()
120-
assert executor.nesting_type == NestingType.NESTED
120+
assert executor.nesting_type is NestingType.NESTED
121121

122122

123123
def test_parallel_executor_execute_item():

0 commit comments

Comments
 (0)