Skip to content

Commit 0c13600

Browse files
committed
Enable rest protocol between Aggregator and Collaborator
- add new AggregatorClientInterface to allow switching b/w grpc and rest - endhance existing AggregatorGRPCClient to start using AggregatorClientInterface - added new transport package for rest with AggregatorRESTClient implementing AggregatorClientInterface - added streaming api support with custom content-type - added various connection flag for streaming request - send additional header key "Sender" for better request logging at server side - aligned Rest and gRPC client for most of the init params - added AggregatorRESTServer and necesary changes in aggregator cli and federated/plan get_server method - added transport_protocol settings in defaults/network.yaml rebased 26-Mar.2 Signed-off-by: Shailesh Pant <shailesh.pant@intel.com>
1 parent 4cc355c commit 0c13600

File tree

10 files changed

+742
-28
lines changed

10 files changed

+742
-28
lines changed

openfl-workspace/workspace/plan/defaults/network.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ settings:
77
client_reconnect_interval : 5
88
require_client_auth : True
99
cert_folder : cert
10-
enable_atomic_connections : False
10+
enable_atomic_connections : False
11+
transport_protocol : rest

openfl/federated/plan/plan.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from openfl.interface.aggregation_functions import AggregationFunction, WeightedAverage
1717
from openfl.interface.cli_helper import WORKSPACE
18-
from openfl.transport import AggregatorGRPCClient, AggregatorGRPCServer
18+
from openfl.transport import AggregatorGRPCClient, AggregatorGRPCServer, AggregatorRESTServer
1919
from openfl.utilities.utils import getfqdn_env
2020

