Skip to content

Commit 88c4c4f

Browse files
authored
Merge pull request #27 from agntcy/feat/push-config-and-grpc-list-interop
feat(server): support push configs and gRPC list interop
2 parents 320b26d + 73cb288 commit 88c4c4f

File tree

4 files changed

+366
-12
lines changed

4 files changed

+366
-12
lines changed

a2a-grpc/tests/e2e.rs

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use a2a::*;
88
use a2a_client::{Transport, TransportFactory};
99
use a2a_grpc::{GrpcHandler, GrpcTransport, GrpcTransportFactory};
1010
use a2a_pb::proto::a2a_service_server::A2aServiceServer;
11-
use a2a_server::{RequestHandler, ServiceParams};
11+
use a2a_server::{DefaultRequestHandler, InMemoryTaskStore, RequestHandler, ServiceParams};
1212
use async_trait::async_trait;
1313
use futures::StreamExt;
1414
use futures::stream::{self, BoxStream};
@@ -81,6 +81,8 @@ fn sample_agent_card() -> AgentCard {
8181

8282
struct TestHandler;
8383

84+
struct StoredTaskExecutor;
85+
8486
#[async_trait]
8587
impl RequestHandler for TestHandler {
8688
async fn send_message(
@@ -245,6 +247,57 @@ impl RequestHandler for TestHandler {
245247
}
246248
}
247249

250+
impl a2a_server::AgentExecutor for StoredTaskExecutor {
251+
fn execute(
252+
&self,
253+
ctx: a2a_server::ExecutorContext,
254+
) -> BoxStream<'static, Result<StreamResponse, A2AError>> {
255+
let response = StreamResponse::Task(Task {
256+
id: ctx.task_id.clone(),
257+
context_id: ctx.context_id.clone(),
258+
status: TaskStatus {
259+
state: TaskState::Completed,
260+
message: Some(Message {
261+
message_id: "stored-task-response".to_string(),
262+
context_id: Some(ctx.context_id.clone()),
263+
task_id: Some(ctx.task_id.clone()),
264+
role: Role::Agent,
265+
parts: vec![Part::text("stored-task-done")],
266+
metadata: None,
267+
extensions: None,
268+
reference_task_ids: None,
269+
}),
270+
timestamp: None,
271+
},
272+
artifacts: None,
273+
history: ctx.message.clone().map(|message| vec![message]),
274+
metadata: None,
275+
});
276+
277+
Box::pin(stream::once(async move { Ok(response) }))
278+
}
279+
280+
fn cancel(
281+
&self,
282+
ctx: a2a_server::ExecutorContext,
283+
) -> BoxStream<'static, Result<StreamResponse, A2AError>> {
284+
let response = StreamResponse::Task(Task {
285+
id: ctx.task_id.clone(),
286+
context_id: ctx.context_id.clone(),
287+
status: TaskStatus {
288+
state: TaskState::Canceled,
289+
message: None,
290+
timestamp: None,
291+
},
292+
artifacts: None,
293+
history: None,
294+
metadata: None,
295+
});
296+
297+
Box::pin(stream::once(async move { Ok(response) }))
298+
}
299+
}
300+
248301
async fn spawn_grpc_server() -> (String, tokio::task::JoinHandle<()>) {
249302
let handler = Arc::new(TestHandler);
250303
let service = A2aServiceServer::new(GrpcHandler::new(handler));
@@ -262,6 +315,26 @@ async fn spawn_grpc_server() -> (String, tokio::task::JoinHandle<()>) {
262315
(format!("http://{addr}"), handle)
263316
}
264317

318+
async fn spawn_default_handler_grpc_server() -> (String, tokio::task::JoinHandle<()>) {
319+
let handler = Arc::new(DefaultRequestHandler::new(
320+
StoredTaskExecutor,
321+
InMemoryTaskStore::new(),
322+
));
323+
let service = A2aServiceServer::new(GrpcHandler::new(handler));
324+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
325+
let addr = listener.local_addr().unwrap();
326+
let incoming = TcpListenerStream::new(listener);
327+
let handle = tokio::spawn(async move {
328+
Server::builder()
329+
.add_service(service)
330+
.serve_with_incoming(incoming)
331+
.await
332+
.unwrap();
333+
});
334+
tokio::time::sleep(Duration::from_millis(20)).await;
335+
(format!("http://{addr}"), handle)
336+
}
337+
265338
fn endpoint_without_scheme(endpoint: &str) -> String {
266339
endpoint
267340
.split_once("://")
@@ -502,3 +575,51 @@ async fn grpc_transport_accepts_bare_host_port_endpoints() {
502575
transport.destroy().await.unwrap();
503576
handle.abort();
504577
}
578+
579+
#[tokio::test]
580+
async fn grpc_transport_treats_zero_page_size_as_unset() {
581+
let (endpoint, handle) = spawn_default_handler_grpc_server().await;
582+
let transport = GrpcTransport::connect(endpoint).await.unwrap();
583+
584+
let sent = transport
585+
.send_message(
586+
&ServiceParams::new(),
587+
&SendMessageRequest {
588+
message: Message::new(Role::User, vec![Part::text("hello")]),
589+
configuration: None,
590+
metadata: None,
591+
tenant: None,
592+
},
593+
)
594+
.await
595+
.unwrap();
596+
597+
let task = match sent {
598+
SendMessageResponse::Task(task) => task,
599+
SendMessageResponse::Message(_) => panic!("expected task response"),
600+
};
601+
602+
let listed = transport
603+
.list_tasks(
604+
&ServiceParams::new(),
605+
&ListTasksRequest {
606+
context_id: Some(task.context_id.clone()),
607+
status: None,
608+
page_size: Some(0),
609+
page_token: None,
610+
history_length: None,
611+
status_timestamp_after: None,
612+
include_artifacts: Some(false),
613+
tenant: None,
614+
},
615+
)
616+
.await
617+
.unwrap();
618+
619+
assert_eq!(listed.tasks.len(), 1);
620+
assert_eq!(listed.tasks[0].id, task.id);
621+
assert_eq!(listed.page_size, 50);
622+
623+
transport.destroy().await.unwrap();
624+
handle.abort();
625+
}

a2a-pb/src/pbconv.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ fn empty_to_none(s: &str) -> Option<String> {
111111
}
112112
}
113113

114+
fn non_positive_to_none(value: Option<i32>) -> Option<i32> {
115+
value.filter(|value| *value > 0)
116+
}
117+
114118
// ---------------------------------------------------------------------------
115119
// Role
116120
// ---------------------------------------------------------------------------
@@ -519,7 +523,7 @@ pub fn from_proto_list_tasks_request(r: &proto::ListTasksRequest) -> ListTasksRe
519523
} else {
520524
Some(from_proto_task_state(r.status))
521525
},
522-
page_size: r.page_size,
526+
page_size: non_positive_to_none(r.page_size),
523527
page_token: empty_to_none(&r.page_token),
524528
history_length: r.history_length,
525529
status_timestamp_after: r
@@ -1626,6 +1630,24 @@ mod tests {
16261630
assert_eq!(req.tenant, back.tenant);
16271631
}
16281632

1633+
#[test]
1634+
fn test_from_proto_list_tasks_request_treats_zero_page_size_as_unset() {
1635+
let proto = proto::ListTasksRequest {
1636+
tenant: String::new(),
1637+
context_id: "ctx-1".to_string(),
1638+
status: 0,
1639+
page_size: Some(0),
1640+
page_token: String::new(),
1641+
history_length: None,
1642+
status_timestamp_after: None,
1643+
include_artifacts: Some(false),
1644+
};
1645+
1646+
let back = from_proto_list_tasks_request(&proto);
1647+
assert_eq!(back.context_id.as_deref(), Some("ctx-1"));
1648+
assert_eq!(back.page_size, None);
1649+
}
1650+
16291651
#[test]
16301652
fn test_list_tasks_response_roundtrip() {
16311653
let resp = ListTasksResponse {

0 commit comments

Comments
 (0)