|
25 | 25 | ExecutableWithState, |
26 | 26 | ExecutionCounters, |
27 | 27 | ) |
28 | | -from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig |
| 28 | +from aws_durable_execution_sdk_python.config import ( |
| 29 | + CompletionConfig, |
| 30 | + MapConfig, |
| 31 | + NestingType, |
| 32 | + CheckpointMode, |
| 33 | + ChildConfig, |
| 34 | +) |
29 | 35 | from aws_durable_execution_sdk_python.exceptions import ( |
30 | 36 | CallableRuntimeError, |
31 | 37 | InvalidStateError, |
|
34 | 40 | ) |
35 | 41 | from aws_durable_execution_sdk_python.lambda_service import ( |
36 | 42 | ErrorObject, |
| 43 | + OperationSubType, |
37 | 44 | ) |
38 | 45 | from aws_durable_execution_sdk_python.operation.map import MapExecutor |
39 | 46 |
|
@@ -853,36 +860,63 @@ def test_batch_result_failed_with_none_error(): |
853 | 860 | assert failed[0].error is not None |
854 | 861 |
|
855 | 862 |
|
856 | | -def test_concurrent_executor_properties(): |
857 | | - """Test ConcurrentExecutor basic properties.""" |
| 863 | +def test_concurrent_executor_nesting_type_parameter(): |
| 864 | + """Test ConcurrentExecutor nesting_type parameter.""" |
858 | 865 |
|
859 | 866 | class TestExecutor(ConcurrentExecutor): |
860 | 867 | def execute_item(self, child_context, executable): |
861 | 868 | return f"result_{executable.index}" |
862 | 869 |
|
863 | | - executables = [Executable(0, lambda: "test"), Executable(1, lambda: "test2")] |
864 | | - completion_config = CompletionConfig( |
865 | | - min_successful=1, |
866 | | - tolerated_failure_count=None, |
867 | | - tolerated_failure_percentage=None, |
| 870 | + executables = [Executable(0, lambda: "test")] |
| 871 | + completion_config = CompletionConfig(min_successful=1) |
| 872 | + |
| 873 | + # Test with NESTED (default) |
| 874 | + executor_nested = TestExecutor( |
| 875 | + executables=executables, |
| 876 | + max_concurrency=1, |
| 877 | + completion_config=completion_config, |
| 878 | + sub_type_top="TOP", |
| 879 | + sub_type_iteration="ITER", |
| 880 | + name_prefix="test_", |
| 881 | + serdes=None, |
| 882 | + nesting_type=NestingType.NESTED, |
868 | 883 | ) |
869 | | - executor = TestExecutor( |
| 884 | + assert executor_nested.nesting_type == NestingType.NESTED |
| 885 | + |
| 886 | + # Test with FLAT |
| 887 | + executor_flat = TestExecutor( |
870 | 888 | executables=executables, |
871 | | - max_concurrency=2, |
| 889 | + max_concurrency=1, |
872 | 890 | completion_config=completion_config, |
873 | 891 | sub_type_top="TOP", |
874 | 892 | sub_type_iteration="ITER", |
875 | 893 | name_prefix="test_", |
876 | 894 | serdes=None, |
| 895 | + nesting_type=NestingType.FLAT, |
877 | 896 | ) |
| 897 | + assert executor_flat.nesting_type == NestingType.FLAT |
878 | 898 |
|
879 | | - # Test basic properties |
880 | | - assert executor.executables == executables |
881 | | - assert executor.max_concurrency == 2 |
882 | | - assert executor.completion_config == completion_config |
883 | | - assert executor.sub_type_top == "TOP" |
884 | | - assert executor.sub_type_iteration == "ITER" |
885 | | - assert executor.name_prefix == "test_" |
| 899 | + |
| 900 | +def test_concurrent_executor_default_nesting_type(): |
| 901 | + """Test ConcurrentExecutor uses NESTED as default nesting_type.""" |
| 902 | + |
| 903 | + class TestExecutor(ConcurrentExecutor): |
| 904 | + def execute_item(self, child_context, executable): |
| 905 | + return f"result_{executable.index}" |
| 906 | + |
| 907 | + executables = [Executable(0, lambda: "test")] |
| 908 | + completion_config = CompletionConfig(min_successful=1) |
| 909 | + |
| 910 | + executor = TestExecutor( |
| 911 | + executables=executables, |
| 912 | + max_concurrency=1, |
| 913 | + completion_config=completion_config, |
| 914 | + sub_type_top="TOP", |
| 915 | + sub_type_iteration="ITER", |
| 916 | + name_prefix="test_", |
| 917 | + serdes=None, |
| 918 | + ) |
| 919 | + assert executor.nesting_type == NestingType.NESTED |
886 | 920 |
|
887 | 921 |
|
888 | 922 | def test_concurrent_executor_full_execution_path(): |
@@ -2474,8 +2508,12 @@ def execute_item(self, child_context, executable): |
2474 | 2508 | # Track operation_id -> result associations |
2475 | 2509 | captured_associations = [] |
2476 | 2510 |
|
2477 | | - def patched_child_handler(func, execution_state, operation_identifier, config): |
| 2511 | + def patched_child_handler( |
| 2512 | + func, execution_state, operation_identifier, config: ChildConfig |
| 2513 | + ): |
2478 | 2514 | """Patched child handler that captures operation_id -> result mapping.""" |
| 2515 | + assert config.checkpoint_mode == CheckpointMode.NO_CHECKPOINT |
| 2516 | + assert config.sub_type == "TEST_ITER" |
2479 | 2517 | result = func() # Execute the function |
2480 | 2518 | captured_associations.append((operation_identifier.operation_id, result)) |
2481 | 2519 | return result |
@@ -2504,6 +2542,7 @@ def patched_child_handler(func, execution_state, operation_identifier, config): |
2504 | 2542 | sub_type_iteration="TEST_ITER", |
2505 | 2543 | name_prefix="test_", |
2506 | 2544 | serdes=None, |
| 2545 | + nesting_type=NestingType.FLAT, |
2507 | 2546 | ) |
2508 | 2547 |
|
2509 | 2548 | # Create executor context mock |
|
0 commit comments