Skip to content

Commit edea867

Browse files
committed
feat: add interactive callback CLI
1 parent 327a4a3 commit edea867

4 files changed

Lines changed: 464 additions & 3 deletions

File tree

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Interactive callback handler for durable function executions.
3+
"""
4+
5+
import logging
6+
from typing import Optional
7+
8+
import click
9+
10+
from samcli.lib.clients.lambda_client import DurableFunctionsClient
11+
12+
LOG = logging.getLogger(__name__)
13+
14+
# Menu choice constants
15+
CHOICE_SUCCESS = 1
16+
CHOICE_FAILURE = 2
17+
CHOICE_HEARTBEAT = 3
18+
CHOICE_STOP = 4
19+
20+
21+
class DurableCallbackHandler:
22+
"""
23+
Handles interactive callback detection and response for durable executions.
24+
"""
25+
26+
def __init__(self, client: DurableFunctionsClient):
27+
self.client = client
28+
self._prompted_callbacks: set[str] = set() # Track which callbacks we've already prompted for
29+
30+
def check_for_pending_callbacks(self, execution_arn: str) -> Optional[str]:
31+
"""
32+
Check execution history for pending callbacks.
33+
34+
Returns:
35+
callback_id if found, None otherwise
36+
"""
37+
try:
38+
LOG.debug("Checking for pending callbacks in execution: %s", execution_arn)
39+
history = self.client.get_durable_execution_history(execution_arn)
40+
events = history.get("Events", [])
41+
42+
if events:
43+
callback_states = {}
44+
45+
for event in events:
46+
event_type = event.get("EventType")
47+
event_id = event.get("Id")
48+
49+
if event_type == "CallbackStarted":
50+
callback_id = event.get("CallbackStartedDetails", {}).get("CallbackId")
51+
callback_states[event_id] = {"callback_id": callback_id, "status": "STARTED", "event": event}
52+
elif event_type in ["CallbackCompleted", "CallbackFailed", "CallbackSucceeded"]:
53+
if event_id in callback_states:
54+
callback_states[event_id]["status"] = "COMPLETED"
55+
56+
# Find callbacks that are started but not completed
57+
for callback_id, state in callback_states.items():
58+
if state["status"] == "STARTED" and state["callback_id"]:
59+
return str(state["callback_id"])
60+
61+
except Exception as e:
62+
LOG.error("Failed to check callback history: %s", e)
63+
64+
return None
65+
66+
def prompt_callback_response(self, execution_arn: str, callback_id: str, execution_complete=None) -> bool:
67+
"""
68+
Prompt user for callback response and send it.
69+
70+
Args:
71+
execution_arn: The execution ARN for stop execution operation
72+
callback_id: The callback ID to respond to
73+
execution_complete: Optional threading.Event to check if execution finished
74+
75+
Returns:
76+
True if callback was sent, False if user chose to continue waiting
77+
"""
78+
# Only prompt once per callback ID to avoid blocking on timed-out callbacks
79+
if callback_id in self._prompted_callbacks:
80+
return False
81+
82+
self._prompted_callbacks.add(callback_id)
83+
84+
# Check if execution already completed before prompting
85+
if execution_complete and execution_complete.is_set():
86+
return False
87+
88+
click.echo(f"\n🔄 Execution is waiting for callback: {callback_id}")
89+
click.echo("Choose an action:")
90+
click.echo(" 1. Send callback success")
91+
click.echo(" 2. Send callback failure")
92+
click.echo(" 3. Send callback heartbeat")
93+
click.echo(" 4. Stop execution")
94+
95+
choice = click.prompt("Enter choice", type=click.IntRange(1, 4), default=CHOICE_SUCCESS)
96+
97+
# Check again after user makes selection in case execution completed
98+
if execution_complete and execution_complete.is_set():
99+
click.echo("⚠️ Execution already completed, callback no longer needed")
100+
return False
101+
102+
try:
103+
if choice == CHOICE_SUCCESS:
104+
result = click.prompt("Enter success result (optional)", default="", show_default=False)
105+
self.client.send_callback_success(callback_id=callback_id, result=result)
106+
click.echo("✅ Callback success sent")
107+
return True
108+
109+
elif choice == CHOICE_FAILURE:
110+
error_message = click.prompt("Enter error message", default="User cancelled")
111+
error_type = click.prompt("Enter error type (optional)", default="", show_default=False) or None
112+
113+
self.client.send_callback_failure(
114+
callback_id=callback_id, error_message=error_message, error_type=error_type
115+
)
116+
click.echo("❌ Callback failure sent")
117+
return True
118+
119+
elif choice == CHOICE_HEARTBEAT:
120+
self.client.send_callback_heartbeat(callback_id=callback_id)
121+
click.echo("💓 Callback heartbeat sent")
122+
return False # Continue waiting after heartbeat
123+
124+
else: # CHOICE_STOP
125+
error_message = click.prompt("Enter error message", default="Execution stopped by user")
126+
error_type = click.prompt("Enter error type (optional)", default="", show_default=False) or None
127+
128+
self.client.stop_durable_execution(
129+
durable_execution_arn=execution_arn, error_message=error_message, error_type=error_type
130+
)
131+
click.echo("🛑 Execution stopped")
132+
return True
133+
134+
except Exception as e:
135+
LOG.error("Failed to send callback: %s", e)
136+
click.echo(f"❌ Failed to send callback: {e}")
137+
return False

samcli/local/docker/durable_lambda_container.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import click
1010
from flask import has_request_context
1111

12+
from samcli.lib.utils.durable_callback_handler import DurableCallbackHandler
1213
from samcli.lib.utils.durable_formatters import format_execution_details, format_next_commands_after_invoke
1314
from samcli.local.docker.lambda_container import LambdaContainer
1415

@@ -146,8 +147,11 @@ def _write_execution_result_to_stdout(self, execution_details: dict, stdout):
146147
def _wait_for_execution(self, execution_arn):
147148
"""Poll the execution status until completion and return the final result."""
148149

149-
# TODO - poll until the execution timeout is hit
150+
callback_handler = DurableCallbackHandler(self.emulator_container.lambda_client)
150151
execution_details = None
152+
callback_thread = None
153+
stop_callback_prompts = threading.Event()
154+
151155
try:
152156
while True:
153157
try:
@@ -156,13 +160,36 @@ def _wait_for_execution(self, execution_arn):
156160
status = execution_details.get("Status")
157161

158162
if status != "RUNNING":
163+
stop_callback_prompts.set() # Signal callback thread to stop
164+
if callback_thread and callback_thread.is_alive():
165+
callback_thread.join(timeout=0.5) # Brief wait for thread cleanup
159166
return execution_details
160167

168+
# Check for pending callbacks (only in CLI context)
169+
if self._is_cli_context():
170+
callback_id = callback_handler.check_for_pending_callbacks(execution_arn)
171+
if callback_id:
172+
173+
def _prompt_in_thread():
174+
if not stop_callback_prompts.is_set():
175+
# give the function logs time to settle after the invocation is suspended
176+
time.sleep(0.5)
177+
callback_sent = callback_handler.prompt_callback_response(
178+
execution_arn, callback_id, stop_callback_prompts
179+
)
180+
if callback_sent:
181+
click.echo("\n" + "─" * 80)
182+
183+
# Start callback prompt in separate thread so it doesn't block polling
184+
callback_thread = threading.Thread(target=_prompt_in_thread, daemon=True)
185+
callback_thread.start()
186+
161187
time.sleep(1) # Poll every second
162188
except Exception as e:
163189
LOG.error("Error polling execution status: %s", e)
164190
break
165191
finally:
192+
stop_callback_prompts.set() # Ensure callback thread knows to stop
166193
self._cleanup_if_needed()
167194

168195
return execution_details
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
Unit tests for DurableCallbackHandler
3+
"""
4+
5+
import threading
6+
from unittest import TestCase
7+
from unittest.mock import Mock, patch
8+
9+
from parameterized import parameterized
10+
11+
from samcli.lib.utils.durable_callback_handler import (
12+
DurableCallbackHandler,
13+
CHOICE_SUCCESS,
14+
CHOICE_FAILURE,
15+
CHOICE_HEARTBEAT,
16+
CHOICE_STOP,
17+
)
18+
19+
20+
class TestDurableCallbackHandler(TestCase):
21+
def setUp(self):
22+
self.mock_client = Mock()
23+
self.handler = DurableCallbackHandler(self.mock_client)
24+
25+
def test_init(self):
26+
"""Test handler initializes with client and empty prompted callbacks set"""
27+
self.assertEqual(self.handler.client, self.mock_client)
28+
self.assertEqual(self.handler._prompted_callbacks, set())
29+
30+
@parameterized.expand(
31+
[
32+
(
33+
"with_pending_callback",
34+
{
35+
"Events": [
36+
{
37+
"Id": 1,
38+
"EventType": "CallbackStarted",
39+
"CallbackStartedDetails": {"CallbackId": "callback-123"},
40+
},
41+
{"Id": 2, "EventType": "StepStarted"},
42+
]
43+
},
44+
"callback-123",
45+
),
46+
(
47+
"with_completed_callback",
48+
{
49+
"Events": [
50+
{
51+
"Id": 1,
52+
"EventType": "CallbackStarted",
53+
"CallbackStartedDetails": {"CallbackId": "callback-123"},
54+
},
55+
{"Id": 1, "EventType": "CallbackCompleted"},
56+
]
57+
},
58+
None,
59+
),
60+
(
61+
"no_callbacks",
62+
{"Events": [{"Id": 1, "EventType": "StepStarted"}]},
63+
None,
64+
),
65+
]
66+
)
67+
def test_check_for_pending_callbacks(self, name, history_response, expected_callback_id):
68+
"""Test checking for pending callbacks"""
69+
self.mock_client.get_durable_execution_history.return_value = history_response
70+
71+
callback_id = self.handler.check_for_pending_callbacks("test-arn")
72+
73+
self.assertEqual(callback_id, expected_callback_id)
74+
self.mock_client.get_durable_execution_history.assert_called_once_with("test-arn")
75+
76+
def test_check_for_pending_callbacks_handles_exception(self):
77+
"""Test that exceptions during callback check are handled gracefully"""
78+
self.mock_client.get_durable_execution_history.side_effect = Exception("API error")
79+
80+
callback_id = self.handler.check_for_pending_callbacks("test-arn")
81+
82+
self.assertIsNone(callback_id)
83+
84+
@parameterized.expand(
85+
[
86+
(
87+
"success",
88+
CHOICE_SUCCESS,
89+
["test result"],
90+
"send_callback_success",
91+
{"callback_id": "callback-123", "result": "test result"},
92+
True,
93+
),
94+
(
95+
"failure",
96+
CHOICE_FAILURE,
97+
["Error occurred", "CustomError"],
98+
"send_callback_failure",
99+
{"callback_id": "callback-123", "error_message": "Error occurred", "error_type": "CustomError"},
100+
True,
101+
),
102+
(
103+
"heartbeat",
104+
CHOICE_HEARTBEAT,
105+
[],
106+
"send_callback_heartbeat",
107+
{"callback_id": "callback-123"},
108+
False,
109+
),
110+
(
111+
"stop_execution",
112+
CHOICE_STOP,
113+
["Execution stopped by user", "StopError"],
114+
"stop_durable_execution",
115+
{
116+
"durable_execution_arn": "test-arn",
117+
"error_message": "Execution stopped by user",
118+
"error_type": "StopError",
119+
},
120+
True,
121+
),
122+
]
123+
)
124+
@patch("samcli.lib.utils.durable_callback_handler.click.prompt")
125+
@patch("samcli.lib.utils.durable_callback_handler.click.echo")
126+
def test_prompt_callback_response(
127+
self,
128+
name,
129+
choice,
130+
prompt_responses,
131+
method_name,
132+
expected_call_args,
133+
expected_result,
134+
mock_echo,
135+
mock_prompt,
136+
):
137+
"""Test prompting for different callback response types"""
138+
mock_prompt.side_effect = [choice] + prompt_responses
139+
140+
result = self.handler.prompt_callback_response("test-arn", "callback-123")
141+
142+
self.assertEqual(result, expected_result)
143+
lambda_client_api_call = getattr(self.mock_client, method_name)
144+
lambda_client_api_call.assert_called_once_with(**expected_call_args)
145+
self.assertIn("callback-123", self.handler._prompted_callbacks)
146+
147+
def test_prompt_callback_response_only_prompts_once(self):
148+
"""Test that callback is only prompted once per ID"""
149+
self.handler._prompted_callbacks.add("callback-123")
150+
151+
result = self.handler.prompt_callback_response("test-arn", "callback-123")
152+
153+
self.assertFalse(result)
154+
self.mock_client.send_callback_success.assert_not_called()
155+
156+
@parameterized.expand(
157+
[
158+
("before_prompt", True, False),
159+
("after_selection", False, True),
160+
]
161+
)
162+
@patch("samcli.lib.utils.durable_callback_handler.click.prompt")
163+
@patch("samcli.lib.utils.durable_callback_handler.click.echo")
164+
def test_prompt_callback_response_checks_execution_complete(
165+
self, name, set_before_prompt, set_during_prompt, mock_echo, mock_prompt
166+
):
167+
"""Test that prompt respects execution_complete event"""
168+
execution_complete = threading.Event()
169+
170+
if set_before_prompt:
171+
execution_complete.set()
172+
elif set_during_prompt:
173+
174+
def prompt_side_effect(*args, **kwargs):
175+
execution_complete.set()
176+
return CHOICE_SUCCESS
177+
178+
mock_prompt.side_effect = prompt_side_effect
179+
180+
result = self.handler.prompt_callback_response("test-arn", "callback-123", execution_complete)
181+
182+
self.assertFalse(result)
183+
self.mock_client.send_callback_success.assert_not_called()
184+
if set_before_prompt:
185+
mock_prompt.assert_not_called()
186+
187+
@patch("samcli.lib.utils.durable_callback_handler.click.prompt")
188+
@patch("samcli.lib.utils.durable_callback_handler.click.echo")
189+
def test_prompt_callback_response_handles_exception(self, mock_echo, mock_prompt):
190+
"""Test that exceptions during callback send are handled gracefully"""
191+
mock_prompt.return_value = CHOICE_HEARTBEAT
192+
193+
self.mock_client.send_callback_heartbeat.side_effect = Exception("API error")
194+
195+
result = self.handler.prompt_callback_response("test-arn", "callback-123")
196+
197+
self.assertFalse(result)

0 commit comments

Comments
 (0)