@@ -11,15 +11,23 @@ use a2a_client::jsonrpc::JsonRpcTransport;
1111use a2a_client:: rest:: RestTransport ;
1212use a2a_server:: jsonrpc:: jsonrpc_router;
1313use a2a_server:: rest:: rest_router;
14- use a2a_server:: { RequestHandler , ServiceParams , WELL_KNOWN_AGENT_CARD_PATH } ;
14+ use a2a_server:: {
15+ DefaultRequestHandler , ExecutorContext , HttpPushSender , InMemoryPushConfigStore ,
16+ InMemoryTaskStore , RequestHandler , ServiceParams ,
17+ WELL_KNOWN_AGENT_CARD_PATH ,
18+ } ;
1519use async_trait:: async_trait;
16- use axum:: http:: StatusCode ;
17- use axum:: routing:: get;
20+ use axum:: body:: Bytes ;
21+ use axum:: extract:: State ;
22+ use axum:: http:: { HeaderMap , StatusCode , header} ;
23+ use axum:: routing:: { get, post} ;
1824use axum:: { Json , Router } ;
1925use futures:: StreamExt ;
2026use futures:: stream:: { self , BoxStream } ;
2127use reqwest:: Client ;
2228use tokio:: net:: TcpListener ;
29+ use tokio:: sync:: mpsc;
30+ use tokio:: time:: timeout;
2331
2432fn sample_message ( role : Role , text : & str ) -> Message {
2533 Message {
@@ -89,6 +97,15 @@ fn sample_agent_card() -> AgentCard {
8997
9098struct TestHandler ;
9199
100+ struct PushTransportExecutor ;
101+
102+ #[ derive( Debug ) ]
103+ struct CapturedPush {
104+ authorization : Option < String > ,
105+ notification_token : Option < String > ,
106+ event : StreamResponse ,
107+ }
108+
92109#[ async_trait]
93110impl RequestHandler for TestHandler {
94111 async fn send_message (
@@ -253,6 +270,52 @@ impl RequestHandler for TestHandler {
253270 }
254271}
255272
273+ impl a2a_server:: AgentExecutor for PushTransportExecutor {
274+ fn execute ( & self , ctx : ExecutorContext ) -> BoxStream < ' static , Result < StreamResponse , A2AError > > {
275+ let working = StreamResponse :: StatusUpdate ( TaskStatusUpdateEvent {
276+ task_id : ctx. task_id . clone ( ) ,
277+ context_id : ctx. context_id . clone ( ) ,
278+ status : TaskStatus {
279+ state : TaskState :: Working ,
280+ message : None ,
281+ timestamp : None ,
282+ } ,
283+ metadata : None ,
284+ } ) ;
285+ let completed = StreamResponse :: Task ( Task {
286+ id : ctx. task_id ,
287+ context_id : ctx. context_id ,
288+ status : TaskStatus {
289+ state : TaskState :: Completed ,
290+ message : ctx. message ,
291+ timestamp : None ,
292+ } ,
293+ artifacts : None ,
294+ history : ctx. stored_task . and_then ( |task| task. history ) ,
295+ metadata : None ,
296+ } ) ;
297+
298+ Box :: pin ( stream:: iter ( vec ! [ Ok ( working) , Ok ( completed) ] ) )
299+ }
300+
301+ fn cancel ( & self , ctx : ExecutorContext ) -> BoxStream < ' static , Result < StreamResponse , A2AError > > {
302+ Box :: pin ( stream:: once ( async move {
303+ Ok ( StreamResponse :: Task ( Task {
304+ id : ctx. task_id ,
305+ context_id : ctx. context_id ,
306+ status : TaskStatus {
307+ state : TaskState :: Canceled ,
308+ message : None ,
309+ timestamp : None ,
310+ } ,
311+ artifacts : None ,
312+ history : None ,
313+ metadata : None ,
314+ } ) )
315+ } ) )
316+ }
317+ }
318+
256319async fn spawn_http_server ( ) -> ( String , tokio:: task:: JoinHandle < ( ) > ) {
257320 let handler = Arc :: new ( TestHandler ) ;
258321 let app = Router :: new ( )
@@ -280,6 +343,76 @@ async fn spawn_http_server() -> (String, tokio::task::JoinHandle<()>) {
280343 ( format ! ( "http://{addr}" ) , handle)
281344}
282345
346+ async fn spawn_push_http_server ( ) -> ( String , tokio:: task:: JoinHandle < ( ) > ) {
347+ let handler = Arc :: new (
348+ DefaultRequestHandler :: new ( PushTransportExecutor , InMemoryTaskStore :: new ( ) )
349+ . with_push_notifications (
350+ InMemoryPushConfigStore :: new ( ) ,
351+ HttpPushSender :: new ( None ) ,
352+ ) ,
353+ ) ;
354+ let app = Router :: new ( )
355+ . nest ( "/rest" , rest_router ( handler. clone ( ) ) )
356+ . nest ( "/rpc" , jsonrpc_router ( handler) )
357+ . route (
358+ WELL_KNOWN_AGENT_CARD_PATH ,
359+ get ( || async { Json ( sample_agent_card ( ) ) } ) ,
360+ ) ;
361+
362+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
363+ let addr = listener. local_addr ( ) . unwrap ( ) ;
364+ let handle = tokio:: spawn ( async move {
365+ axum:: serve ( listener, app) . await . unwrap ( ) ;
366+ } ) ;
367+ tokio:: time:: sleep ( Duration :: from_millis ( 20 ) ) . await ;
368+ ( format ! ( "http://{addr}" ) , handle)
369+ }
370+
371+ async fn capture_push (
372+ State ( sender) : State < mpsc:: UnboundedSender < CapturedPush > > ,
373+ headers : HeaderMap ,
374+ body : Bytes ,
375+ ) -> StatusCode {
376+ sender
377+ . send ( CapturedPush {
378+ authorization : headers
379+ . get ( header:: AUTHORIZATION )
380+ . and_then ( |value| value. to_str ( ) . ok ( ) )
381+ . map ( ToOwned :: to_owned) ,
382+ notification_token : headers
383+ . get ( "A2A-Notification-Token" )
384+ . and_then ( |value| value. to_str ( ) . ok ( ) )
385+ . map ( ToOwned :: to_owned) ,
386+ event : serde_json:: from_slice ( & body) . unwrap ( ) ,
387+ } )
388+ . unwrap ( ) ;
389+ StatusCode :: ACCEPTED
390+ }
391+
392+ async fn spawn_webhook_server (
393+ ) -> (
394+ String ,
395+ mpsc:: UnboundedReceiver < CapturedPush > ,
396+ tokio:: task:: JoinHandle < ( ) > ,
397+ ) {
398+ let ( sender, receiver) = mpsc:: unbounded_channel ( ) ;
399+ let app = Router :: new ( ) . route ( "/" , post ( capture_push) ) . with_state ( sender) ;
400+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
401+ let addr = listener. local_addr ( ) . unwrap ( ) ;
402+ let handle = tokio:: spawn ( async move {
403+ axum:: serve ( listener, app) . await . unwrap ( ) ;
404+ } ) ;
405+ tokio:: time:: sleep ( Duration :: from_millis ( 20 ) ) . await ;
406+ ( format ! ( "http://{addr}/" ) , receiver, handle)
407+ }
408+
409+ async fn recv_push ( receiver : & mut mpsc:: UnboundedReceiver < CapturedPush > ) -> CapturedPush {
410+ timeout ( Duration :: from_secs ( 5 ) , receiver. recv ( ) )
411+ . await
412+ . unwrap ( )
413+ . unwrap ( )
414+ }
415+
283416fn send_message_request ( ) -> SendMessageRequest {
284417 SendMessageRequest {
285418 message : sample_message ( Role :: User , "hello" ) ,
@@ -690,3 +823,134 @@ async fn jsonrpc_transport_end_to_end() {
690823
691824 handle. abort ( ) ;
692825}
826+
827+ #[ tokio:: test]
828+ async fn rest_transport_push_delivery_end_to_end ( ) {
829+ let ( base_url, server_handle) = spawn_push_http_server ( ) . await ;
830+ let ( webhook_url, mut receiver, webhook_handle) = spawn_webhook_server ( ) . await ;
831+ let transport = RestTransport :: new ( Client :: new ( ) , format ! ( "{base_url}/rest" ) ) ;
832+
833+ transport
834+ . create_push_config (
835+ & ServiceParams :: new ( ) ,
836+ & CreateTaskPushNotificationConfigRequest {
837+ task_id : "task-rest-push" . to_string ( ) ,
838+ config : PushNotificationConfig {
839+ url : webhook_url,
840+ id : Some ( "cfg-rest" . to_string ( ) ) ,
841+ token : Some ( "rest-token" . to_string ( ) ) ,
842+ authentication : Some ( AuthenticationInfo {
843+ scheme : "Basic" . to_string ( ) ,
844+ credentials : Some ( "dGVzdDpzZWNyZXQ=" . to_string ( ) ) ,
845+ } ) ,
846+ } ,
847+ tenant : None ,
848+ } ,
849+ )
850+ . await
851+ . unwrap ( ) ;
852+
853+ let mut request = send_message_request ( ) ;
854+ request. message . task_id = Some ( "task-rest-push" . to_string ( ) ) ;
855+ request. message . context_id = Some ( "ctx-rest-push" . to_string ( ) ) ;
856+
857+ let response = transport
858+ . send_message ( & ServiceParams :: new ( ) , & request)
859+ . await
860+ . unwrap ( ) ;
861+ assert ! ( matches!( response, SendMessageResponse :: Task ( _) ) ) ;
862+
863+ let first = recv_push ( & mut receiver) . await ;
864+ assert_eq ! ( first. authorization. as_deref( ) , Some ( "Basic dGVzdDpzZWNyZXQ=" ) ) ;
865+ assert_eq ! ( first. notification_token. as_deref( ) , Some ( "rest-token" ) ) ;
866+ match first. event {
867+ StreamResponse :: StatusUpdate ( update) => {
868+ assert_eq ! ( update. task_id, "task-rest-push" ) ;
869+ assert_eq ! ( update. status. state, TaskState :: Working ) ;
870+ }
871+ _ => panic ! ( "expected status update push" ) ,
872+ }
873+
874+ let second = recv_push ( & mut receiver) . await ;
875+ assert_eq ! ( second. authorization. as_deref( ) , Some ( "Basic dGVzdDpzZWNyZXQ=" ) ) ;
876+ assert_eq ! ( second. notification_token. as_deref( ) , Some ( "rest-token" ) ) ;
877+ match second. event {
878+ StreamResponse :: Task ( task) => {
879+ assert_eq ! ( task. id, "task-rest-push" ) ;
880+ assert_eq ! ( task. status. state, TaskState :: Completed ) ;
881+ }
882+ _ => panic ! ( "expected final task push" ) ,
883+ }
884+
885+ server_handle. abort ( ) ;
886+ webhook_handle. abort ( ) ;
887+ }
888+
889+ #[ tokio:: test]
890+ async fn jsonrpc_transport_push_delivery_end_to_end ( ) {
891+ let ( base_url, server_handle) = spawn_push_http_server ( ) . await ;
892+ let ( webhook_url, mut receiver, webhook_handle) = spawn_webhook_server ( ) . await ;
893+ let transport = JsonRpcTransport :: new ( Client :: new ( ) , format ! ( "{base_url}/rpc" ) ) ;
894+
895+ let mut request = send_message_request ( ) ;
896+ request. message . task_id = Some ( "task-rpc-push" . to_string ( ) ) ;
897+ request. message . context_id = Some ( "ctx-rpc-push" . to_string ( ) ) ;
898+ request. configuration = Some ( SendMessageConfiguration {
899+ accepted_output_modes : None ,
900+ push_notification_config : Some ( PushNotificationConfig {
901+ url : webhook_url. clone ( ) ,
902+ id : Some ( "cfg-rpc" . to_string ( ) ) ,
903+ token : Some ( "rpc-token" . to_string ( ) ) ,
904+ authentication : Some ( AuthenticationInfo {
905+ scheme : "Bearer" . to_string ( ) ,
906+ credentials : Some ( "rpc-secret" . to_string ( ) ) ,
907+ } ) ,
908+ } ) ,
909+ history_length : None ,
910+ return_immediately : None ,
911+ } ) ;
912+
913+ let response = transport
914+ . send_message ( & ServiceParams :: new ( ) , & request)
915+ . await
916+ . unwrap ( ) ;
917+ assert ! ( matches!( response, SendMessageResponse :: Task ( _) ) ) ;
918+
919+ let saved = transport
920+ . get_push_config (
921+ & ServiceParams :: new ( ) ,
922+ & GetTaskPushNotificationConfigRequest {
923+ task_id : "task-rpc-push" . to_string ( ) ,
924+ id : "cfg-rpc" . to_string ( ) ,
925+ tenant : None ,
926+ } ,
927+ )
928+ . await
929+ . unwrap ( ) ;
930+ assert_eq ! ( saved. config. url, webhook_url) ;
931+
932+ let first = recv_push ( & mut receiver) . await ;
933+ assert_eq ! ( first. authorization. as_deref( ) , Some ( "Bearer rpc-secret" ) ) ;
934+ assert_eq ! ( first. notification_token. as_deref( ) , Some ( "rpc-token" ) ) ;
935+ match first. event {
936+ StreamResponse :: StatusUpdate ( update) => {
937+ assert_eq ! ( update. task_id, "task-rpc-push" ) ;
938+ assert_eq ! ( update. status. state, TaskState :: Working ) ;
939+ }
940+ _ => panic ! ( "expected status update push" ) ,
941+ }
942+
943+ let second = recv_push ( & mut receiver) . await ;
944+ assert_eq ! ( second. authorization. as_deref( ) , Some ( "Bearer rpc-secret" ) ) ;
945+ assert_eq ! ( second. notification_token. as_deref( ) , Some ( "rpc-token" ) ) ;
946+ match second. event {
947+ StreamResponse :: Task ( task) => {
948+ assert_eq ! ( task. id, "task-rpc-push" ) ;
949+ assert_eq ! ( task. status. state, TaskState :: Completed ) ;
950+ }
951+ _ => panic ! ( "expected final task push" ) ,
952+ }
953+
954+ server_handle. abort ( ) ;
955+ webhook_handle. abort ( ) ;
956+ }
0 commit comments