|
25 | 25 | ExecutableWithState, |
26 | 26 | ExecutionCounters, |
27 | 27 | ) |
28 | | -from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig |
| 28 | +from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig, NestingType, CheckpointMode, \ |
| 29 | + ChildConfig |
29 | 30 | from aws_durable_execution_sdk_python.exceptions import ( |
30 | 31 | CallableRuntimeError, |
31 | 32 | InvalidStateError, |
32 | 33 | SuspendExecution, |
33 | 34 | TimedSuspendExecution, |
34 | 35 | ) |
35 | 36 | from aws_durable_execution_sdk_python.lambda_service import ( |
36 | | - ErrorObject, |
| 37 | + ErrorObject, OperationSubType, |
37 | 38 | ) |
38 | 39 | from aws_durable_execution_sdk_python.operation.map import MapExecutor |
39 | 40 |
|
@@ -853,36 +854,63 @@ def test_batch_result_failed_with_none_error(): |
853 | 854 | assert failed[0].error is not None |
854 | 855 |
|
855 | 856 |
|
856 | | -def test_concurrent_executor_properties(): |
857 | | - """Test ConcurrentExecutor basic properties.""" |
| 857 | +def test_concurrent_executor_nesting_type_parameter(): |
| 858 | + """Test ConcurrentExecutor nesting_type parameter.""" |
858 | 859 |
|
859 | 860 | class TestExecutor(ConcurrentExecutor): |
860 | 861 | def execute_item(self, child_context, executable): |
861 | 862 | return f"result_{executable.index}" |
862 | 863 |
|
863 | | - executables = [Executable(0, lambda: "test"), Executable(1, lambda: "test2")] |
864 | | - completion_config = CompletionConfig( |
865 | | - min_successful=1, |
866 | | - tolerated_failure_count=None, |
867 | | - tolerated_failure_percentage=None, |
| 864 | + executables = [Executable(0, lambda: "test")] |
| 865 | + completion_config = CompletionConfig(min_successful=1) |
| 866 | + |
| 867 | + # Test with NESTED (default) |
| 868 | + executor_nested = TestExecutor( |
| 869 | + executables=executables, |
| 870 | + max_concurrency=1, |
| 871 | + completion_config=completion_config, |
| 872 | + sub_type_top="TOP", |
| 873 | + sub_type_iteration="ITER", |
| 874 | + name_prefix="test_", |
| 875 | + serdes=None, |
| 876 | + nesting_type=NestingType.NESTED, |
868 | 877 | ) |
869 | | - executor = TestExecutor( |
| 878 | + assert executor_nested.nesting_type == NestingType.NESTED |
| 879 | + |
| 880 | + # Test with FLAT |
| 881 | + executor_flat = TestExecutor( |
870 | 882 | executables=executables, |
871 | | - max_concurrency=2, |
| 883 | + max_concurrency=1, |
872 | 884 | completion_config=completion_config, |
873 | 885 | sub_type_top="TOP", |
874 | 886 | sub_type_iteration="ITER", |
875 | 887 | name_prefix="test_", |
876 | 888 | serdes=None, |
| 889 | + nesting_type=NestingType.FLAT, |
877 | 890 | ) |
| 891 | + assert executor_flat.nesting_type == NestingType.FLAT |
878 | 892 |
|
879 | | - # Test basic properties |
880 | | - assert executor.executables == executables |
881 | | - assert executor.max_concurrency == 2 |
882 | | - assert executor.completion_config == completion_config |
883 | | - assert executor.sub_type_top == "TOP" |
884 | | - assert executor.sub_type_iteration == "ITER" |
885 | | - assert executor.name_prefix == "test_" |
| 893 | + |
| 894 | +def test_concurrent_executor_default_nesting_type(): |
| 895 | + """Test ConcurrentExecutor uses NESTED as default nesting_type.""" |
| 896 | + |
| 897 | + class TestExecutor(ConcurrentExecutor): |
| 898 | + def execute_item(self, child_context, executable): |
| 899 | + return f"result_{executable.index}" |
| 900 | + |
| 901 | + executables = [Executable(0, lambda: "test")] |
| 902 | + completion_config = CompletionConfig(min_successful=1) |
| 903 | + |
| 904 | + executor = TestExecutor( |
| 905 | + executables=executables, |
| 906 | + max_concurrency=1, |
| 907 | + completion_config=completion_config, |
| 908 | + sub_type_top="TOP", |
| 909 | + sub_type_iteration="ITER", |
| 910 | + name_prefix="test_", |
| 911 | + serdes=None, |
| 912 | + ) |
| 913 | + assert executor.nesting_type == NestingType.NESTED |
886 | 914 |
|
887 | 915 |
|
888 | 916 | def test_concurrent_executor_full_execution_path(): |
@@ -2474,8 +2502,10 @@ def execute_item(self, child_context, executable): |
2474 | 2502 | # Track operation_id -> result associations |
2475 | 2503 | captured_associations = [] |
2476 | 2504 |
|
2477 | | - def patched_child_handler(func, execution_state, operation_identifier, config): |
| 2505 | + def patched_child_handler(func, execution_state, operation_identifier, config: ChildConfig): |
2478 | 2506 | """Patched child handler that captures operation_id -> result mapping.""" |
| 2507 | + assert config.checkpoint_mode == CheckpointMode.NO_CHECKPOINT |
| 2508 | + assert config.sub_type == "TEST_ITER" |
2479 | 2509 | result = func() # Execute the function |
2480 | 2510 | captured_associations.append((operation_identifier.operation_id, result)) |
2481 | 2511 | return result |
@@ -2504,6 +2534,7 @@ def patched_child_handler(func, execution_state, operation_identifier, config): |
2504 | 2534 | sub_type_iteration="TEST_ITER", |
2505 | 2535 | name_prefix="test_", |
2506 | 2536 | serdes=None, |
| 2537 | + nesting_type=NestingType.FLAT, |
2507 | 2538 | ) |
2508 | 2539 |
|
2509 | 2540 | # Create executor context mock |
|
0 commit comments