Skip to content

Commit 8921205

Browse files
authored
fix(interop): align push-config JSON bindings with a2a-go (#31)
Signed-off-by: Luca Muscariello <[email protected]>
1 parent 7ffecef commit 8921205

File tree

8 files changed

+258
-43
lines changed

8 files changed

+258
-43
lines changed

a2a-client/src/jsonrpc.rs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ use async_trait::async_trait;
66
use futures::stream::{self, BoxStream, StreamExt};
77
use reqwest::Client;
88

9+
use crate::push_config_compat::{
10+
deserialize_list_task_push_notification_configs_response,
11+
deserialize_task_push_notification_config,
12+
};
913
use crate::transport::{ServiceParams, Transport, TransportFactory};
1014

1115
/// JSON-RPC transport implementation.
@@ -22,15 +26,14 @@ impl JsonRpcTransport {
2226
JsonRpcTransport { client, endpoint }
2327
}
2428

25-
async fn call<Req, Resp>(
29+
async fn call_value<Req>(
2630
&self,
2731
params: &ServiceParams,
2832
method: &str,
2933
request_params: &Req,
30-
) -> Result<Resp, A2AError>
34+
) -> Result<serde_json::Value, A2AError>
3135
where
3236
Req: ProtoJsonPayload,
33-
Resp: ProtoJsonPayload,
3437
{
3538
let id = JsonRpcId::String(uuid::Uuid::now_v7().to_string());
3639
let payload = protojson_conv::to_value(request_params).map_err(|e| {
@@ -64,6 +67,21 @@ impl JsonRpcTransport {
6467
.result
6568
.ok_or_else(|| A2AError::internal("JSON-RPC response missing result"))?;
6669

70+
Ok(result)
71+
}
72+
73+
async fn call<Req, Resp>(
74+
&self,
75+
params: &ServiceParams,
76+
method: &str,
77+
request_params: &Req,
78+
) -> Result<Resp, A2AError>
79+
where
80+
Req: ProtoJsonPayload,
81+
Resp: ProtoJsonPayload,
82+
{
83+
let result = self.call_value(params, method, request_params).await?;
84+
6785
protojson_conv::from_value(result)
6886
.map_err(|e| A2AError::internal(format!("failed to deserialize result: {e}")))
6987
}
@@ -323,23 +341,26 @@ impl Transport for JsonRpcTransport {
323341
params: &ServiceParams,
324342
req: &CreateTaskPushNotificationConfigRequest,
325343
) -> Result<TaskPushNotificationConfig, A2AError> {
326-
self.call(params, methods::CREATE_PUSH_CONFIG, req).await
344+
let result = self.call_value(params, methods::CREATE_PUSH_CONFIG, req).await?;
345+
deserialize_task_push_notification_config(result)
327346
}
328347

329348
async fn get_push_config(
330349
&self,
331350
params: &ServiceParams,
332351
req: &GetTaskPushNotificationConfigRequest,
333352
) -> Result<TaskPushNotificationConfig, A2AError> {
334-
self.call(params, methods::GET_PUSH_CONFIG, req).await
353+
let result = self.call_value(params, methods::GET_PUSH_CONFIG, req).await?;
354+
deserialize_task_push_notification_config(result)
335355
}
336356

337357
async fn list_push_configs(
338358
&self,
339359
params: &ServiceParams,
340360
req: &ListTaskPushNotificationConfigsRequest,
341361
) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
342-
self.call(params, methods::LIST_PUSH_CONFIGS, req).await
362+
let result = self.call_value(params, methods::LIST_PUSH_CONFIGS, req).await?;
363+
deserialize_list_task_push_notification_configs_response(result)
343364
}
344365

345366
async fn delete_push_config(

a2a-client/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod client;
66
pub mod factory;
77
pub mod jsonrpc;
88
pub mod middleware;
9+
mod push_config_compat;
910
pub mod rest;
1011
pub mod transport;
1112

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright AGNTCY Contributors (https://github.com/agntcy)
2+
// SPDX-License-Identifier: Apache-2.0
3+
use a2a::*;
4+
use a2a_pb::protojson_conv;
5+
use serde_json::Value;
6+
7+
pub(crate) fn deserialize_task_push_notification_config(
8+
payload: Value,
9+
) -> Result<TaskPushNotificationConfig, A2AError> {
10+
serde_json::from_value::<TaskPushNotificationConfig>(payload.clone()).or_else(|serde_error| {
11+
protojson_conv::from_value(payload).map_err(|protojson_error| {
12+
A2AError::internal(format!(
13+
"failed to deserialize push-config response: {serde_error}; ProtoJSON fallback failed: {protojson_error}"
14+
))
15+
})
16+
})
17+
}
18+
19+
pub(crate) fn deserialize_list_task_push_notification_configs_response(
20+
payload: Value,
21+
) -> Result<ListTaskPushNotificationConfigsResponse, A2AError> {
22+
serde_json::from_value::<ListTaskPushNotificationConfigsResponse>(payload.clone()).or_else(
23+
|serde_error| {
24+
serde_json::from_value::<Vec<TaskPushNotificationConfig>>(payload.clone())
25+
.map(|configs| ListTaskPushNotificationConfigsResponse {
26+
configs,
27+
next_page_token: None,
28+
})
29+
.or_else(|array_error| {
30+
protojson_conv::from_value(payload).map_err(|protojson_error| {
31+
A2AError::internal(format!(
32+
"failed to deserialize push-config list response: {serde_error}; array fallback failed: {array_error}; ProtoJSON fallback failed: {protojson_error}"
33+
))
34+
})
35+
})
36+
},
37+
)
38+
}
39+
40+
#[cfg(test)]
41+
mod tests {
42+
use super::*;
43+
44+
fn sample_task_push_config() -> TaskPushNotificationConfig {
45+
TaskPushNotificationConfig {
46+
task_id: "t1".into(),
47+
config: PushNotificationConfig {
48+
url: "https://example.invalid/webhook".into(),
49+
id: Some("cfg1".into()),
50+
token: Some("token-1".into()),
51+
authentication: Some(AuthenticationInfo {
52+
scheme: "Bearer".into(),
53+
credentials: Some("secret".into()),
54+
}),
55+
},
56+
tenant: Some("tenant-1".into()),
57+
}
58+
}
59+
60+
#[test]
61+
fn parses_nested_task_push_config_shape() {
62+
let payload = serde_json::to_value(sample_task_push_config()).unwrap();
63+
let parsed = deserialize_task_push_notification_config(payload).unwrap();
64+
assert_eq!(parsed, sample_task_push_config());
65+
}
66+
67+
#[test]
68+
fn falls_back_to_flattened_protojson_task_push_config_shape() {
69+
let payload = protojson_conv::to_value(&sample_task_push_config()).unwrap();
70+
let parsed = deserialize_task_push_notification_config(payload).unwrap();
71+
assert_eq!(parsed, sample_task_push_config());
72+
}
73+
74+
#[test]
75+
fn parses_nested_list_task_push_configs_shape() {
76+
let response = ListTaskPushNotificationConfigsResponse {
77+
configs: vec![sample_task_push_config()],
78+
next_page_token: Some("next".into()),
79+
};
80+
let payload = serde_json::to_value(response.clone()).unwrap();
81+
let parsed = deserialize_list_task_push_notification_configs_response(payload).unwrap();
82+
assert_eq!(parsed, response);
83+
}
84+
85+
#[test]
86+
fn falls_back_to_flattened_protojson_list_task_push_configs_shape() {
87+
let response = ListTaskPushNotificationConfigsResponse {
88+
configs: vec![sample_task_push_config()],
89+
next_page_token: Some("next".into()),
90+
};
91+
let payload = protojson_conv::to_value(&response).unwrap();
92+
let parsed = deserialize_list_task_push_notification_configs_response(payload).unwrap();
93+
assert_eq!(parsed, response);
94+
}
95+
96+
#[test]
97+
fn falls_back_to_raw_array_list_task_push_configs_shape() {
98+
let configs = vec![sample_task_push_config()];
99+
let payload = serde_json::to_value(configs.clone()).unwrap();
100+
let parsed = deserialize_list_task_push_notification_configs_response(payload).unwrap();
101+
assert_eq!(parsed.configs, configs);
102+
assert_eq!(parsed.next_page_token, None);
103+
}
104+
}

a2a-client/src/rest.rs

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ use serde::Deserialize;
99
use serde_json::Value;
1010
use std::collections::HashMap;
1111

12+
use crate::push_config_compat::{
13+
deserialize_list_task_push_notification_configs_response,
14+
deserialize_task_push_notification_config,
15+
};
1216
use crate::transport::{ServiceParams, Transport, TransportFactory};
1317

1418
const REST_SEND_MESSAGE_PATH: &str = "/message:send";
@@ -99,15 +103,14 @@ impl RestTransport {
99103
parse_rest_error(status, &body)
100104
}
101105

102-
async fn post_json<Req, Resp>(
106+
async fn post_value<Req>(
103107
&self,
104108
path: &str,
105109
params: &ServiceParams,
106110
body: &Req,
107-
) -> Result<Resp, A2AError>
111+
) -> Result<Value, A2AError>
108112
where
109113
Req: ProtoJsonPayload,
110-
Resp: ProtoJsonPayload,
111114
{
112115
let payload = protojson_conv::to_value(body).map_err(|e| {
113116
A2AError::internal(format!("failed to serialize request as ProtoJSON: {e}"))
@@ -127,20 +130,32 @@ impl RestTransport {
127130
.await
128131
.map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
129132

130-
protojson_conv::from_value(payload).map_err(|e| {
131-
A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
132-
})
133+
Ok(payload)
133134
}
134135

135-
async fn get_json<Resp>(
136+
async fn post_json<Req, Resp>(
136137
&self,
137138
path: &str,
138139
params: &ServiceParams,
139-
query: &[(String, String)],
140+
body: &Req,
140141
) -> Result<Resp, A2AError>
141142
where
143+
Req: ProtoJsonPayload,
142144
Resp: ProtoJsonPayload,
143145
{
146+
let payload = self.post_value(path, params, body).await?;
147+
148+
protojson_conv::from_value(payload).map_err(|e| {
149+
A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
150+
})
151+
}
152+
153+
async fn get_value(
154+
&self,
155+
path: &str,
156+
params: &ServiceParams,
157+
query: &[(String, String)],
158+
) -> Result<Value, A2AError> {
144159
let resp = self
145160
.send(self.build_request_with_query(reqwest::Method::GET, path, params, query))
146161
.await?;
@@ -153,6 +168,20 @@ impl RestTransport {
153168
.await
154169
.map_err(|e| A2AError::internal(format!("failed to parse response: {e}")))?;
155170

171+
Ok(payload)
172+
}
173+
174+
async fn get_json<Resp>(
175+
&self,
176+
path: &str,
177+
params: &ServiceParams,
178+
query: &[(String, String)],
179+
) -> Result<Resp, A2AError>
180+
where
181+
Resp: ProtoJsonPayload,
182+
{
183+
let payload = self.get_value(path, params, query).await?;
184+
156185
protojson_conv::from_value(payload).map_err(|e| {
157186
A2AError::internal(format!("failed to deserialize response as ProtoJSON: {e}"))
158187
})
@@ -362,25 +391,29 @@ impl Transport for RestTransport {
362391
params: &ServiceParams,
363392
req: &CreateTaskPushNotificationConfigRequest,
364393
) -> Result<TaskPushNotificationConfig, A2AError> {
365-
self.post_json(
394+
let payload = self
395+
.post_value(
366396
&format!("/tasks/{}/pushNotificationConfigs", req.task_id),
367397
params,
368398
&req.config,
369399
)
370-
.await
400+
.await?;
401+
deserialize_task_push_notification_config(payload)
371402
}
372403

373404
async fn get_push_config(
374405
&self,
375406
params: &ServiceParams,
376407
req: &GetTaskPushNotificationConfigRequest,
377408
) -> Result<TaskPushNotificationConfig, A2AError> {
378-
self.get_json(
409+
let payload = self
410+
.get_value(
379411
&format!("/tasks/{}/pushNotificationConfigs/{}", req.task_id, req.id),
380412
params,
381413
&[],
382414
)
383-
.await
415+
.await?;
416+
deserialize_task_push_notification_config(payload)
384417
}
385418

386419
async fn list_push_configs(
@@ -396,12 +429,14 @@ impl Transport for RestTransport {
396429
query_parts.push(("pageToken".to_string(), page_token.clone()));
397430
}
398431

399-
self.get_json(
432+
let payload = self
433+
.get_value(
400434
&format!("/tasks/{}/pushNotificationConfigs", req.task_id),
401435
params,
402436
&query_parts,
403437
)
404-
.await
438+
.await?;
439+
deserialize_list_task_push_notification_configs_response(payload)
405440
}
406441

407442
async fn delete_push_config(

0 commit comments

Comments
 (0)