Skip to content

Commit 103f056

Browse files
committed
test(helloworld): cover transport push delivery
Signed-off-by: Luca Muscariello <muscariello@ieee.org>
1 parent 5e57039 commit 103f056

File tree

3 files changed

+269
-3
lines changed

3 files changed

+269
-3
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/helloworld/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ chrono = { workspace = true }
1919
a2a-client = { workspace = true }
2020
async-trait = { workspace = true }
2121
reqwest = { workspace = true }
22+
serde_json = { workspace = true }

examples/helloworld/tests/transports_e2e.rs

Lines changed: 267 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,23 @@ use a2a_client::jsonrpc::JsonRpcTransport;
1111
use a2a_client::rest::RestTransport;
1212
use a2a_server::jsonrpc::jsonrpc_router;
1313
use a2a_server::rest::rest_router;
14-
use a2a_server::{RequestHandler, ServiceParams, WELL_KNOWN_AGENT_CARD_PATH};
14+
use a2a_server::{
15+
DefaultRequestHandler, ExecutorContext, HttpPushSender, InMemoryPushConfigStore,
16+
InMemoryTaskStore, RequestHandler, ServiceParams,
17+
WELL_KNOWN_AGENT_CARD_PATH,
18+
};
1519
use async_trait::async_trait;
16-
use axum::http::StatusCode;
17-
use axum::routing::get;
20+
use axum::body::Bytes;
21+
use axum::extract::State;
22+
use axum::http::{HeaderMap, StatusCode, header};
23+
use axum::routing::{get, post};
1824
use axum::{Json, Router};
1925
use futures::StreamExt;
2026
use futures::stream::{self, BoxStream};
2127
use reqwest::Client;
2228
use tokio::net::TcpListener;
29+
use tokio::sync::mpsc;
30+
use tokio::time::timeout;
2331

