Skip to content

Commit 541fe87

Browse files
committed
Enable rest protocol between Aggregator and Collaborator
- add new AggregatorClientInterface to allow switching b/w grpc and rest - endhance existing AggregatorGRPCClient to start using AggregatorClientInterface - added new transport package for rest with AggregatorRESTClient implementing AggregatorClientInterface - added streaming api support with custom content-type - added various connection flag for streaming request - send additional header key "Sender" for better request logging at server side - aligned Rest and gRPC client for most of the init params rebased 26-Mar.2 Signed-off-by: Shailesh Pant <shailesh.pant@intel.com>
1 parent 4cc355c commit 541fe87

File tree

8 files changed

+387
-23
lines changed

8 files changed

+387
-23
lines changed

openfl/federated/plan/plan.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,6 @@ def get_collaborator(
542542
else:
543543
defaults[SETTINGS]["client"] = self.get_client(
544544
collaborator_name,
545-
self.aggregator_uuid,
546-
self.federation_uuid,
547545
root_certificate,
548546
private_key,
549547
certificate,
@@ -557,8 +555,6 @@ def get_collaborator(
557555
def get_client(
558556
self,
559557
collaborator_name,
560-
aggregator_uuid,
561-
federation_uuid,
562558
root_certificate=None,
563559
private_key=None,
564560
certificate=None,
@@ -579,6 +575,25 @@ def get_client(
579575
Returns:
580576
AggregatorGRPCClient: gRPC client for the specified collaborator.
581577
"""
578+
client_args = self.get_client_args(
579+
collaborator_name,
580+
root_certificate,
581+
private_key,
582+
certificate,
583+
)
584+
585+
if self.client_ is None:
586+
self.client_ = AggregatorGRPCClient(**client_args)
587+
588+
return self.client_
589+
590+
def get_client_args(
591+
self,
592+
collaborator_name,
593+
root_certificate=None,
594+
private_key=None,
595+
certificate=None,
596+
):
582597
common_name = collaborator_name
583598
if not root_certificate or not private_key or not certificate:
584599
root_certificate = "cert/cert_chain.crt"
@@ -593,14 +608,10 @@ def get_client(
593608
client_args["certificate"] = certificate
594609
client_args["private_key"] = private_key
595610

596-
client_args["aggregator_uuid"] = aggregator_uuid
597-
client_args["federation_uuid"] = federation_uuid
611+
client_args["aggregator_uuid"] = self.aggregator_uuid
612+
client_args["federation_uuid"] = self.federation_uuid
598613
client_args["collaborator_name"] = collaborator_name
599-
600-
if self.client_ is None:
601-
self.client_ = AggregatorGRPCClient(**client_args)
602-
603-
return self.client_
614+
return client_args
604615

605616
def get_server(
606617
self,
@@ -623,6 +634,16 @@ def get_server(
623634
Returns:
624635
AggregatorGRPCServer: gRPC server of the aggregator instance.
625636
"""
637+
server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs)
638+
639+
server_args["aggregator"] = self.get_aggregator()
640+
641+
if self.server_ is None:
642+
self.server_ = AggregatorGRPCServer(**server_args)
643+
644+
return self.server_
645+
646+
def get_server_args(self, root_certificate, private_key, certificate, kwargs):
626647
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()
627648

628649
if not root_certificate or not private_key or not certificate:
@@ -638,13 +659,7 @@ def get_server(
638659
server_args["root_certificate"] = root_certificate
639660
server_args["certificate"] = certificate
640661
server_args["private_key"] = private_key
641-
642-
server_args["aggregator"] = self.get_aggregator()
643-
644-
if self.server_ is None:
645-
self.server_ = AggregatorGRPCServer(**server_args)
646-
647-
return self.server_
662+
return server_args
648663

649664
def save_model_to_state_file(self, tensor_dict, round_number, output_path):
650665
"""Save model weights to a protobuf state file.

openfl/interface/collaborator.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from shutil import copy, copytree, ignore_patterns, make_archive, unpack_archive
1515
from tempfile import mkdtemp
1616

17+
from click import Choice as ClickChoice
1718
from click import Path as ClickPath
1819
from click import confirm, echo, group, option, pass_context, prompt, style
1920
from yaml import FullLoader, dump, load
@@ -61,7 +62,14 @@ def collaborator(context):
6162
required=True,
6263
help="The certified common name of the collaborator.",
6364
)
64-
def start_(plan, collaborator_name, data_config):
65+
@option(
66+
"-c",
67+
"--client_protocol",
68+
help="The client protocol to use for communication with the aggregator.",
69+
default="grpc",
70+
type=ClickChoice(["grpc", "rest"]),
71+
)
72+
def start_(plan, collaborator_name, data_config, client_protocol):
6573
"""Starts a collaborator service."""
6674

6775
if plan and is_directory_traversal(plan):
@@ -80,8 +88,15 @@ def start_(plan, collaborator_name, data_config):
8088

8189
echo(f"Data = {plan.cols_data_paths}")
8290
logger.info("🧿 Starting a Collaborator Service.")
83-
84-
plan.get_collaborator(collaborator_name).run()
91+
collaborator = plan.get_collaborator(collaborator_name)
92+
# Determine if additional client protocol needs to be created
93+
if client_protocol == "rest":
94+
from openfl.transport.rest.aggregator_client import AggregatorRESTClient
95+
96+
client_args = plan.get_client_args(collaborator_name=collaborator_name)
97+
aggregator_client = AggregatorRESTClient(**client_args)
98+
collaborator.client = aggregator_client
99+
collaborator.run()
85100

86101

87102
@collaborator.command(name="create")

openfl/protocols/aggregator.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import "openfl/protocols/base.proto";
1111
service Aggregator {
1212
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse) {}
1313
rpc GetAggregatedTensor(GetAggregatedTensorRequest) returns (GetAggregatedTensorResponse) {}
14-
rpc SendLocalTaskResults(stream DataStream) returns (SendLocalTaskResultsResponse) {}
1514
rpc InteropRelay(InteropMessage) returns (InteropMessage) {}
15+
rpc SendLocalTaskResults(stream DataStream) returns (SendLocalTaskResultsResponse) {}
1616
}
1717

1818
message MessageHeader {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, List, Tuple
3+
4+
5+
class AggregatorClientInterface(ABC):
6+
@abstractmethod
7+
def get_tasks(self) -> Tuple[List[Any], int, int, bool]:
8+
"""
9+
Retrieves tasks for the given collaborator client.
10+
Returns a tuple: (tasks, round_number, sleep_time, time_to_quit)
11+
"""
12+
pass
13+
14+
@abstractmethod
15+
def get_aggregated_tensor(
16+
self,
17+
tensor_name: str,
18+
round_number: int,
19+
report: bool,
20+
tags: List[str],
21+
require_lossless: bool,
22+
) -> Any:
23+
"""
24+
Retrieves the aggregated tensor.
25+
"""
26+
pass
27+
28+
@abstractmethod
29+
def send_local_task_results(
30+
self,
31+
round_number: int,
32+
task_name: str,
33+
data_size: int,
34+
named_tensors: List[Any],
35+
) -> Any:
36+
"""
37+
Sends local task results.
38+
Parameters:
39+
collaborator_name: Name of the collaborator.
40+
round_number: The current round.
41+
task_name: Name of the task.
42+
data_size: Size of the data.
43+
named_tensors: A list of tensors (or named tensor objects).
44+
Returns a SendLocalTaskResultsResponse.
45+
"""
46+
pass

openfl/transport/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44

55
from openfl.transport.grpc import AggregatorGRPCClient, AggregatorGRPCServer
6+
from openfl.transport.rest import AggregatorRESTClient

openfl/transport/grpc/aggregator_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import grpc
1212

1313
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
14+
from openfl.protocols.aggregator_client_interface import AggregatorClientInterface
1415
from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel
1516

1617
logger = logging.getLogger(__name__)
@@ -165,9 +166,12 @@ def wrapper(self, *args, **kwargs):
165166
return wrapper
166167

167168

168-
class AggregatorGRPCClient:
169+
class AggregatorGRPCClient(AggregatorClientInterface):
169170
"""Collaborator-side gRPC client that talks to the aggregator.
170171
172+
This class implements a gRPC client for communicating with an aggregator
173+
over a secure (TLS) connection.
174+
171175
Attributes:
172176
agg_addr (str): Aggregator address.
173177
agg_port (int): Aggregator port.

openfl/transport/rest/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2020-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from openfl.transport.rest.aggregator_client import AggregatorRESTClient

0 commit comments

Comments
 (0)