Skip to content

Commit 0c61d98

Browse files
author
Alex Wang
committed
Adding replay
1 parent 44785df commit 0c61d98

5 files changed

Lines changed: 139 additions & 4 deletions

File tree

src/aws_durable_execution_sdk_python/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class ChildConfig:
110110
# checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
111111
serdes: SerDes | None = None
112112
sub_type: OperationSubType | None = None
113+
summary_generator: Callable[[T], str] | None = None
113114

114115

115116
class ItemsPerBatchUnit(Enum):

src/aws_durable_execution_sdk_python/lambda_service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,11 @@ def create_context_start(
344344

345345
@classmethod
346346
def create_context_succeed(
347-
cls, identifier: OperationIdentifier, payload: str, sub_type: OperationSubType
347+
cls,
348+
identifier: OperationIdentifier,
349+
payload: str,
350+
sub_type: OperationSubType,
351+
context_options: ContextOptions | None = None,
348352
) -> OperationUpdate:
349353
"""Create an instance of OperationUpdate for type: CONTEXT, action: SUCCEED."""
350354
return cls(
@@ -355,6 +359,7 @@ def create_context_succeed(
355359
action=OperationAction.SUCCEED,
356360
name=identifier.name,
357361
payload=payload,
362+
context_options=context_options,
358363
)
359364

360365
@classmethod

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aws_durable_execution_sdk_python.config import ChildConfig
99
from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution
1010
from aws_durable_execution_sdk_python.lambda_service import (
11+
ContextOptions,
1112
ErrorObject,
1213
OperationSubType,
1314
OperationUpdate,
@@ -24,6 +25,9 @@
2425

2526
T = TypeVar("T")
2627

28+
# Checkpoint size limit in bytes (256KB)
29+
CHECKPOINT_SIZE_LIMIT = 256 * 1024
30+
2731

2832
def child_handler(
2933
func: Callable[[], T],
@@ -40,9 +44,11 @@ def child_handler(
4044
if not config:
4145
config = ChildConfig()
4246

43-
# TODO: ReplayChildren
4447
checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id)
45-
if checkpointed_result.is_succeeded():
48+
if (
49+
checkpointed_result.is_succeeded()
50+
and not checkpointed_result.is_replay_children()
51+
):
4652
logger.debug(
4753
"Child context already completed, skipping execution for id: %s, name: %s",
4854
operation_identifier.operation_id,
@@ -71,17 +77,37 @@ def child_handler(
7177

7278
try:
7379
raw_result: T = func()
80+
if checkpointed_result.is_replay_children():
81+
logger.debug(
82+
"ReplayChildren mode: Re-executing child context due to large payload: id: %s, name: %s",
83+
operation_identifier.operation_id,
84+
operation_identifier.name,
85+
)
86+
return raw_result
7487
serialized_result: str = serialize(
7588
serdes=config.serdes,
7689
value=raw_result,
7790
operation_id=operation_identifier.operation_id,
7891
durable_execution_arn=state.durable_execution_arn,
7992
)
93+
payload_to_checkpoint = serialized_result
94+
replay_children = False
95+
if len(serialized_result) > CHECKPOINT_SIZE_LIMIT:
96+
logger.debug(
97+
"Large payload detected, using ReplayChildren mode: id: %s, name: %s",
98+
operation_identifier.operation_id,
99+
operation_identifier.name,
100+
)
101+
replay_children = True
102+
payload_to_checkpoint = (
103+
config.summary_generator(raw_result) if config.summary_generator else ""
104+
)
80105

81106
success_operation = OperationUpdate.create_context_succeed(
82107
identifier=operation_identifier,
83-
payload=serialized_result,
108+
payload=payload_to_checkpoint,
84109
sub_type=sub_type,
110+
context_options=ContextOptions(replay_children=replay_children),
85111
)
86112
state.create_checkpoint(operation_update=success_operation)
87113

src/aws_durable_execution_sdk_python/state.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def is_timed_out(self) -> bool:
113113
return False
114114
return op.status is OperationStatus.TIMED_OUT
115115

116+
def is_replay_children(self) -> bool:
117+
op = self.operation
118+
if not op:
119+
return False
120+
context_details = op.context_details
121+
if not context_details:
122+
return False
123+
return context_details.replay_children
124+
116125
def raise_callable_error(self) -> None:
117126
if self.error is None:
118127
msg: str = "Attempted to throw exception, but no ErrorObject exists on the Checkpoint Operation."

tests/operation/child_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_child_handler_not_started(
4141
mock_result.is_succeeded.return_value = False
4242
mock_result.is_failed.return_value = False
4343
mock_result.is_started.return_value = False
44+
mock_result.is_replay_children.return_value = False
4445
mock_state.get_checkpoint_result.return_value = mock_result
4546
mock_callable = Mock(return_value="fresh_result")
4647

@@ -80,6 +81,7 @@ def test_child_handler_already_succeeded():
8081
mock_state.durable_execution_arn = "test_arn"
8182
mock_result = Mock()
8283
mock_result.is_succeeded.return_value = True
84+
mock_result.is_replay_children.return_value = False
8385
mock_result.result = json.dumps("cached_result")
8486
mock_state.get_checkpoint_result.return_value = mock_result
8587
mock_callable = Mock()
@@ -99,6 +101,7 @@ def test_child_handler_already_succeeded_none_result():
99101
mock_state.durable_execution_arn = "test_arn"
100102
mock_result = Mock()
101103
mock_result.is_succeeded.return_value = True
104+
mock_result.is_replay_children.return_value = False
102105
mock_result.result = None
103106
mock_state.get_checkpoint_result.return_value = mock_result
104107
mock_callable = Mock()
@@ -155,6 +158,7 @@ def test_child_handler_already_started(
155158
mock_result.is_succeeded.return_value = False
156159
mock_result.is_failed.return_value = False
157160
mock_result.is_started.return_value = True
161+
mock_result.is_replay_children.return_value = False
158162
mock_state.get_checkpoint_result.return_value = mock_result
159163
mock_callable = Mock(return_value="started_result")
160164

@@ -281,6 +285,7 @@ def test_child_handler_default_serialization():
281285
mock_result.is_succeeded.return_value = False
282286
mock_result.is_failed.return_value = False
283287
mock_result.is_started.return_value = False
288+
mock_result.is_replay_children.return_value = False
284289
mock_state.get_checkpoint_result.return_value = mock_result
285290
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
286291
mock_callable = Mock(return_value=complex_result)
@@ -306,6 +311,7 @@ def test_child_handler_custom_serdes_not_start():
306311
mock_result.is_succeeded.return_value = False
307312
mock_result.is_failed.return_value = False
308313
mock_result.is_started.return_value = False
314+
mock_result.is_replay_children.return_value = False
309315
mock_state.get_checkpoint_result.return_value = mock_result
310316
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
311317
mock_callable = Mock(return_value=complex_result)
@@ -334,6 +340,7 @@ def test_child_handler_custom_serdes_already_succeeded():
334340
mock_result.is_succeeded.return_value = True
335341
mock_result.is_failed.return_value = False
336342
mock_result.is_started.return_value = False
343+
mock_result.is_replay_children.return_value = False
337344
mock_result.result = '{"key": "VALUE", "number": "84", "list": [1, 2, 3]}'
338345
mock_state.get_checkpoint_result.return_value = mock_result
339346
mock_callable = Mock()
@@ -352,3 +359,90 @@ def test_child_handler_custom_serdes_already_succeeded():
352359

353360

354361
# endregion child_handler
362+
363+
364+
# large payload with summary generator
365+
def test_child_handler_large_payload_with_summary_generator():
366+
"""Test child_handler with large payload and summary generator."""
367+
mock_state = Mock(spec=ExecutionState)
368+
mock_state.durable_execution_arn = "test_arn"
369+
mock_result = Mock()
370+
mock_result.is_succeeded.return_value = False
371+
mock_result.is_failed.return_value = False
372+
mock_result.is_started.return_value = False
373+
mock_result.is_replay_children.return_value = False
374+
mock_state.get_checkpoint_result.return_value = mock_result
375+
large_result = "large" * 256 * 1024
376+
mock_callable = Mock(return_value=large_result)
377+
child_config: ChildConfig = ChildConfig(summary_generator=lambda x: "summary")
378+
379+
actual_result = child_handler(
380+
mock_callable,
381+
mock_state,
382+
OperationIdentifier("op9", None, "test_name"),
383+
child_config,
384+
)
385+
386+
assert large_result == actual_result
387+
success_call = mock_state.create_checkpoint.call_args_list[1]
388+
success_operation = success_call[1]["operation_update"]
389+
assert success_operation.context_options.replay_children
390+
expected_checkpoointed_result = "summary"
391+
assert success_operation.payload == expected_checkpoointed_result
392+
393+
394+
# large payload without summary generator
395+
def test_child_handler_large_payload_without_summary_generator():
396+
"""Test child_handler with large payload and no summary generator."""
397+
mock_state = Mock(spec=ExecutionState)
398+
mock_state.durable_execution_arn = "test_arn"
399+
mock_result = Mock()
400+
mock_result.is_succeeded.return_value = False
401+
mock_result.is_failed.return_value = False
402+
mock_result.is_started.return_value = False
403+
mock_result.is_replay_children.return_value = False
404+
mock_state.get_checkpoint_result.return_value = mock_result
405+
large_result = "large" * 256 * 1024
406+
mock_callable = Mock(return_value=large_result)
407+
child_config: ChildConfig = ChildConfig()
408+
409+
actual_result = child_handler(
410+
mock_callable,
411+
mock_state,
412+
OperationIdentifier("op9", None, "test_name"),
413+
child_config,
414+
)
415+
416+
assert large_result == actual_result
417+
success_call = mock_state.create_checkpoint.call_args_list[1]
418+
success_operation = success_call[1]["operation_update"]
419+
assert success_operation.context_options.replay_children
420+
expected_checkpoointed_result = ""
421+
assert success_operation.payload == expected_checkpoointed_result
422+
423+
424+
# mocked children replay mode execute the function again
425+
def test_child_handler_replay_children_mode():
426+
"""Test child_handler in ReplayChildren mode."""
427+
mock_state = Mock(spec=ExecutionState)
428+
mock_state.durable_execution_arn = "test_arn"
429+
mock_result = Mock()
430+
mock_result.is_succeeded.return_value = True
431+
mock_result.is_failed.return_value = False
432+
mock_result.is_started.return_value = True
433+
mock_result.is_replay_children.return_value = True
434+
mock_state.get_checkpoint_result.return_value = mock_result
435+
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
436+
mock_callable = Mock(return_value=complex_result)
437+
child_config: ChildConfig = ChildConfig()
438+
439+
actual_result = child_handler(
440+
mock_callable,
441+
mock_state,
442+
OperationIdentifier("op9", None, "test_name"),
443+
child_config,
444+
)
445+
446+
assert actual_result == complex_result
447+
448+
mock_state.create_checkpoint.assert_not_called()

0 commit comments

Comments
 (0)