Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion openfl-workspace/workspace/plan/defaults/network.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ settings:
client_reconnect_interval : 5
require_client_auth : True
cert_folder : cert
enable_atomic_connections : False
enable_atomic_connections : False
transport_protocol : grpc
Comment thread
ishaileshpant marked this conversation as resolved.
4 changes: 2 additions & 2 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from openfl.databases import TensorDB
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import utils
from openfl.transport.grpc.aggregator_client import AggregatorGRPCClient
from openfl.transport.grpc.aggregator_client import AggregatorClientInterface
from openfl.utilities import TensorKey

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
collaborator_name,
aggregator_uuid,
federation_uuid,
client: AggregatorGRPCClient,
client: AggregatorClientInterface,
task_runner,
task_config,
opt_treatment="RESET",
Expand Down
88 changes: 65 additions & 23 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

from openfl.interface.aggregation_functions import AggregationFunction, WeightedAverage
from openfl.interface.cli_helper import WORKSPACE
from openfl.transport import AggregatorGRPCClient, AggregatorGRPCServer
from openfl.transport import (
AggregatorGRPCClient,
AggregatorGRPCServer,
AggregatorRESTClient,
AggregatorRESTServer,
)
from openfl.utilities.utils import getfqdn_env

SETTINGS = "settings"
Expand Down Expand Up @@ -542,8 +547,6 @@ def get_collaborator(
else:
defaults[SETTINGS]["client"] = self.get_client(
collaborator_name,
self.aggregator_uuid,
self.federation_uuid,
root_certificate,
private_key,
certificate,
Expand All @@ -557,13 +560,11 @@ def get_collaborator(
def get_client(
self,
collaborator_name,
aggregator_uuid,
federation_uuid,
root_certificate=None,
private_key=None,
certificate=None,
):
"""Get gRPC client for the specified collaborator.
"""Get gRPC or REST client for the specified collaborator.

Args:
collaborator_name (str): Name of the collaborator.
Expand All @@ -577,8 +578,38 @@ def get_client(
Defaults to None.

Returns:
AggregatorGRPCClient: gRPC client for the specified collaborator.
AggregatorGRPCClient or AggregatorRESTClient: gRPC or REST client for the collaborator.
"""
client_args = self.get_client_args(
collaborator_name,
root_certificate,
private_key,
certificate,
)
network_cfg = self.config["network"][SETTINGS]
protocol = network_cfg.get("transport_protocol", "grpc").lower()

if self.client_ is None:
Comment thread
ishaileshpant marked this conversation as resolved.
self.client_ = self._get_client(protocol, **client_args)

return self.client_

def _get_client(self, protocol, **kwargs):
if protocol == "rest":
client = AggregatorRESTClient(**kwargs)
elif protocol == "grpc":
client = AggregatorGRPCClient(**kwargs)
else:
raise ValueError(f"Unsupported transport_protocol '{protocol}'")
return client

def get_client_args(
self,
collaborator_name,
root_certificate=None,
private_key=None,
certificate=None,
):
common_name = collaborator_name
if not root_certificate or not private_key or not certificate:
root_certificate = "cert/cert_chain.crt"
Expand All @@ -593,14 +624,10 @@ def get_client(
client_args["certificate"] = certificate
client_args["private_key"] = private_key

client_args["aggregator_uuid"] = aggregator_uuid
client_args["federation_uuid"] = federation_uuid
client_args["aggregator_uuid"] = self.aggregator_uuid
client_args["federation_uuid"] = self.federation_uuid
client_args["collaborator_name"] = collaborator_name

if self.client_ is None:
self.client_ = AggregatorGRPCClient(**client_args)

return self.client_
return client_args

def get_server(
self,
Expand All @@ -609,7 +636,7 @@ def get_server(
certificate=None,
**kwargs,
):
"""Get gRPC server of the aggregator instance.
"""Get gRPC or REST server of the aggregator instance.

Args:
root_certificate (str, optional): Root certificate for the server.
Expand All @@ -621,8 +648,29 @@ def get_server(
**kwargs: Additional keyword arguments.

Returns:
AggregatorGRPCServer: gRPC server of the aggregator instance.
Aggregator Server: returns either gRPC or REST server of the aggregator instance.
"""
server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs)

server_args["aggregator"] = self.get_aggregator()
network_cfg = self.config["network"][SETTINGS]
protocol = network_cfg.get("transport_protocol", "grpc").lower()

if self.server_ is None:
self.server_ = self._get_server(protocol, **server_args)

return self.server_

def _get_server(self, protocol, **kwargs):
if protocol == "rest":
server = AggregatorRESTServer(**kwargs)
elif protocol == "grpc":
server = AggregatorGRPCServer(**kwargs)
else:
raise ValueError(f"Unsupported transport_protocol '{protocol}'")
return server

def get_server_args(self, root_certificate, private_key, certificate, kwargs):
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()

if not root_certificate or not private_key or not certificate:
Expand All @@ -638,13 +686,7 @@ def get_server(
server_args["root_certificate"] = root_certificate
server_args["certificate"] = certificate
server_args["private_key"] = private_key

server_args["aggregator"] = self.get_aggregator()

if self.server_ is None:
self.server_ = AggregatorGRPCServer(**server_args)

return self.server_
return server_args

def save_model_to_state_file(self, tensor_dict, round_number, output_path):
"""Save model weights to a protobuf state file.
Expand Down
4 changes: 2 additions & 2 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def start_(plan, authorized_cols, task_group):
logger.info(f"Setting aggregator to assign: {task_group} task_group")

logger.info("🧿 Starting the Aggregator Service.")

parsed_plan.get_server().serve()
server = parsed_plan.get_server()
server.serve()


@aggregator.command(name="generate-cert-request")
Expand Down
72 changes: 72 additions & 0 deletions openfl/protocols/aggregator_client_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""AggregatorClientInterface module."""

from abc import ABC, abstractmethod
from typing import Any, List, Tuple


class AggregatorClientInterface(ABC):
Comment thread
ishaileshpant marked this conversation as resolved.
Comment thread
kminhta marked this conversation as resolved.
@abstractmethod
def ping(self):
"""
Ping the aggregator to check connectivity.
"""
pass

@abstractmethod
def get_tasks(self) -> Tuple[List[Any], int, int, bool]:
"""
Retrieves tasks for the given collaborator client.
Returns a tuple: (tasks, round_number, sleep_time, time_to_quit)
"""
pass

@abstractmethod
def get_aggregated_tensor(
self,
tensor_name: str,
round_number: int,
report: bool,
tags: List[str],
require_lossless: bool,
) -> Any:
"""
Retrieves the aggregated tensor.
"""
pass

@abstractmethod
def send_local_task_results(
self,
round_number: int,
task_name: str,
data_size: int,
named_tensors: List[Any],
) -> Any:
"""
Sends local task results.
Parameters:
collaborator_name: Name of the collaborator.
round_number: The current round.
task_name: Name of the task.
data_size: Size of the data.
named_tensors: A list of tensors (or named tensor objects).
Returns a SendLocalTaskResultsResponse.
"""
pass

@abstractmethod
def send_message_to_server(self, openfl_message: Any, collaborator_name: str) -> Any:
"""
Forwards a converted message from the local client to the OpenFL server and returns the
response.
Args:
openfl_message: The converted message to be sent to the OpenFL server (InteropMessage
proto).
collaborator_name: The name of the collaborator.
Returns:
The response from the OpenFL server (InteropMessage proto).
"""
pass
3 changes: 2 additions & 1 deletion openfl/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2020-2024 Intel Corporation
# Copyright 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from openfl.transport.grpc import AggregatorGRPCClient, AggregatorGRPCServer
from openfl.transport.rest import AggregatorRESTClient, AggregatorRESTServer
5 changes: 4 additions & 1 deletion openfl/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import grpc

from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
from openfl.protocols.aggregator_client_interface import AggregatorClientInterface
from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -165,9 +166,11 @@ def wrapper(self, *args, **kwargs):
return wrapper


class AggregatorGRPCClient:
class AggregatorGRPCClient(AggregatorClientInterface):
"""Collaborator-side gRPC client that talks to the aggregator.

This class implements a gRPC client for communicating with an aggregator.

Attributes:
agg_addr (str): Aggregator address.
agg_port (int): Aggregator port.
Expand Down
6 changes: 6 additions & 0 deletions openfl/transport/rest/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from openfl.transport.rest.aggregator_client import AggregatorRESTClient
from openfl.transport.rest.aggregator_server import AggregatorRESTServer
Loading
Loading