Skip to content

Commit 9ca90ee

Browse files
committed
[feature]: add nesting_type to concurrency operations
1 parent b042739 commit 9ca90ee

10 files changed

Lines changed: 336 additions & 69 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
ExecutionCounters,
2121
SuspendResult,
2222
)
23-
from aws_durable_execution_sdk_python.config import ChildConfig
23+
from aws_durable_execution_sdk_python.config import (
24+
ChildConfig,
25+
CheckpointMode,
26+
NestingType,
27+
)
2428
from aws_durable_execution_sdk_python.exceptions import (
2529
OrphanedChildException,
2630
SuspendExecution,
@@ -143,6 +147,7 @@ def __init__(
143147
serdes: SerDes | None,
144148
item_serdes: SerDes | None = None,
145149
summary_generator: SummaryGenerator | None = None,
150+
nesting_type: NestingType = NestingType.NESTED,
146151
):
147152
"""Initialize ConcurrentExecutor.
148153
@@ -160,6 +165,7 @@ def __init__(
160165
self.sub_type_iteration = sub_type_iteration
161166
self.name_prefix = name_prefix
162167
self.summary_generator = summary_generator
168+
self.nesting_type = nesting_type
163169

164170
# Event-driven state tracking for when the executor is done
165171
self._completion_event = threading.Event()
@@ -412,6 +418,11 @@ def _execute_item_in_child_context(
412418
executor_context._parent_id, # noqa: SLF001
413419
name,
414420
)
421+
checkpoint_mode = (
422+
CheckpointMode.NO_CHECKPOINT
423+
if self.nesting_type == NestingType.FLAT
424+
else CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
425+
)
415426

416427
def run_in_child_handler():
417428
return self.execute_item(child_context, executable)
@@ -424,6 +435,7 @@ def run_in_child_handler():
424435
serdes=self.item_serdes or self.serdes,
425436
sub_type=self.sub_type_iteration,
426437
summary_generator=self.summary_generator,
438+
checkpoint_mode=checkpoint_mode,
427439
),
428440
)
429441
child_context.state.track_replay(operation_id=operation_id)

src/aws_durable_execution_sdk_python/config.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ class TerminationMode(Enum):
7676
ABANDON = "ABANDON"
7777

7878

79+
class NestingType(Enum):
80+
"""
81+
Defines how child operations inherit context from their parent.
82+
83+
NESTED: Child operations are executed in their own nested context.
84+
FLAT: Child operations are executed within the same context as their parent.
85+
"""
86+
87+
NESTED = "NESTED"
88+
FLAT = "FLAT"
89+
90+
7991
@dataclass(frozen=True)
8092
class CompletionConfig:
8193
"""Configuration for determining when parallel/map operations complete.
@@ -187,6 +199,10 @@ class ParallelConfig:
187199
Used internally by map/parallel operations to handle large BatchResult payloads.
188200
Signature: (result: T) -> str
189201
202+
nesting_type: How child operations should inherit context from their parent.
203+
- NESTED: Each branch runs in its own isolated context (default)
204+
- FLAT: All branches share the same parent context
205+
190206
Example:
191207
# Run at most 3 branches concurrently, succeed if any one succeeds
192208
config = ParallelConfig(
@@ -202,6 +218,7 @@ class ParallelConfig:
202218
serdes: SerDes | None = None
203219
item_serdes: SerDes | None = None
204220
summary_generator: SummaryGenerator | None = None
221+
nesting_type: NestingType = NestingType.NESTED
205222

206223

207224
class StepSemantics(Enum):
@@ -220,7 +237,6 @@ class StepConfig:
220237

221238
class CheckpointMode(Enum):
222239
NO_CHECKPOINT = ("NO_CHECKPOINT",)
223-
CHECKPOINT_AT_FINISH = ("CHECKPOINT_AT_FINISH",)
224240
CHECKPOINT_AT_START_AND_FINISH = "CHECKPOINT_AT_START_AND_FINISH"
225241

226242

@@ -259,17 +275,16 @@ class ChildConfig(Generic[T]):
259275
260276
Used internally by map/parallel operations to handle large BatchResult payloads.
261277
Signature: (result: T) -> str
262-
Note:
263-
checkpoint_mode field is commented out as it's not currently implemented.
264-
When implemented, it will control when checkpoints are created:
278+
279+
checkpoint_mode: controls when checkpoints are created
265280
- CHECKPOINT_AT_START_AND_FINISH: Checkpoint at both start and completion (default)
266-
- CHECKPOINT_AT_FINISH: Only checkpoint when operation completes
281+
- CHECKPOINT_AT_FINISH: Only checkpoint when operation completes (not implemented)
267282
- NO_CHECKPOINT: No automatic checkpointing
268283
269284
See TypeScript reference: aws-durable-execution-sdk-js/src/types/index.ts
270285
"""
271286

272-
# checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
287+
checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
273288
serdes: SerDes | None = None
274289
item_serdes: SerDes | None = None
275290
sub_type: OperationSubType | None = None
@@ -361,6 +376,10 @@ class MapConfig:
361376
Used internally by map/parallel operations to handle large BatchResult payloads.
362377
Signature: (result: T) -> str
363378
379+
nesting_type: How child operations should inherit context from their parent.
380+
- NESTED: Each item runs in its own isolated context (default)
381+
- FLAT: All items share the same parent context
382+
364383
Example:
365384
# Process 5 items at a time, batch by count, require all to succeed
366385
config = MapConfig(
@@ -376,6 +395,7 @@ class MapConfig:
376395
serdes: SerDes | None = None
377396
item_serdes: SerDes | None = None
378397
summary_generator: SummaryGenerator | None = None
398+
nesting_type: NestingType = NestingType.NESTED
379399

380400

381401
@dataclass(frozen=True)

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from typing import TYPE_CHECKING, TypeVar
77

8-
from aws_durable_execution_sdk_python.config import ChildConfig
8+
from aws_durable_execution_sdk_python.config import ChildConfig, CheckpointMode
99
from aws_durable_execution_sdk_python.exceptions import (
1010
InvocationError,
1111
SuspendExecution,
@@ -118,7 +118,11 @@ def check_result_status(self) -> CheckResult[T]:
118118
checkpointed_result.raise_callable_error()
119119

120120
# Create START checkpoint if not exists
121-
if not checkpointed_result.is_existent():
121+
if (
122+
not checkpointed_result.is_existent()
123+
and self.config.checkpoint_mode
124+
== CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
125+
):
122126
start_operation: OperationUpdate = OperationUpdate.create_context_start(
123127
identifier=self.operation_identifier,
124128
sub_type=self.sub_type,
@@ -203,18 +207,21 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
203207
else ""
204208
)
205209

206-
# Checkpoint SUCCEED
207-
success_operation: OperationUpdate = OperationUpdate.create_context_succeed(
208-
identifier=self.operation_identifier,
209-
payload=serialized_result,
210-
sub_type=self.sub_type,
211-
context_options=ContextOptions(replay_children=replay_children),
212-
)
213-
# Checkpoint child context SUCCEED with blocking (is_sync=True, default).
214-
# Must ensure the child context result is persisted before returning to the parent.
215-
# This guarantees the result is durable and child operations won't be re-executed on replay
216-
# (unless replay_children=True for large payloads).
217-
self.state.create_checkpoint(operation_update=success_operation)
210+
if self.config.checkpoint_mode != CheckpointMode.NO_CHECKPOINT:
211+
# Checkpoint SUCCEED
212+
success_operation: OperationUpdate = (
213+
OperationUpdate.create_context_succeed(
214+
identifier=self.operation_identifier,
215+
payload=serialized_result,
216+
sub_type=self.sub_type,
217+
context_options=ContextOptions(replay_children=replay_children),
218+
)
219+
)
220+
# Checkpoint child context SUCCEED with blocking (is_sync=True, default).
221+
# Must ensure the child context result is persisted before returning to the parent.
222+
# This guarantees the result is durable and child operations won't be re-executed on replay
223+
# (unless replay_children=True for large payloads).
224+
self.state.create_checkpoint(operation_update=success_operation)
218225

219226
logger.debug(
220227
"✅ Successfully completed child context for id: %s, name: %s",
@@ -270,8 +277,7 @@ def child_handler(
270277
Raises:
271278
May raise operation-specific errors during execution
272279
"""
273-
if not config:
274-
config = ChildConfig()
275-
276-
executor = ChildOperationExecutor(func, state, operation_identifier, config)
280+
executor = ChildOperationExecutor(
281+
func, state, operation_identifier, config or ChildConfig()
282+
)
277283
return executor.process()

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
BatchResult,
1313
Executable,
1414
)
15-
from aws_durable_execution_sdk_python.config import MapConfig
15+
from aws_durable_execution_sdk_python.config import MapConfig, NestingType
1616
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
1717

1818
if TYPE_CHECKING:
@@ -46,6 +46,7 @@ def __init__(
4646
serdes: SerDes | None,
4747
summary_generator: SummaryGenerator | None = None,
4848
item_serdes: SerDes | None = None,
49+
nesting_type: NestingType = NestingType.NESTED,
4950
):
5051
super().__init__(
5152
executables=executables,
@@ -57,6 +58,7 @@ def __init__(
5758
serdes=serdes,
5859
summary_generator=summary_generator,
5960
item_serdes=item_serdes,
61+
nesting_type=nesting_type,
6062
)
6163
self.items = items
6264

@@ -83,6 +85,7 @@ def from_items(
8385
serdes=config.serdes,
8486
summary_generator=config.summary_generator,
8587
item_serdes=config.item_serdes,
88+
nesting_type=config.nesting_type,
8689
)
8790

8891
def execute_item(self, child_context, executable: Executable[Callable]) -> R:

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from aws_durable_execution_sdk_python.concurrency.executor import ConcurrentExecutor
1111
from aws_durable_execution_sdk_python.concurrency.models import Executable
12-
from aws_durable_execution_sdk_python.config import ParallelConfig
12+
from aws_durable_execution_sdk_python.config import ParallelConfig, NestingType
1313
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
1414

1515
if TYPE_CHECKING:
@@ -38,6 +38,7 @@ def __init__(
3838
serdes: SerDes | None,
3939
summary_generator: SummaryGenerator | None = None,
4040
item_serdes: SerDes | None = None,
41+
nesting_type: NestingType = NestingType.NESTED,
4142
):
4243
super().__init__(
4344
executables=executables,
@@ -49,6 +50,7 @@ def __init__(
4950
serdes=serdes,
5051
summary_generator=summary_generator,
5152
item_serdes=item_serdes,
53+
nesting_type=nesting_type,
5254
)
5355

5456
@classmethod
@@ -71,6 +73,7 @@ def from_callables(
7173
serdes=config.serdes,
7274
summary_generator=config.summary_generator,
7375
item_serdes=config.item_serdes,
76+
nesting_type=config.nesting_type,
7477
)
7578

7679
def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301

tests/concurrency_test.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
ExecutableWithState,
2626
ExecutionCounters,
2727
)
28-
from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig
28+
from aws_durable_execution_sdk_python.config import (
29+
CompletionConfig,
30+
MapConfig,
31+
NestingType,
32+
CheckpointMode,
33+
ChildConfig,
34+
)
2935
from aws_durable_execution_sdk_python.exceptions import (
3036
CallableRuntimeError,
3137
InvalidStateError,
@@ -34,6 +40,7 @@
3440
)
3541
from aws_durable_execution_sdk_python.lambda_service import (
3642
ErrorObject,
43+
OperationSubType,
3744
)
3845
from aws_durable_execution_sdk_python.operation.map import MapExecutor
3946

@@ -853,36 +860,63 @@ def test_batch_result_failed_with_none_error():
853860
assert failed[0].error is not None
854861

855862

856-
def test_concurrent_executor_properties():
857-
"""Test ConcurrentExecutor basic properties."""
863+
def test_concurrent_executor_nesting_type_parameter():
864+
"""Test ConcurrentExecutor nesting_type parameter."""
858865

859866
class TestExecutor(ConcurrentExecutor):
860867
def execute_item(self, child_context, executable):
861868
return f"result_{executable.index}"
862869

863-
executables = [Executable(0, lambda: "test"), Executable(1, lambda: "test2")]
864-
completion_config = CompletionConfig(
865-
min_successful=1,
866-
tolerated_failure_count=None,
867-
tolerated_failure_percentage=None,
870+
executables = [Executable(0, lambda: "test")]
871+
completion_config = CompletionConfig(min_successful=1)
872+
873+
# Test with NESTED (default)
874+
executor_nested = TestExecutor(
875+
executables=executables,
876+
max_concurrency=1,
877+
completion_config=completion_config,
878+
sub_type_top="TOP",
879+
sub_type_iteration="ITER",
880+
name_prefix="test_",
881+
serdes=None,
882+
nesting_type=NestingType.NESTED,
868883
)
869-
executor = TestExecutor(
884+
assert executor_nested.nesting_type == NestingType.NESTED
885+
886+
# Test with FLAT
887+
executor_flat = TestExecutor(
870888
executables=executables,
871-
max_concurrency=2,
889+
max_concurrency=1,
872890
completion_config=completion_config,
873891
sub_type_top="TOP",
874892
sub_type_iteration="ITER",
875893
name_prefix="test_",
876894
serdes=None,
895+
nesting_type=NestingType.FLAT,
877896
)
897+
assert executor_flat.nesting_type == NestingType.FLAT
878898

879-
# Test basic properties
880-
assert executor.executables == executables
881-
assert executor.max_concurrency == 2
882-
assert executor.completion_config == completion_config
883-
assert executor.sub_type_top == "TOP"
884-
assert executor.sub_type_iteration == "ITER"
885-
assert executor.name_prefix == "test_"
899+
900+
def test_concurrent_executor_default_nesting_type():
901+
"""Test ConcurrentExecutor uses NESTED as default nesting_type."""
902+
903+
class TestExecutor(ConcurrentExecutor):
904+
def execute_item(self, child_context, executable):
905+
return f"result_{executable.index}"
906+
907+
executables = [Executable(0, lambda: "test")]
908+
completion_config = CompletionConfig(min_successful=1)
909+
910+
executor = TestExecutor(
911+
executables=executables,
912+
max_concurrency=1,
913+
completion_config=completion_config,
914+
sub_type_top="TOP",
915+
sub_type_iteration="ITER",
916+
name_prefix="test_",
917+
serdes=None,
918+
)
919+
assert executor.nesting_type == NestingType.NESTED
886920

887921

888922
def test_concurrent_executor_full_execution_path():
@@ -2474,8 +2508,12 @@ def execute_item(self, child_context, executable):
24742508
# Track operation_id -> result associations
24752509
captured_associations = []
24762510

2477-
def patched_child_handler(func, execution_state, operation_identifier, config):
2511+
def patched_child_handler(
2512+
func, execution_state, operation_identifier, config: ChildConfig
2513+
):
24782514
"""Patched child handler that captures operation_id -> result mapping."""
2515+
assert config.checkpoint_mode == CheckpointMode.NO_CHECKPOINT
2516+
assert config.sub_type == "TEST_ITER"
24792517
result = func() # Execute the function
24802518
captured_associations.append((operation_identifier.operation_id, result))
24812519
return result
@@ -2504,6 +2542,7 @@ def patched_child_handler(func, execution_state, operation_identifier, config):
25042542
sub_type_iteration="TEST_ITER",
25052543
name_prefix="test_",
25062544
serdes=None,
2545+
nesting_type=NestingType.FLAT,
25072546
)
25082547

25092548
# Create executor context mock

0 commit comments

Comments
 (0)