Skip to content

Commit c356ab6

Browse files
committed
[feature]: add nesting_type to concurrency operations
1 parent a3a207f commit c356ab6

11 files changed

Lines changed: 489 additions & 105 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
ExecutionCounters,
2121
SuspendResult,
2222
)
23-
from aws_durable_execution_sdk_python.config import ChildConfig
23+
from aws_durable_execution_sdk_python.config import (
24+
ChildConfig,
25+
NestingType,
26+
)
2427
from aws_durable_execution_sdk_python.exceptions import (
2528
OrphanedChildException,
2629
SuspendExecution,
@@ -134,6 +137,7 @@ class ConcurrentExecutor(ABC, Generic[CallableType, ResultType]):
134137

135138
def __init__(
136139
self,
140+
operation_identifier: OperationIdentifier,
137141
executables: list[Executable[CallableType]],
138142
max_concurrency: int | None,
139143
completion_config: CompletionConfig,
@@ -143,6 +147,7 @@ def __init__(
143147
serdes: SerDes | None,
144148
item_serdes: SerDes | None = None,
145149
summary_generator: SummaryGenerator | None = None,
150+
nesting_type: NestingType = NestingType.NESTED,
146151
):
147152
"""Initialize ConcurrentExecutor.
148153
@@ -153,13 +158,15 @@ def __init__(
153158
handle large BatchResult payloads efficiently. Matches TypeScript behavior in
154159
run-in-child-context-handler.ts.
155160
"""
161+
self.operation_identifier = operation_identifier
156162
self.executables = executables
157163
self.max_concurrency = max_concurrency
158164
self.completion_config = completion_config
159165
self.sub_type_top = sub_type_top
160166
self.sub_type_iteration = sub_type_iteration
161167
self.name_prefix = name_prefix
162168
self.summary_generator = summary_generator
169+
self.nesting_type = nesting_type
163170

164171
# Event-driven state tracking for when the executor is done
165172
self._completion_event = threading.Event()
@@ -406,7 +413,14 @@ def _execute_item_in_child_context(
406413
executable.index
407414
)
408415
name = f"{self.name_prefix}{executable.index}"
409-
child_context = executor_context.create_child_context(operation_id)
416+
non_virtual_parent_id = (
417+
self.operation_identifier.operation_id
418+
if self.nesting_type is NestingType.FLAT
419+
else None
420+
)
421+
child_context = executor_context.create_child_context(
422+
operation_id, non_virtual_parent_id
423+
)
410424
operation_identifier = OperationIdentifier(
411425
operation_id,
412426
executor_context._parent_id, # noqa: SLF001
@@ -424,6 +438,7 @@ def run_in_child_handler():
424438
serdes=self.item_serdes or self.serdes,
425439
sub_type=self.sub_type_iteration,
426440
summary_generator=self.summary_generator,
441+
is_virtual=self.nesting_type is NestingType.FLAT,
427442
),
428443
)
429444
child_context.state.track_replay(operation_id=operation_id)

src/aws_durable_execution_sdk_python/config.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ class TerminationMode(Enum):
7676
ABANDON = "ABANDON"
7777

7878

79+
class NestingType(Enum):
80+
"""
81+
How child operations should inherit context from their parent.
82+
83+
- NESTED: Each child operation runs in its own isolated context
84+
- FLAT: All child operations share the same parent context
85+
"""
86+
87+
NESTED = "NESTED"
88+
FLAT = "FLAT"
89+
90+
7991
@dataclass(frozen=True)
8092
class CompletionConfig:
8193
"""Configuration for determining when parallel/map operations complete.
@@ -187,6 +199,10 @@ class ParallelConfig:
187199
Used internally by map/parallel operations to handle large BatchResult payloads.
188200
Signature: (result: T) -> str
189201
202+
nesting_type: How child operations should inherit context from their parent.
203+
- NESTED: Each branch runs in its own isolated context (default)
204+
- FLAT: All branches share the same parent context
205+
190206
Example:
191207
# Run at most 3 branches concurrently, succeed if any one succeeds
192208
config = ParallelConfig(
@@ -202,6 +218,7 @@ class ParallelConfig:
202218
serdes: SerDes | None = None
203219
item_serdes: SerDes | None = None
204220
summary_generator: SummaryGenerator | None = None
221+
nesting_type: NestingType = NestingType.NESTED
205222

206223

207224
class StepSemantics(Enum):
@@ -218,12 +235,6 @@ class StepConfig:
218235
serdes: SerDes | None = None
219236

220237

221-
class CheckpointMode(Enum):
222-
NO_CHECKPOINT = ("NO_CHECKPOINT",)
223-
CHECKPOINT_AT_FINISH = ("CHECKPOINT_AT_FINISH",)
224-
CHECKPOINT_AT_START_AND_FINISH = "CHECKPOINT_AT_START_AND_FINISH"
225-
226-
227238
@dataclass(frozen=True)
228239
class ChildConfig(Generic[T]):
229240
"""Configuration options for child context operations.
@@ -259,21 +270,19 @@ class ChildConfig(Generic[T]):
259270
260271
Used internally by map/parallel operations to handle large BatchResult payloads.
261272
Signature: (result: T) -> str
262-
Note:
263-
checkpoint_mode field is commented out as it's not currently implemented.
264-
When implemented, it will control when checkpoints are created:
265-
- CHECKPOINT_AT_START_AND_FINISH: Checkpoint at both start and completion (default)
266-
- CHECKPOINT_AT_FINISH: Only checkpoint when operation completes
267-
- NO_CHECKPOINT: No automatic checkpointing
273+
274+
is_virtual: Whether the child operation is virtual (doesn't represent a real operation).
275+
Virtual contexts are used for concurrency operations and don't appear in
276+
the final execution history. Default is False.
268277
269278
See TypeScript reference: aws-durable-execution-sdk-js/src/types/index.ts
270279
"""
271280

272-
# checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
273281
serdes: SerDes | None = None
274282
item_serdes: SerDes | None = None
275283
sub_type: OperationSubType | None = None
276284
summary_generator: SummaryGenerator | None = None
285+
is_virtual: bool = False
277286

278287

279288
class ItemsPerBatchUnit(Enum):
@@ -361,6 +370,10 @@ class MapConfig:
361370
Used internally by map/parallel operations to handle large BatchResult payloads.
362371
Signature: (result: T) -> str
363372
373+
nesting_type: How child operations should inherit context from their parent.
374+
- NESTED: Each item runs in its own isolated context (default)
375+
- FLAT: All items share the same parent context
376+
364377
Example:
365378
# Process 5 items at a time, batch by count, require all to succeed
366379
config = MapConfig(
@@ -376,6 +389,7 @@ class MapConfig:
376389
serdes: SerDes | None = None
377390
item_serdes: SerDes | None = None
378391
summary_generator: SummaryGenerator | None = None
392+
nesting_type: NestingType = NestingType.NESTED
379393

380394

381395
@dataclass(frozen=True)

src/aws_durable_execution_sdk_python/context.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ def __init__(
237237
lambda_context: LambdaContext | None = None,
238238
parent_id: str | None = None,
239239
logger: Logger | None = None,
240+
non_virtual_parent_id: str | None = None,
240241
) -> None:
241242
self.state: ExecutionState = state
242243
self.execution_context: ExecutionContext = execution_context
243244
self.lambda_context = lambda_context
244245
self._parent_id: str | None = parent_id
245246
self._step_counter: OrderedCounter = OrderedCounter()
247+
self._non_virtual_parent_id = non_virtual_parent_id or parent_id
246248

247249
log_info = LogInfo(
248250
execution_state=state,
@@ -269,7 +271,9 @@ def from_lambda_context(
269271
parent_id=None,
270272
)
271273

272-
def create_child_context(self, parent_id: str) -> DurableContext:
274+
def create_child_context(
275+
self, parent_id: str, non_virtual_parent_id=None
276+
) -> DurableContext:
273277
"""Create a child context from the given parent."""
274278
logger.debug("Creating child context for parent %s", parent_id)
275279
return DurableContext(
@@ -283,6 +287,7 @@ def create_child_context(self, parent_id: str) -> DurableContext:
283287
parent_id=parent_id,
284288
)
285289
),
290+
non_virtual_parent_id=non_virtual_parent_id,
286291
)
287292

288293
# endregion factories
@@ -347,7 +352,9 @@ def create_callback(
347352
executor: CallbackOperationExecutor = CallbackOperationExecutor(
348353
state=self.state,
349354
operation_identifier=OperationIdentifier(
350-
operation_id=operation_id, parent_id=self._parent_id, name=name
355+
operation_id=operation_id,
356+
parent_id=self._non_virtual_parent_id,
357+
name=name,
351358
),
352359
config=config,
353360
)
@@ -388,7 +395,7 @@ def invoke(
388395
state=self.state,
389396
operation_identifier=OperationIdentifier(
390397
operation_id=operation_id,
391-
parent_id=self._parent_id,
398+
parent_id=self._non_virtual_parent_id,
392399
name=name,
393400
),
394401
config=config,
@@ -409,7 +416,9 @@ def map(
409416

410417
operation_id = self._create_step_id()
411418
operation_identifier = OperationIdentifier(
412-
operation_id=operation_id, parent_id=self._parent_id, name=map_name
419+
operation_id=operation_id,
420+
parent_id=self._non_virtual_parent_id,
421+
name=map_name,
413422
)
414423
map_context = self.create_child_context(parent_id=operation_id)
415424

@@ -454,7 +463,7 @@ def parallel(
454463
operation_id = self._create_step_id()
455464
parallel_context = self.create_child_context(parent_id=operation_id)
456465
operation_identifier = OperationIdentifier(
457-
operation_id=operation_id, parent_id=self._parent_id, name=name
466+
operation_id=operation_id, parent_id=self._non_virtual_parent_id, name=name
458467
)
459468

460469
def parallel_in_child_context() -> BatchResult[T]:
@@ -515,7 +524,9 @@ def callable_with_child_context():
515524
func=callable_with_child_context,
516525
state=self.state,
517526
operation_identifier=OperationIdentifier(
518-
operation_id=operation_id, parent_id=self._parent_id, name=step_name
527+
operation_id=operation_id,
528+
parent_id=self._non_virtual_parent_id,
529+
name=step_name,
519530
),
520531
config=config,
521532
)
@@ -539,7 +550,7 @@ def step(
539550
state=self.state,
540551
operation_identifier=OperationIdentifier(
541552
operation_id=operation_id,
542-
parent_id=self._parent_id,
553+
parent_id=self._non_virtual_parent_id,
543554
name=step_name,
544555
),
545556
context_logger=self.logger,
@@ -566,7 +577,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
566577
state=self.state,
567578
operation_identifier=OperationIdentifier(
568579
operation_id=operation_id,
569-
parent_id=self._parent_id,
580+
parent_id=self._non_virtual_parent_id,
570581
name=name,
571582
),
572583
)
@@ -621,7 +632,7 @@ def wait_for_condition(
621632
state=self.state,
622633
operation_identifier=OperationIdentifier(
623634
operation_id=operation_id,
624-
parent_id=self._parent_id,
635+
parent_id=self._non_virtual_parent_id,
625636
name=name,
626637
),
627638
context_logger=self.logger,

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def check_result_status(self) -> CheckResult[T]:
118118
checkpointed_result.raise_callable_error()
119119

120120
# Create START checkpoint if not exists
121-
if not checkpointed_result.is_existent():
121+
if not checkpointed_result.is_existent() and not self.config.is_virtual:
122122
start_operation: OperationUpdate = OperationUpdate.create_context_start(
123123
identifier=self.operation_identifier,
124124
sub_type=self.sub_type,
@@ -157,6 +157,14 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
157157
try:
158158
raw_result: T = self.func()
159159

160+
if self.config.is_virtual:
161+
logger.debug(
162+
"Virtual context: Exiting child context without creating another checkpoint. id: %s, name: %s",
163+
self.operation_identifier.operation_id,
164+
self.operation_identifier.name,
165+
)
166+
return raw_result
167+
160168
# If in replay_children mode, return without checkpointing
161169
if checkpointed_result.is_replay_children():
162170
logger.debug(
@@ -270,8 +278,7 @@ def child_handler(
270278
Raises:
271279
May raise operation-specific errors during execution
272280
"""
273-
if not config:
274-
config = ChildConfig()
275-
276-
executor = ChildOperationExecutor(func, state, operation_identifier, config)
281+
executor = ChildOperationExecutor(
282+
func, state, operation_identifier, config or ChildConfig()
283+
)
277284
return executor.process()

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
BatchResult,
1313
Executable,
1414
)
15-
from aws_durable_execution_sdk_python.config import MapConfig
15+
from aws_durable_execution_sdk_python.config import MapConfig, NestingType
1616
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
1717

1818
if TYPE_CHECKING:
@@ -36,6 +36,7 @@
3636
class MapExecutor(Generic[T, R], ConcurrentExecutor[Callable, R]): # noqa: PYI059
3737
def __init__(
3838
self,
39+
operation_identifier: OperationIdentifier,
3940
executables: list[Executable[Callable]],
4041
items: Sequence[T],
4142
max_concurrency: int | None,
@@ -46,8 +47,10 @@ def __init__(
4647
serdes: SerDes | None,
4748
summary_generator: SummaryGenerator | None = None,
4849
item_serdes: SerDes | None = None,
50+
nesting_type: NestingType = NestingType.NESTED,
4951
):
5052
super().__init__(
53+
operation_identifier=operation_identifier,
5154
executables=executables,
5255
max_concurrency=max_concurrency,
5356
completion_config=completion_config,
@@ -57,12 +60,14 @@ def __init__(
5760
serdes=serdes,
5861
summary_generator=summary_generator,
5962
item_serdes=item_serdes,
63+
nesting_type=nesting_type,
6064
)
6165
self.items = items
6266

6367
@classmethod
6468
def from_items(
6569
cls,
70+
operation_identifier: OperationIdentifier,
6671
items: Sequence[T],
6772
func: Callable,
6873
config: MapConfig,
@@ -73,6 +78,7 @@ def from_items(
7378
]
7479

7580
return cls(
81+
operation_identifier=operation_identifier,
7682
executables=executables,
7783
items=items,
7884
max_concurrency=config.max_concurrency,
@@ -83,6 +89,7 @@ def from_items(
8389
serdes=config.serdes,
8490
summary_generator=config.summary_generator,
8591
item_serdes=config.item_serdes,
92+
nesting_type=config.nesting_type,
8693
)
8794

8895
def execute_item(self, child_context, executable: Executable[Callable]) -> R:
@@ -109,6 +116,7 @@ def map_handler(
109116
# See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/map-handler/map-handler.ts (~line 79)
110117

111118
executor: MapExecutor[T, R] = MapExecutor.from_items(
119+
operation_identifier=operation_identifier,
112120
items=items,
113121
func=func,
114122
config=config or MapConfig(summary_generator=MapSummaryGenerator()),

0 commit comments

Comments
 (0)