2121
SETTINGS = "settings"
@@ -542,8 +542,6 @@ def get_collaborator(
542542
else:
543543
defaults[SETTINGS]["client"] = self.get_client(
544544
collaborator_name,
545-
self.aggregator_uuid,
546-
self.federation_uuid,
547545
root_certificate,
548546
private_key,
549547
certificate,
@@ -557,8 +555,6 @@ def get_collaborator(
557555
def get_client(
558556
self,
559557
collaborator_name,
560-
aggregator_uuid,
561-
federation_uuid,
562558
root_certificate=None,
563559
private_key=None,
564560
certificate=None,
@@ -579,6 +575,25 @@ def get_client(
579575
Returns:
580576
AggregatorGRPCClient: gRPC client for the specified collaborator.
581577
"""
578+
client_args = self.get_client_args(
579+
collaborator_name,
580+
root_certificate,
581+
private_key,
582+
certificate,
583+
)
584+
585+
if self.client_ is None:
586+
self.client_ = AggregatorGRPCClient(**client_args)
587+
588+
return self.client_
589+
590+
def get_client_args(
591+
self,
592+
collaborator_name,
593+
root_certificate=None,
594+
private_key=None,
595+
certificate=None,
596+
):
582597
common_name = collaborator_name
583598
if not root_certificate or not private_key or not certificate:
584599
root_certificate = "cert/cert_chain.crt"
@@ -593,14 +608,10 @@ def get_client(
593608
client_args["certificate"] = certificate
594609
client_args["private_key"] = private_key
595610

596-
client_args["aggregator_uuid"] = aggregator_uuid
597-
client_args["federation_uuid"] = federation_uuid
611+
client_args["aggregator_uuid"] = self.aggregator_uuid
612+
client_args["federation_uuid"] = self.federation_uuid
598613
client_args["collaborator_name"] = collaborator_name
599-
600-
if self.client_ is None:
601-
self.client_ = AggregatorGRPCClient(**client_args)
602-
603-
return self.client_
614+
return client_args
604615

605616
def get_server(
606617
self,
@@ -609,7 +620,7 @@ def get_server(
609620
certificate=None,
610621
**kwargs,
611622
):
612-
"""Get gRPC server of the aggregator instance.
623+
"""Get gRPC or REST server of the aggregator instance.
613624
614625
Args:
615626
root_certificate (str, optional): Root certificate for the server.
@@ -621,8 +632,29 @@ def get_server(
621632
**kwargs: Additional keyword arguments.
622633
623634
Returns:
624-
AggregatorGRPCServer: gRPC server of the aggregator instance.
635+
Aggregator Server: returns either gRPC or REST server of the aggregator instance.
625636
"""
637+
server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs)
638+
639+
server_args["aggregator"] = self.get_aggregator()
640+
network_cfg = self.config["network"][SETTINGS]
641+
protocol = network_cfg.get("transport_protocol", "grpc").lower()
642+
643+
if self.server_ is None:
644+
self.server_ = self._get_server(protocol, **server_args)
645+
646+
return self.server_
647+
648+
def _get_server(self, protocol, **kwargs):
649+
if protocol == "rest":
650+
server = AggregatorRESTServer(**kwargs)
651+
elif protocol == "grpc":
652+
server = AggregatorGRPCServer(**kwargs)
653+
else:
654+
raise ValueError(f"Unsupported transport_protocol '{protocol}'")
655+
return server
656+
657+
def get_server_args(self, root_certificate, private_key, certificate, kwargs):
626658
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()
627659

628660
if not root_certificate or not private_key or not certificate:
@@ -638,13 +670,7 @@ def get_server(
638670
server_args["root_certificate"] = root_certificate
639671
server_args["certificate"] = certificate
640672
server_args["private_key"] = private_key
641-
642-
server_args["aggregator"] = self.get_aggregator()
643-
644-
if self.server_ is None:
645-
self.server_ = AggregatorGRPCServer(**server_args)
646-
647-
return self.server_
673+
return server_args
648674

649675
def save_model_to_state_file(self, tensor_dict, round_number, output_path):
650676
"""Save model weights to a protobuf state file.

openfl/interface/aggregator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def start_(plan, authorized_cols, task_group):
9292
logger.info(f"Setting aggregator to assign: {task_group} task_group")
9393

9494
logger.info("🧿 Starting the Aggregator Service.")
95-
96-
parsed_plan.get_server().serve()
95+
# Instantiate either gRPC or REST server based on transport configuration
96+
server = parsed_plan.get_server()
97+
server.serve()
9798

9899

99100
@aggregator.command(name="generate-cert-request")

openfl/interface/collaborator.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from shutil import copy, copytree, ignore_patterns, make_archive, unpack_archive
1515
from tempfile import mkdtemp
1616

17+
from click import Choice as ClickChoice
1718
from click import Path as ClickPath
1819
from click import confirm, echo, group, option, pass_context, prompt, style
1920
from yaml import FullLoader, dump, load
@@ -61,7 +62,14 @@ def collaborator(context):
6162
required=True,
6263
help="The certified common name of the collaborator.",
6364
)
64-
def start_(plan, collaborator_name, data_config):
65+
@option(
66+
"-c",
67+
"--client_protocol",
68+
help="The client protocol to use for communication with the aggregator.",
69+
default="grpc",
70+
type=ClickChoice(["grpc", "rest"]),
71+
)
72+
def start_(plan, collaborator_name, data_config, client_protocol):
6573
"""Starts a collaborator service."""
6674

6775
if plan and is_directory_traversal(plan):
@@ -80,8 +88,15 @@ def start_(plan, collaborator_name, data_config):
8088

8189
echo(f"Data = {plan.cols_data_paths}")
8290
logger.info("🧿 Starting a Collaborator Service.")
83-
84-
plan.get_collaborator(collaborator_name).run()
91+
collaborator = plan.get_collaborator(collaborator_name)
92+
# Determine if additional client protocol needs to be created
93+
if client_protocol == "rest":
94+
from openfl.transport.rest.aggregator_client import AggregatorRESTClient
95+
96+
client_args = plan.get_client_args(collaborator_name=collaborator_name)
97+
aggregator_client = AggregatorRESTClient(**client_args)
98+
collaborator.client = aggregator_client
99+
collaborator.run()
85100

86101

87102
@collaborator.command(name="create")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, List, Tuple
3+
4+
5+
class AggregatorClientInterface(ABC):
6+
@abstractmethod
7+
def get_tasks(self) -> Tuple[List[Any], int, int, bool]:
8+
"""
9+
Retrieves tasks for the given collaborator client.
10+
Returns a tuple: (tasks, round_number, sleep_time, time_to_quit)
11+
"""
12+
pass
13+
14+
@abstractmethod
15+
def get_aggregated_tensor(
16+
self,
17+
tensor_name: str,
18+
round_number: int,
19+
report: bool,
20+
tags: List[str],
21+
require_lossless: bool,
22+
) -> Any:
23+
"""
24+
Retrieves the aggregated tensor.
25+
"""
26+
pass
27+
28+
@abstractmethod
29+
def send_local_task_results(
30+
self,
31+
round_number: int,
32+
task_name: str,
33+
data_size: int,
34+
named_tensors: List[Any],
35+
) -> Any:
36+
"""
37+
Sends local task results.
38+
Parameters:
39+
collaborator_name: Name of the collaborator.
40+
round_number: The current round.
41+
task_name: Name of the task.
42+
data_size: Size of the data.
43+
named_tensors: A list of tensors (or named tensor objects).
44+
Returns a SendLocalTaskResultsResponse.
45+
"""
46+
pass

openfl/transport/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33

44

55
from openfl.transport.grpc import AggregatorGRPCClient, AggregatorGRPCServer
6+
from openfl.transport.rest import AggregatorRESTClient, AggregatorRESTServer

openfl/transport/grpc/aggregator_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import grpc
1212

1313
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
14+
from openfl.protocols.aggregator_client_interface import AggregatorClientInterface
1415
from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel
1516

1617
logger = logging.getLogger(__name__)
@@ -165,9 +166,12 @@ def wrapper(self, *args, **kwargs):
165166
return wrapper
166167

167168

168-
class AggregatorGRPCClient:
169+
class AggregatorGRPCClient(AggregatorClientInterface):
169170
"""Collaborator-side gRPC client that talks to the aggregator.
170171
172+
This class implements a gRPC client for communicating with an aggregator
173+
over a secure (TLS) connection.
174+
171175
Attributes:
172176
agg_addr (str): Aggregator address.
173177
agg_port (int): Aggregator port.

openfl/transport/rest/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright 2020-2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from openfl.transport.rest.aggregator_client import AggregatorRESTClient
6+
from openfl.transport.rest.aggregator_server import AggregatorRESTServer

0 commit comments

Comments
 (0)