Skip to content

Commit af0522b

Browse files
author
Alex Wang
committed
chore: use updated sdk model
1 parent f471ed3 commit af0522b

5 files changed

Lines changed: 95 additions & 26 deletions

File tree

src/aws_durable_execution_sdk_python/lambda_service.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -780,13 +780,18 @@ class DurableServiceClient(Protocol):
780780

781781
def checkpoint(
782782
self,
783+
durable_execution_arn: str,
783784
checkpoint_token: str,
784785
updates: list[OperationUpdate],
785786
client_token: str | None,
786787
) -> CheckpointOutput: ... # pragma: no cover
787788

788789
def get_execution_state(
789-
self, checkpoint_token: str, next_marker: str, max_items: int = 1000
790+
self,
791+
durable_execution_arn: str,
792+
checkpoint_token: str,
793+
next_marker: str,
794+
max_items: int = 1000,
790795
) -> StateOutput: ... # pragma: no cover
791796

792797
def stop(
@@ -866,12 +871,14 @@ def initialize_from_env() -> LambdaClient:
866871

867872
def checkpoint(
868873
self,
874+
durable_execution_arn: str,
869875
checkpoint_token: str,
870876
updates: list[OperationUpdate],
871877
client_token: str | None,
872878
) -> CheckpointOutput:
873879
try:
874880
params = {
881+
"DurableExecutionArn": durable_execution_arn,
875882
"CheckpointToken": checkpoint_token,
876883
"Updates": [o.to_dict() for o in updates],
877884
}
@@ -888,10 +895,17 @@ def checkpoint(
888895
raise CheckpointError(e) from e
889896

890897
def get_execution_state(
891-
self, checkpoint_token: str, next_marker: str, max_items: int = 1000
898+
self,
899+
durable_execution_arn: str,
900+
checkpoint_token: str,
901+
next_marker: str,
902+
max_items: int = 1000,
892903
) -> StateOutput:
893904
result: MutableMapping[str, Any] = self.client.get_durable_execution_state(
894-
CheckpointToken=checkpoint_token, Marker=next_marker, MaxItems=max_items
905+
DurableExecutionArn=durable_execution_arn,
906+
CheckpointToken=checkpoint_token,
907+
Marker=next_marker,
908+
MaxItems=max_items,
895909
)
896910
return StateOutput.from_dict(result)
897911

src/aws_durable_execution_sdk_python/state.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def fetch_paginated_operations(
162162
)
163163
while next_marker:
164164
output: StateOutput = self._service_client.get_execution_state(
165+
durable_execution_arn=self.durable_execution_arn,
165166
checkpoint_token=checkpoint_token,
166167
next_marker=next_marker,
167168
)
@@ -227,6 +228,7 @@ def create_checkpoint(
227228
[operation_update] if operation_update is not None else []
228229
)
229230
output: CheckpointOutput = self._service_client.checkpoint(
231+
durable_execution_arn=self.durable_execution_arn,
230232
checkpoint_token=self._current_checkpoint_token,
231233
updates=updates,
232234
client_token=None,

tests/e2e/execution_int_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ def my_handler(event, context: DurableContext) -> list[str]:
6565
# Mock the checkpoint method to track calls
6666
checkpoint_calls = []
6767

68-
def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107
68+
def mock_checkpoint(
69+
durable_execution_arn,
70+
checkpoint_token,
71+
updates,
72+
client_token="token", # noqa: S107
73+
):
6974
checkpoint_calls.append(updates)
7075

7176
return CheckpointOutput(
@@ -142,7 +147,12 @@ def my_handler(event, context: DurableContext):
142147
# Mock the checkpoint method to track calls
143148
checkpoint_calls = []
144149

145-
def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107
150+
def mock_checkpoint(
151+
durable_execution_arn,
152+
checkpoint_token,
153+
updates,
154+
client_token="token", # noqa: S107
155+
):
146156
checkpoint_calls.append(updates)
147157

148158
return CheckpointOutput(
@@ -224,7 +234,12 @@ def my_handler(event, context):
224234
# Mock the checkpoint method to track calls
225235
checkpoint_calls = []
226236

227-
def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107
237+
def mock_checkpoint(
238+
durable_execution_arn,
239+
checkpoint_token,
240+
updates,
241+
client_token="token", # noqa: S107
242+
):
228243
checkpoint_calls.append(updates)
229244

230245
return CheckpointOutput(
@@ -310,7 +325,12 @@ def my_handler(event: Any, context: DurableContext):
310325
# Mock the checkpoint method to track calls
311326
checkpoint_calls = []
312327

313-
def mock_checkpoint(checkpoint_token, updates, client_token="token"): # noqa: S107
328+
def mock_checkpoint(
329+
durable_execution_arn,
330+
checkpoint_token,
331+
updates,
332+
client_token="token", # noqa: S107
333+
):
314334
checkpoint_calls.append(updates)
315335

316336
return CheckpointOutput(

tests/lambda_service_test.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -838,10 +838,12 @@ def test_lambda_client_checkpoint():
838838
action=OperationAction.START,
839839
)
840840

841-
result = lambda_client.checkpoint("token123", [update], None)
841+
result = lambda_client.checkpoint("arn123", "token123", [update], None)
842842

843843
mock_client.checkpoint_durable_execution.assert_called_once_with(
844-
CheckpointToken="token123", Updates=[update.to_dict()]
844+
DurableExecutionArn="arn123",
845+
CheckpointToken="token123",
846+
Updates=[update.to_dict()],
845847
)
846848
assert isinstance(result, CheckpointOutput)
847849
assert result.checkpoint_token == "new_token" # noqa: S105
@@ -862,9 +864,12 @@ def test_lambda_client_checkpoint_with_client_token():
862864
action=OperationAction.START,
863865
)
864866

865-
result = lambda_client.checkpoint("token123", [update], "client-token-123")
867+
result = lambda_client.checkpoint(
868+
"arn123", "token123", [update], "client-token-123"
869+
)
866870

867871
mock_client.checkpoint_durable_execution.assert_called_once_with(
872+
DurableExecutionArn="arn123",
868873
CheckpointToken="token123",
869874
Updates=[update.to_dict()],
870875
ClientToken="client-token-123",
@@ -888,10 +893,12 @@ def test_lambda_client_checkpoint_with_explicit_none_client_token():
888893
action=OperationAction.START,
889894
)
890895

891-
result = lambda_client.checkpoint("token123", [update], None)
896+
result = lambda_client.checkpoint("arn123", "token123", [update], None)
892897

893898
mock_client.checkpoint_durable_execution.assert_called_once_with(
894-
CheckpointToken="token123", Updates=[update.to_dict()]
899+
DurableExecutionArn="arn123",
900+
CheckpointToken="token123",
901+
Updates=[update.to_dict()],
895902
)
896903
assert isinstance(result, CheckpointOutput)
897904
assert result.checkpoint_token == "new_token" # noqa: S105
@@ -912,10 +919,13 @@ def test_lambda_client_checkpoint_with_empty_string_client_token():
912919
action=OperationAction.START,
913920
)
914921

915-
result = lambda_client.checkpoint("token123", [update], "")
922+
result = lambda_client.checkpoint("arn123", "token123", [update], "")
916923

917924
mock_client.checkpoint_durable_execution.assert_called_once_with(
918-
CheckpointToken="token123", Updates=[update.to_dict()], ClientToken=""
925+
DurableExecutionArn="arn123",
926+
CheckpointToken="token123",
927+
Updates=[update.to_dict()],
928+
ClientToken="",
919929
)
920930
assert isinstance(result, CheckpointOutput)
921931
assert result.checkpoint_token == "new_token" # noqa: S105
@@ -936,9 +946,10 @@ def test_lambda_client_checkpoint_with_string_value_client_token():
936946
action=OperationAction.START,
937947
)
938948

939-
result = lambda_client.checkpoint("token123", [update], "my-client-token")
949+
result = lambda_client.checkpoint("arn123", "token123", [update], "my-client-token")
940950

941951
mock_client.checkpoint_durable_execution.assert_called_once_with(
952+
DurableExecutionArn="arn123",
942953
CheckpointToken="token123",
943954
Updates=[update.to_dict()],
944955
ClientToken="my-client-token",
@@ -960,7 +971,7 @@ def test_lambda_client_checkpoint_with_exception():
960971
)
961972

962973
with pytest.raises(CheckpointError):
963-
lambda_client.checkpoint("token123", [update], None)
974+
lambda_client.checkpoint("arn123", "token123", [update], None)
964975

965976

966977
def test_lambda_client_get_execution_state():
@@ -971,10 +982,13 @@ def test_lambda_client_get_execution_state():
971982
}
972983

973984
lambda_client = LambdaClient(mock_client)
974-
result = lambda_client.get_execution_state("token123", "marker", 500)
985+
result = lambda_client.get_execution_state("arn123", "token123", "marker", 500)
975986

976987
mock_client.get_durable_execution_state.assert_called_once_with(
977-
CheckpointToken="token123", Marker="marker", MaxItems=500
988+
DurableExecutionArn="arn123",
989+
CheckpointToken="token123",
990+
Marker="marker",
991+
MaxItems=500,
978992
)
979993
assert len(result.operations) == 1
980994

@@ -1018,9 +1032,11 @@ def test_durable_service_client_protocol_checkpoint():
10181032
)
10191033
]
10201034