2432
fn sample_message(role: Role, text: &str) -> Message {
2533
Message {
@@ -89,6 +97,15 @@ fn sample_agent_card() -> AgentCard {
8997

9098
struct TestHandler;
9199

100+
struct PushTransportExecutor;
101+
102+
#[derive(Debug)]
103+
struct CapturedPush {
104+
authorization: Option<String>,
105+
notification_token: Option<String>,
106+
event: StreamResponse,
107+
}
108+
92109
#[async_trait]
93110
impl RequestHandler for TestHandler {
94111
async fn send_message(
@@ -253,6 +270,52 @@ impl RequestHandler for TestHandler {
253270
}
254271
}
255272

273+
impl a2a_server::AgentExecutor for PushTransportExecutor {
274+
fn execute(&self, ctx: ExecutorContext) -> BoxStream<'static, Result<StreamResponse, A2AError>> {
275+
let working = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
276+
task_id: ctx.task_id.clone(),
277+
context_id: ctx.context_id.clone(),
278+
status: TaskStatus {
279+
state: TaskState::Working,
280+
message: None,
281+
timestamp: None,
282+
},
283+
metadata: None,
284+
});
285+
let completed = StreamResponse::Task(Task {
286+
id: ctx.task_id,
287+
context_id: ctx.context_id,
288+
status: TaskStatus {
289+
state: TaskState::Completed,
290+
message: ctx.message,
291+
timestamp: None,
292+
},
293+
artifacts: None,
294+
history: ctx.stored_task.and_then(|task| task.history),
295+
metadata: None,
296+
});
297+
298+
Box::pin(stream::iter(vec![Ok(working), Ok(completed)]))
299+
}
300+
301+
fn cancel(&self, ctx: ExecutorContext) -> BoxStream<'static, Result<StreamResponse, A2AError>> {
302+
Box::pin(stream::once(async move {
303+
Ok(StreamResponse::Task(Task {
304+
id: ctx.task_id,
305+
context_id: ctx.context_id,
306+
status: TaskStatus {
307+
state: TaskState::Canceled,
308+
message: None,
309+
timestamp: None,
310+
},
311+
artifacts: None,
312+
history: None,
313+
metadata: None,
314+
}))
315+
}))
316+
}
317+
}
318+
256319
async fn spawn_http_server() -> (String, tokio::task::JoinHandle<()>) {
257320
let handler = Arc::new(TestHandler);
258321
let app = Router::new()
@@ -280,6 +343,76 @@ async fn spawn_http_server() -> (String, tokio::task::JoinHandle<()>) {
280343
(format!("http://{addr}"), handle)
281344
}
282345

346+
async fn spawn_push_http_server() -> (String, tokio::task::JoinHandle<()>) {
347+
let handler = Arc::new(
348+
DefaultRequestHandler::new(PushTransportExecutor, InMemoryTaskStore::new())
349+
.with_push_notifications(
350+
InMemoryPushConfigStore::new(),
351+
HttpPushSender::new(None),
352+
),
353+
);
354+
let app = Router::new()
355+
.nest("/rest", rest_router(handler.clone()))
356+
.nest("/rpc", jsonrpc_router(handler))
357+
.route(
358+
WELL_KNOWN_AGENT_CARD_PATH,
359+
get(|| async { Json(sample_agent_card()) }),
360+
);
361+
362+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
363+
let addr = listener.local_addr().unwrap();
364+
let handle = tokio::spawn(async move {
365+
axum::serve(listener, app).await.unwrap();
366+
});
367+
tokio::time::sleep(Duration::from_millis(20)).await;
368+
(format!("http://{addr}"), handle)
369+
}
370+
371+
async fn capture_push(
372+
State(sender): State<mpsc::UnboundedSender<CapturedPush>>,
373+
headers: HeaderMap,
374+
body: Bytes,
375+
) -> StatusCode {
376+
sender
377+
.send(CapturedPush {
378+
authorization: headers
379+
.get(header::AUTHORIZATION)
380+
.and_then(|value| value.to_str().ok())
381+
.map(ToOwned::to_owned),
382+
notification_token: headers
383+
.get("A2A-Notification-Token")
384+
.and_then(|value| value.to_str().ok())
385+
.map(ToOwned::to_owned),
386+
event: serde_json::from_slice(&body).unwrap(),
387+
})
388+
.unwrap();
389+
StatusCode::ACCEPTED
390+
}
391+
392+
async fn spawn_webhook_server(
393+
) -> (
394+
String,
395+
mpsc::UnboundedReceiver<CapturedPush>,
396+
tokio::task::JoinHandle<()>,
397+
) {
398+
let (sender, receiver) = mpsc::unbounded_channel();
399+
let app = Router::new().route("/", post(capture_push)).with_state(sender);
400+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
401+
let addr = listener.local_addr().unwrap();
402+
let handle = tokio::spawn(async move {
403+
axum::serve(listener, app).await.unwrap();
404+
});
405+
tokio::time::sleep(Duration::from_millis(20)).await;
406+
(format!("http://{addr}/"), receiver, handle)
407+
}
408+
409+
async fn recv_push(receiver: &mut mpsc::UnboundedReceiver<CapturedPush>) -> CapturedPush {
410+
timeout(Duration::from_secs(5), receiver.recv())
411+
.await
412+
.unwrap()
413+
.unwrap()
414+
}
415+
283416
fn send_message_request() -> SendMessageRequest {
284417
SendMessageRequest {
285418
message: sample_message(Role::User, "hello"),
@@ -690,3 +823,134 @@ async fn jsonrpc_transport_end_to_end() {
690823

691824
handle.abort();
692825
}
826+
827+
#[tokio::test]
828+
async fn rest_transport_push_delivery_end_to_end() {
829+
let (base_url, server_handle) = spawn_push_http_server().await;
830+
let (webhook_url, mut receiver, webhook_handle) = spawn_webhook_server().await;
831+
let transport = RestTransport::new(Client::new(), format!("{base_url}/rest"));
832+
833+
transport
834+
.create_push_config(
835+
&ServiceParams::new(),
836+
&CreateTaskPushNotificationConfigRequest {
837+
task_id: "task-rest-push".to_string(),
838+
config: PushNotificationConfig {
839+
url: webhook_url,
840+
id: Some("cfg-rest".to_string()),
841+
token: Some("rest-token".to_string()),
842+
authentication: Some(AuthenticationInfo {
843+
scheme: "Basic".to_string(),
844+
credentials: Some("dGVzdDpzZWNyZXQ=".to_string()),
845+
}),
846+
},
847+
tenant: None,
848+
},
849+
)
850+
.await
851+
.unwrap();
852+
853+
let mut request = send_message_request();
854+
request.message.task_id = Some("task-rest-push".to_string());
855+
request.message.context_id = Some("ctx-rest-push".to_string());
856+
857+
let response = transport
858+
.send_message(&ServiceParams::new(), &request)
859+
.await
860+
.unwrap();
861+
assert!(matches!(response, SendMessageResponse::Task(_)));
862+
863+
let first = recv_push(&mut receiver).await;
864+
assert_eq!(first.authorization.as_deref(), Some("Basic dGVzdDpzZWNyZXQ="));
865+
assert_eq!(first.notification_token.as_deref(), Some("rest-token"));
866+
match first.event {
867+
StreamResponse::StatusUpdate(update) => {
868+
assert_eq!(update.task_id, "task-rest-push");
869+
assert_eq!(update.status.state, TaskState::Working);
870+
}
871+
_ => panic!("expected status update push"),
872+
}
873+
874+
let second = recv_push(&mut receiver).await;
875+
assert_eq!(second.authorization.as_deref(), Some("Basic dGVzdDpzZWNyZXQ="));
876+
assert_eq!(second.notification_token.as_deref(), Some("rest-token"));
877+
match second.event {
878+
StreamResponse::Task(task) => {
879+
assert_eq!(task.id, "task-rest-push");
880+
assert_eq!(task.status.state, TaskState::Completed);
881+
}
882+
_ => panic!("expected final task push"),
883+
}
884+
885+
server_handle.abort();
886+
webhook_handle.abort();
887+
}
888+
889+
#[tokio::test]
890+
async fn jsonrpc_transport_push_delivery_end_to_end() {
891+
let (base_url, server_handle) = spawn_push_http_server().await;
892+
let (webhook_url, mut receiver, webhook_handle) = spawn_webhook_server().await;
893+
let transport = JsonRpcTransport::new(Client::new(), format!("{base_url}/rpc"));
894+
895+
let mut request = send_message_request();
896+
request.message.task_id = Some("task-rpc-push".to_string());
897+
request.message.context_id = Some("ctx-rpc-push".to_string());
898+
request.configuration = Some(SendMessageConfiguration {
899+
accepted_output_modes: None,
900+
push_notification_config: Some(PushNotificationConfig {
901+
url: webhook_url.clone(),
902+
id: Some("cfg-rpc".to_string()),
903+
token: Some("rpc-token".to_string()),
904+
authentication: Some(AuthenticationInfo {
905+
scheme: "Bearer".to_string(),
906+
credentials: Some("rpc-secret".to_string()),
907+
}),
908+
}),
909+
history_length: None,
910+
return_immediately: None,
911+
});
912+
913+
let response = transport
914+
.send_message(&ServiceParams::new(), &request)
915+
.await
916+
.unwrap();
917+
assert!(matches!(response, SendMessageResponse::Task(_)));
918+
919+
let saved = transport
920+
.get_push_config(
921+
&ServiceParams::new(),
922+
&GetTaskPushNotificationConfigRequest {
923+
task_id: "task-rpc-push".to_string(),
924+
id: "cfg-rpc".to_string(),
925+
tenant: None,
926+
},
927+
)
928+
.await
929+
.unwrap();
930+
assert_eq!(saved.config.url, webhook_url);
931+
932+
let first = recv_push(&mut receiver).await;
933+
assert_eq!(first.authorization.as_deref(), Some("Bearer rpc-secret"));
934+
assert_eq!(first.notification_token.as_deref(), Some("rpc-token"));
935+
match first.event {
936+
StreamResponse::StatusUpdate(update) => {
937+
assert_eq!(update.task_id, "task-rpc-push");
938+
assert_eq!(update.status.state, TaskState::Working);
939+
}
940+
_ => panic!("expected status update push"),
941+
}
942+
943+
let second = recv_push(&mut receiver).await;
944+
assert_eq!(second.authorization.as_deref(), Some("Bearer rpc-secret"));
945+
assert_eq!(second.notification_token.as_deref(), Some("rpc-token"));
946+
match second.event {
947+
StreamResponse::Task(task) => {
948+
assert_eq!(task.id, "task-rpc-push");
949+
assert_eq!(task.status.state, TaskState::Completed);
950+
}
951+
_ => panic!("expected final task push"),
952+
}
953+
954+
server_handle.abort();
955+
webhook_handle.abort();
956+
}

0 commit comments

Comments
 (0)