Skip to content

Commit abda9c3

Browse files
committed
fix parent id for operations in virtual context
1 parent 9dd8b9e commit abda9c3

11 files changed

Lines changed: 233 additions & 116 deletions

File tree

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from aws_durable_execution_sdk_python.config import (
2424
ChildConfig,
25-
CheckpointMode,
2625
NestingType,
2726
)
2827
from aws_durable_execution_sdk_python.exceptions import (
@@ -138,6 +137,7 @@ class ConcurrentExecutor(ABC, Generic[CallableType, ResultType]):
138137

139138
def __init__(
140139
self,
140+
operation_identifier: OperationIdentifier,
141141
executables: list[Executable[CallableType]],
142142
max_concurrency: int | None,
143143
completion_config: CompletionConfig,
@@ -158,6 +158,7 @@ def __init__(
158158
handle large BatchResult payloads efficiently. Matches TypeScript behavior in
159159
run-in-child-context-handler.ts.
160160
"""
161+
self.operation_identifier = operation_identifier
161162
self.executables = executables
162163
self.max_concurrency = max_concurrency
163164
self.completion_config = completion_config
@@ -412,17 +413,19 @@ def _execute_item_in_child_context(
412413
executable.index
413414
)
414415
name = f"{self.name_prefix}{executable.index}"
415-
child_context = executor_context.create_child_context(operation_id)
416+
non_virtual_parent_id = (
417+
self.operation_identifier.operation_id
418+
if self.nesting_type is NestingType.FLAT
419+
else None
420+
)
421+
child_context = executor_context.create_child_context(
422+
operation_id, non_virtual_parent_id
423+
)
416424
operation_identifier = OperationIdentifier(
417425
operation_id,
418426
executor_context._parent_id, # noqa: SLF001
419427
name,
420428
)
421-
checkpoint_mode = (
422-
CheckpointMode.NO_CHECKPOINT
423-
if self.nesting_type == NestingType.FLAT
424-
else CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
425-
)
426429

427430
def run_in_child_handler():
428431
return self.execute_item(child_context, executable)
@@ -435,7 +438,7 @@ def run_in_child_handler():
435438
serdes=self.item_serdes or self.serdes,
436439
sub_type=self.sub_type_iteration,
437440
summary_generator=self.summary_generator,
438-
checkpoint_mode=checkpoint_mode,
441+
is_virtual=self.nesting_type is NestingType.FLAT,
439442
),
440443
)
441444
child_context.state.track_replay(operation_id=operation_id)

src/aws_durable_execution_sdk_python/config.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ class TerminationMode(Enum):
7878

7979
class NestingType(Enum):
8080
"""
81-
Defines how child operations inherit context from their parent.
81+
How child operations should inherit context from their parent.
8282
83-
NESTED: Child operations are executed in their own nested context.
84-
FLAT: Child operations are executed within the same context as their parent.
83+
- NESTED: Each child operation runs in its own isolated context
84+
- FLAT: All child operations share the same parent context
8585
"""
8686

8787
NESTED = "NESTED"
@@ -235,11 +235,6 @@ class StepConfig:
235235
serdes: SerDes | None = None
236236

237237

238-
class CheckpointMode(Enum):
239-
NO_CHECKPOINT = ("NO_CHECKPOINT",)
240-
CHECKPOINT_AT_START_AND_FINISH = "CHECKPOINT_AT_START_AND_FINISH"
241-
242-
243238
@dataclass(frozen=True)
244239
class ChildConfig(Generic[T]):
245240
"""Configuration options for child context operations.
@@ -276,19 +271,18 @@ class ChildConfig(Generic[T]):
276271
Used internally by map/parallel operations to handle large BatchResult payloads.
277272
Signature: (result: T) -> str
278273
279-
checkpoint_mode: controls when checkpoints are created
280-
- CHECKPOINT_AT_START_AND_FINISH: Checkpoint at both start and completion (default)
281-
- CHECKPOINT_AT_FINISH: Only checkpoint when operation completes (not implemented)
282-
- NO_CHECKPOINT: No automatic checkpointing
274+
is_virtual: Whether the child operation is virtual (doesn't represent a real operation).
275+
Virtual contexts are used for concurrency operations and don't appear in
276+
the final execution history. Default is False.
283277
284278
See TypeScript reference: aws-durable-execution-sdk-js/src/types/index.ts
285279
"""
286280

287-
checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
288281
serdes: SerDes | None = None
289282
item_serdes: SerDes | None = None
290283
sub_type: OperationSubType | None = None
291284
summary_generator: SummaryGenerator | None = None
285+
is_virtual: bool = False
292286

293287

294288
class ItemsPerBatchUnit(Enum):

src/aws_durable_execution_sdk_python/context.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ def __init__(
237237
lambda_context: LambdaContext | None = None,
238238
parent_id: str | None = None,
239239
logger: Logger | None = None,
240+
non_virtual_parent_id: str | None = None,
240241
) -> None:
241242
self.state: ExecutionState = state
242243
self.execution_context: ExecutionContext = execution_context
243244
self.lambda_context = lambda_context
244245
self._parent_id: str | None = parent_id
245246
self._step_counter: OrderedCounter = OrderedCounter()
247+
self._non_virtual_parent_id = non_virtual_parent_id or parent_id
246248

247249
log_info = LogInfo(
248250
execution_state=state,
@@ -269,7 +271,9 @@ def from_lambda_context(
269271
parent_id=None,
270272
)
271273

272-
def create_child_context(self, parent_id: str) -> DurableContext:
274+
def create_child_context(
275+
self, parent_id: str, non_virtual_parent_id=None
276+
) -> DurableContext:
273277
"""Create a child context from the given parent."""
274278
logger.debug("Creating child context for parent %s", parent_id)
275279
return DurableContext(
@@ -283,6 +287,7 @@ def create_child_context(self, parent_id: str) -> DurableContext:
283287
parent_id=parent_id,
284288
)
285289
),
290+
non_virtual_parent_id=non_virtual_parent_id,
286291
)
287292

288293
# endregion factories
@@ -347,7 +352,9 @@ def create_callback(
347352
executor: CallbackOperationExecutor = CallbackOperationExecutor(
348353
state=self.state,
349354
operation_identifier=OperationIdentifier(
350-
operation_id=operation_id, parent_id=self._parent_id, name=name
355+
operation_id=operation_id,
356+
parent_id=self._non_virtual_parent_id,
357+
name=name,
351358
),
352359
config=config,
353360
)
@@ -388,7 +395,7 @@ def invoke(
388395
state=self.state,
389396
operation_identifier=OperationIdentifier(
390397
operation_id=operation_id,
391-
parent_id=self._parent_id,
398+
parent_id=self._non_virtual_parent_id,
392399
name=name,
393400
),
394401
config=config,
@@ -409,7 +416,9 @@ def map(
409416

410417
operation_id = self._create_step_id()
411418
operation_identifier = OperationIdentifier(
412-
operation_id=operation_id, parent_id=self._parent_id, name=map_name
419+
operation_id=operation_id,
420+
parent_id=self._non_virtual_parent_id,
421+
name=map_name,
413422
)
414423
map_context = self.create_child_context(parent_id=operation_id)
415424

@@ -454,7 +463,7 @@ def parallel(
454463
operation_id = self._create_step_id()
455464
parallel_context = self.create_child_context(parent_id=operation_id)
456465
operation_identifier = OperationIdentifier(
457-
operation_id=operation_id, parent_id=self._parent_id, name=name
466+
operation_id=operation_id, parent_id=self._non_virtual_parent_id, name=name
458467
)
459468

460469
def parallel_in_child_context() -> BatchResult[T]:
@@ -515,7 +524,9 @@ def callable_with_child_context():
515524
func=callable_with_child_context,
516525
state=self.state,
517526
operation_identifier=OperationIdentifier(
518-
operation_id=operation_id, parent_id=self._parent_id, name=step_name
527+
operation_id=operation_id,
528+
parent_id=self._non_virtual_parent_id,
529+
name=step_name,
519530
),
520531
config=config,
521532
)
@@ -539,7 +550,7 @@ def step(
539550
state=self.state,
540551
operation_identifier=OperationIdentifier(
541552
operation_id=operation_id,
542-
parent_id=self._parent_id,
553+
parent_id=self._non_virtual_parent_id,
543554
name=step_name,
544555
),
545556
context_logger=self.logger,
@@ -566,7 +577,7 @@ def wait(self, duration: Duration, name: str | None = None) -> None:
566577
state=self.state,
567578
operation_identifier=OperationIdentifier(
568579
operation_id=operation_id,
569-
parent_id=self._parent_id,
580+
parent_id=self._non_virtual_parent_id,
570581
name=name,
571582
),
572583
)
@@ -621,7 +632,7 @@ def wait_for_condition(
621632
state=self.state,
622633
operation_identifier=OperationIdentifier(
623634
operation_id=operation_id,
624-
parent_id=self._parent_id,
635+
parent_id=self._non_virtual_parent_id,
625636
name=name,
626637
),
627638
context_logger=self.logger,

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from typing import TYPE_CHECKING, TypeVar
77

8-
from aws_durable_execution_sdk_python.config import ChildConfig, CheckpointMode
8+
from aws_durable_execution_sdk_python.config import ChildConfig
99
from aws_durable_execution_sdk_python.exceptions import (
1010
InvocationError,
1111
SuspendExecution,
@@ -118,11 +118,7 @@ def check_result_status(self) -> CheckResult[T]:
118118
checkpointed_result.raise_callable_error()
119119

120120
# Create START checkpoint if not exists
121-
if (
122-
not checkpointed_result.is_existent()
123-
and self.config.checkpoint_mode
124-
== CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
125-
):
121+
if not checkpointed_result.is_existent() and not self.config.is_virtual:
126122
start_operation: OperationUpdate = OperationUpdate.create_context_start(
127123
identifier=self.operation_identifier,
128124
sub_type=self.sub_type,
@@ -161,6 +157,14 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
161157
try:
162158
raw_result: T = self.func()
163159

160+
if self.config.is_virtual:
161+
logger.debug(
162+
"Virtual context: Exiting child context without creating another checkpoint. id: %s, name: %s",
163+
self.operation_identifier.operation_id,
164+
self.operation_identifier.name,
165+
)
166+
return raw_result
167+
164168
# If in replay_children mode, return without checkpointing
165169
if checkpointed_result.is_replay_children():
166170
logger.debug(
@@ -207,21 +211,18 @@ def execute(self, checkpointed_result: CheckpointedResult) -> T:
207211
else ""
208212
)
209213

210-
if self.config.checkpoint_mode != CheckpointMode.NO_CHECKPOINT:
211-
# Checkpoint SUCCEED
212-
success_operation: OperationUpdate = (
213-
OperationUpdate.create_context_succeed(
214-
identifier=self.operation_identifier,
215-
payload=serialized_result,
216-
sub_type=self.sub_type,
217-
context_options=ContextOptions(replay_children=replay_children),
218-
)
219-
)
220-
# Checkpoint child context SUCCEED with blocking (is_sync=True, default).
221-
# Must ensure the child context result is persisted before returning to the parent.
222-
# This guarantees the result is durable and child operations won't be re-executed on replay
223-
# (unless replay_children=True for large payloads).
224-
self.state.create_checkpoint(operation_update=success_operation)
214+
# Checkpoint SUCCEED
215+
success_operation: OperationUpdate = OperationUpdate.create_context_succeed(
216+
identifier=self.operation_identifier,
217+
payload=serialized_result,
218+
sub_type=self.sub_type,
219+
context_options=ContextOptions(replay_children=replay_children),
220+
)
221+
# Checkpoint child context SUCCEED with blocking (is_sync=True, default).
222+
# Must ensure the child context result is persisted before returning to the parent.
223+
# This guarantees the result is durable and child operations won't be re-executed on replay
224+
# (unless replay_children=True for large payloads).
225+
self.state.create_checkpoint(operation_update=success_operation)
225226

226227
logger.debug(
227228
"✅ Successfully completed child context for id: %s, name: %s",

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
class MapExecutor(Generic[T, R], ConcurrentExecutor[Callable, R]): # noqa: PYI059
3737
def __init__(
3838
self,
39+
operation_identifier: OperationIdentifier,
3940
executables: list[Executable[Callable]],
4041
items: Sequence[T],
4142
max_concurrency: int | None,
@@ -49,6 +50,7 @@ def __init__(
4950
nesting_type: NestingType = NestingType.NESTED,
5051
):
5152
super().__init__(
53+
operation_identifier=operation_identifier,
5254
executables=executables,
5355
max_concurrency=max_concurrency,
5456
completion_config=completion_config,
@@ -65,6 +67,7 @@ def __init__(
6567
@classmethod
6668
def from_items(
6769
cls,
70+
operation_identifier: OperationIdentifier,
6871
items: Sequence[T],
6972
func: Callable,
7073
config: MapConfig,
@@ -75,6 +78,7 @@ def from_items(
7578
]
7679

7780
return cls(
81+
operation_identifier=operation_identifier,
7882
executables=executables,
7983
items=items,
8084
max_concurrency=config.max_concurrency,
@@ -112,6 +116,7 @@ def map_handler(
112116
# See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/map-handler/map-handler.ts (~line 79)
113117

114118
executor: MapExecutor[T, R] = MapExecutor.from_items(
119+
operation_identifier=operation_identifier,
115120
items=items,
116121
func=func,
117122
config=config or MapConfig(summary_generator=MapSummaryGenerator()),

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
class ParallelExecutor(ConcurrentExecutor[Callable, R]):
3030
def __init__(
3131
self,
32+
operation_identifier: OperationIdentifier,
3233
executables: list[Executable[Callable]],
3334
max_concurrency: int | None,
3435
completion_config,
@@ -41,6 +42,7 @@ def __init__(
4142
nesting_type: NestingType = NestingType.NESTED,
4243
):
4344
super().__init__(
45+
operation_identifier=operation_identifier,
4446
executables=executables,
4547
max_concurrency=max_concurrency,
4648
completion_config=completion_config,
@@ -56,6 +58,7 @@ def __init__(
5658
@classmethod
5759
def from_callables(
5860
cls,
61+
operation_identifier: OperationIdentifier,
5962
callables: Sequence[Callable],
6063
config: ParallelConfig,
6164
) -> ParallelExecutor:
@@ -64,6 +67,7 @@ def from_callables(
6467
Executable(index=i, func=func) for i, func in enumerate(callables)
6568
]
6669
return cls(
70+
operation_identifier=operation_identifier,
6771
executables=executables,
6872
max_concurrency=config.max_concurrency,
6973
completion_config=config.completion_config,
@@ -98,6 +102,7 @@ def parallel_handler(
98102
# See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/parallel-handler/parallel-handler.ts (~line 112)
99103

100104
executor = ParallelExecutor.from_callables(
105+
operation_identifier,
101106
callables,
102107
config or ParallelConfig(summary_generator=ParallelSummaryGenerator()),
103108
)

0 commit comments

Comments
 (0)