1021-
result = mock_client.checkpoint("token", updates, "client_token")
1035+
result = mock_client.checkpoint("arn123", "token", updates, "client_token")
10221036

1023-
mock_client.checkpoint.assert_called_once_with("token", updates, "client_token")
1037+
mock_client.checkpoint.assert_called_once_with(
1038+
"arn123", "token", updates, "client_token"
1039+
)
10241040
assert result == mock_output
10251041

10261042

@@ -1030,9 +1046,11 @@ def test_durable_service_client_protocol_get_execution_state():
10301046
mock_output = StateOutput(operations=[], next_marker="marker")
10311047
mock_client.get_execution_state.return_value = mock_output
10321048

1033-
result = mock_client.get_execution_state("token", "marker", 1000)
1049+
result = mock_client.get_execution_state("arn123", "token", "marker", 1000)
10341050

1035-
mock_client.get_execution_state.assert_called_once_with("token", "marker", 1000)
1051+
mock_client.get_execution_state.assert_called_once_with(
1052+
"arn123", "token", "marker", 1000
1053+
)
10361054
assert result == mock_output
10371055

10381056

tests/state_test.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def test_create_checkpoint():
384384

385385
# Verify the checkpoint was called
386386
mock_lambda_client.checkpoint.assert_called_once_with(
387+
durable_execution_arn="test_arn",
387388
checkpoint_token="token123", # noqa: S106
388389
updates=[operation_update],
389390
client_token=None,
@@ -416,6 +417,7 @@ def test_create_checkpoint_with_none():
416417

417418
# Verify the checkpoint was called with empty updates
418419
mock_lambda_client.checkpoint.assert_called_once_with(
420+
durable_execution_arn="test_arn",
419421
checkpoint_token="token123", # noqa: S106
420422
updates=[],
421423
client_token=None,
@@ -444,6 +446,7 @@ def test_create_checkpoint_with_no_args():
444446

445447
# Verify the checkpoint was called with empty updates
446448
mock_lambda_client.checkpoint.assert_called_once_with(
449+
durable_execution_arn="test_arn",
447450
checkpoint_token="token123", # noqa: S106
448451
updates=[],
449452
client_token=None,
@@ -514,7 +517,7 @@ def test_checkpointed_result_is_timed_out_false_for_other_statuses():
514517
def test_fetch_paginated_operations_with_marker():
515518
mock_lambda_client = Mock(spec=LambdaClient)
516519

517-
def mock_get_execution_state(checkpoint_token, next_marker):
520+
def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marker):
518521
resp = {
519522
"marker1": StateOutput(
520523
operations=[
@@ -573,9 +576,21 @@ def mock_get_execution_state(checkpoint_token, next_marker):
573576
assert mock_lambda_client.get_execution_state.call_count == 3
574577
mock_lambda_client.get_execution_state.assert_has_calls(
575578
[
576-
call(checkpoint_token="test_token", next_marker="marker1"), # noqa: S106
577-
call(checkpoint_token="test_token", next_marker="marker2"), # noqa: S106
578-
call(checkpoint_token="test_token", next_marker="marker3"), # noqa: S106
579+
call(
580+
durable_execution_arn="test_arn",
581+
checkpoint_token="test_token", # noqa: S106
582+
next_marker="marker1",
583+
),
584+
call(
585+
durable_execution_arn="test_arn",
586+
checkpoint_token="test_token", # noqa: S106
587+
next_marker="marker2",
588+
),
589+
call(
590+
durable_execution_arn="test_arn",
591+
checkpoint_token="test_token", # noqa: S106
592+
next_marker="marker3",
593+
),
579594
]
580595
)
581596

0 commit comments

Comments
 (0)