diff --git a/openfl-workspace/workspace/plan/defaults/network.yaml b/openfl-workspace/workspace/plan/defaults/network.yaml index 654667240e..82372e822c 100644 --- a/openfl-workspace/workspace/plan/defaults/network.yaml +++ b/openfl-workspace/workspace/plan/defaults/network.yaml @@ -7,4 +7,5 @@ settings: client_reconnect_interval : 5 require_client_auth : True cert_folder : cert - enable_atomic_connections : False \ No newline at end of file + enable_atomic_connections : False + transport_protocol : grpc diff --git a/openfl/component/collaborator/collaborator.py b/openfl/component/collaborator/collaborator.py index 7f66805753..178bd8076a 100644 --- a/openfl/component/collaborator/collaborator.py +++ b/openfl/component/collaborator/collaborator.py @@ -14,7 +14,7 @@ from openfl.databases import TensorDB from openfl.pipelines import NoCompressionPipeline, TensorCodec from openfl.protocols import utils -from openfl.transport.grpc.aggregator_client import AggregatorGRPCClient +from openfl.transport.grpc.aggregator_client import AggregatorClientInterface from openfl.utilities import TensorKey logger = logging.getLogger(__name__) @@ -64,7 +64,7 @@ def __init__( collaborator_name, aggregator_uuid, federation_uuid, - client: AggregatorGRPCClient, + client: AggregatorClientInterface, task_runner, task_config, opt_treatment="RESET", diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index a324934e76..8cdb5a13e5 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -15,7 +15,12 @@ from openfl.interface.aggregation_functions import AggregationFunction, WeightedAverage from openfl.interface.cli_helper import WORKSPACE -from openfl.transport import AggregatorGRPCClient, AggregatorGRPCServer +from openfl.transport import ( + AggregatorGRPCClient, + AggregatorGRPCServer, + AggregatorRESTClient, + AggregatorRESTServer, +) from openfl.utilities.utils import getfqdn_env SETTINGS = "settings" @@ -542,8 +547,6 @@ def get_collaborator( else: defaults[SETTINGS]["client"] = self.get_client( collaborator_name, - self.aggregator_uuid, - self.federation_uuid, root_certificate, private_key, certificate, @@ -557,13 +560,11 @@ def get_collaborator( def get_client( self, collaborator_name, - aggregator_uuid, - federation_uuid, root_certificate=None, private_key=None, certificate=None, ): - """Get gRPC client for the specified collaborator. + """Get gRPC or REST client for the specified collaborator. Args: collaborator_name (str): Name of the collaborator. @@ -577,8 +578,38 @@ def get_client( Defaults to None. Returns: - AggregatorGRPCClient: gRPC client for the specified collaborator. + AggregatorGRPCClient or AggregatorRESTClient: gRPC or REST client for the collaborator. """ + client_args = self.get_client_args( + collaborator_name, + root_certificate, + private_key, + certificate, + ) + network_cfg = self.config["network"][SETTINGS] + protocol = network_cfg.get("transport_protocol", "grpc").lower() + + if self.client_ is None: + self.client_ = self._get_client(protocol, **client_args) + + return self.client_ + + def _get_client(self, protocol, **kwargs): + if protocol == "rest": + client = AggregatorRESTClient(**kwargs) + elif protocol == "grpc": + client = AggregatorGRPCClient(**kwargs) + else: + raise ValueError(f"Unsupported transport_protocol '{protocol}'") + return client + + def get_client_args( + self, + collaborator_name, + root_certificate=None, + private_key=None, + certificate=None, + ): common_name = collaborator_name if not root_certificate or not private_key or not certificate: root_certificate = "cert/cert_chain.crt" @@ -593,14 +624,10 @@ def get_client( client_args["certificate"] = certificate client_args["private_key"] = private_key - client_args["aggregator_uuid"] = aggregator_uuid - client_args["federation_uuid"] = federation_uuid + client_args["aggregator_uuid"] = self.aggregator_uuid + client_args["federation_uuid"] = self.federation_uuid client_args["collaborator_name"] = collaborator_name - - if self.client_ is None: - self.client_ = AggregatorGRPCClient(**client_args) - - return self.client_ + return client_args def get_server( self, @@ -609,7 +636,7 @@ def get_server( certificate=None, **kwargs, ): - """Get gRPC server of the aggregator instance. + """Get gRPC or REST server of the aggregator instance. Args: root_certificate (str, optional): Root certificate for the server. @@ -621,8 +648,29 @@ def get_server( **kwargs: Additional keyword arguments. Returns: - AggregatorGRPCServer: gRPC server of the aggregator instance. + Aggregator Server: returns either gRPC or REST server of the aggregator instance. """ + server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs) + + server_args["aggregator"] = self.get_aggregator() + network_cfg = self.config["network"][SETTINGS] + protocol = network_cfg.get("transport_protocol", "grpc").lower() + + if self.server_ is None: + self.server_ = self._get_server(protocol, **server_args) + + return self.server_ + + def _get_server(self, protocol, **kwargs): + if protocol == "rest": + server = AggregatorRESTServer(**kwargs) + elif protocol == "grpc": + server = AggregatorGRPCServer(**kwargs) + else: + raise ValueError(f"Unsupported transport_protocol '{protocol}'") + return server + + def get_server_args(self, root_certificate, private_key, certificate, kwargs): common_name = self.config["network"][SETTINGS]["agg_addr"].lower() if not root_certificate or not private_key or not certificate: @@ -638,13 +686,7 @@ def get_server( server_args["root_certificate"] = root_certificate server_args["certificate"] = certificate server_args["private_key"] = private_key - - server_args["aggregator"] = self.get_aggregator() - - if self.server_ is None: - self.server_ = AggregatorGRPCServer(**server_args) - - return self.server_ + return server_args def save_model_to_state_file(self, tensor_dict, round_number, output_path): """Save model weights to a protobuf state file. diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 043216d9f2..16dc48e9bf 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -92,8 +92,8 @@ def start_(plan, authorized_cols, task_group): logger.info(f"Setting aggregator to assign: {task_group} task_group") logger.info("🧿 Starting the Aggregator Service.") - - parsed_plan.get_server().serve() + server = parsed_plan.get_server() + server.serve() @aggregator.command(name="generate-cert-request") diff --git a/openfl/protocols/aggregator_client_interface.py b/openfl/protocols/aggregator_client_interface.py new file mode 100644 index 0000000000..331e1f9bf0 --- /dev/null +++ b/openfl/protocols/aggregator_client_interface.py @@ -0,0 +1,72 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AggregatorClientInterface module.""" + +from abc import ABC, abstractmethod +from typing import Any, List, Tuple + + +class AggregatorClientInterface(ABC): + @abstractmethod + def ping(self): + """ + Ping the aggregator to check connectivity. + """ + pass + + @abstractmethod + def get_tasks(self) -> Tuple[List[Any], int, int, bool]: + """ + Retrieves tasks for the given collaborator client. + Returns a tuple: (tasks, round_number, sleep_time, time_to_quit) + """ + pass + + @abstractmethod + def get_aggregated_tensor( + self, + tensor_name: str, + round_number: int, + report: bool, + tags: List[str], + require_lossless: bool, + ) -> Any: + """ + Retrieves the aggregated tensor. + """ + pass + + @abstractmethod + def send_local_task_results( + self, + round_number: int, + task_name: str, + data_size: int, + named_tensors: List[Any], + ) -> Any: + """ + Sends local task results. + Parameters: + collaborator_name: Name of the collaborator. + round_number: The current round. + task_name: Name of the task. + data_size: Size of the data. + named_tensors: A list of tensors (or named tensor objects). + Returns a SendLocalTaskResultsResponse. + """ + pass + + @abstractmethod + def send_message_to_server(self, openfl_message: Any, collaborator_name: str) -> Any: + """ + Forwards a converted message from the local client to the OpenFL server and returns the + response. + Args: + openfl_message: The converted message to be sent to the OpenFL server (InteropMessage + proto). + collaborator_name: The name of the collaborator. + Returns: + The response from the OpenFL server (InteropMessage proto). + """ + pass diff --git a/openfl/transport/__init__.py b/openfl/transport/__init__.py index 72bc7864c8..b757223351 100644 --- a/openfl/transport/__init__.py +++ b/openfl/transport/__init__.py @@ -1,5 +1,6 @@ -# Copyright 2020-2024 Intel Corporation +# Copyright 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from openfl.transport.grpc import AggregatorGRPCClient, AggregatorGRPCServer +from openfl.transport.rest import AggregatorRESTClient, AggregatorRESTServer diff --git a/openfl/transport/grpc/aggregator_client.py b/openfl/transport/grpc/aggregator_client.py index a91110b518..1022378030 100644 --- a/openfl/transport/grpc/aggregator_client.py +++ b/openfl/transport/grpc/aggregator_client.py @@ -11,6 +11,7 @@ import grpc from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils +from openfl.protocols.aggregator_client_interface import AggregatorClientInterface from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel logger = logging.getLogger(__name__) @@ -165,9 +166,11 @@ def wrapper(self, *args, **kwargs): return wrapper -class AggregatorGRPCClient: +class AggregatorGRPCClient(AggregatorClientInterface): """Collaborator-side gRPC client that talks to the aggregator. + This class implements a gRPC client for communicating with an aggregator. + Attributes: agg_addr (str): Aggregator address. agg_port (int): Aggregator port. diff --git a/openfl/transport/rest/__init__.py b/openfl/transport/rest/__init__.py new file mode 100644 index 0000000000..633f6dae84 --- /dev/null +++ b/openfl/transport/rest/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from openfl.transport.rest.aggregator_client import AggregatorRESTClient +from openfl.transport.rest.aggregator_server import AggregatorRESTServer diff --git a/openfl/transport/rest/aggregator_client.py b/openfl/transport/rest/aggregator_client.py new file mode 100644 index 0000000000..b97058b2f4 --- /dev/null +++ b/openfl/transport/rest/aggregator_client.py @@ -0,0 +1,638 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AggregatorRESTClient module.""" + +# Standard library imports +import logging +import ssl +import struct +import time +from typing import Any, List, Tuple + +# Third-party libraries +import requests +from google.protobuf import json_format +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +# Internal modules +from openfl.protocols import aggregator_pb2, base_pb2 +from openfl.protocols.aggregator_client_interface import AggregatorClientInterface + +logger = logging.getLogger(__name__) + + +class SecurityError(Exception): + """Security-related error.""" + + pass + + +class AggregatorRESTClient(AggregatorClientInterface): + def __init__( + self, + agg_addr, + agg_port, + aggregator_uuid: str, + federation_uuid: str, + collaborator_name: str, + use_tls=True, + require_client_auth=True, + root_certificate=None, + certificate=None, + private_key=None, + single_col_cert_common_name=None, + refetch_server_cert_callback=None, + **kwargs, + ): + """ + Initialize the AggregatorRESTClient with proper security settings. + + Args: + agg_addr: Aggregator address + agg_port: Aggregator port + aggregator_uuid: UUID of the aggregator + federation_uuid: UUID of the federation + collaborator_name: Name of the collaborator + use_tls: Whether to use TLS + require_client_auth: Whether to require client authentication + root_certificate: Path to root certificate + certificate: Path to client certificate + private_key: Path to client private key + single_col_cert_common_name: Common name for single collaborator certificate + refetch_server_cert_callback: Callback to refetch server certificate + """ + self.use_tls = use_tls + self.require_client_auth = require_client_auth + self.root_certificate = root_certificate + self.certificate = certificate + self.private_key = private_key + self.aggregator_uuid = aggregator_uuid + self.federation_uuid = federation_uuid + self.collaborator_name = collaborator_name + self.single_col_cert_common_name = single_col_cert_common_name + self.refetch_server_cert_callback = refetch_server_cert_callback + + # Determine scheme and TLS verification + scheme = "https" if self.use_tls else "http" + + # Configure certificate verification + self.cert_verification = self._configure_cert_verification( + self.use_tls, self.root_certificate + ) + + # Configure client certificates if required + if self.use_tls and self.require_client_auth: + if not self.certificate or not self.private_key: + raise ValueError( + "Both certificate and private key are required for mTLS " + "(client authentication). " + "Please provide both certificate and private key paths." + ) + self.cert = (self.certificate, self.private_key) + else: + self.cert = None + + # Configure session with proper settings + self.session = requests.Session() + + # Set default headers + self.session.headers.update( + { + "Connection": "keep-alive", + "Keep-Alive": "timeout=300", + "Accept": "application/json", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + } + ) + + # Configure timeouts with longer duration for large payloads + self.timeout = (30, 300) # (connect timeout, read timeout) in seconds + + # Configure retries with backoff + retry_strategy = Retry( + total=3, + backoff_factor=1, + status_forcelist=[408, 429, 500, 502, 503, 504], + allowed_methods=["GET", "POST"], + raise_on_status=True, + ) + + # Configure the adapter with the retry strategy + adapter = HTTPAdapter( + max_retries=retry_strategy, pool_connections=10, pool_maxsize=10, pool_block=False + ) + + # Mount the adapter for both HTTP and HTTPS + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + + # Build the base URL + self.base_url = f"{scheme}://{agg_addr}:{agg_port}/experimental/v1" + + # Log warning about experimental API + logger.warning( + "Initializing Aggregator REST Client (EXPERIMENTAL API - Not for production use)" + ) + + # Verify certificates if TLS is enabled + if self.use_tls: + try: + self._verify_certificates() + except Exception as e: + logger.error(f"Certificate verification failed: {e}") + raise + + @classmethod + def _configure_cert_verification( + cls, use_tls: bool, root_certificate: str = None + ) -> bool | str: + """ + Configure certificate verification settings for requests. + + Args: + use_tls: Whether TLS is enabled + root_certificate: Optional path to root certificate file + + Returns: + Union[bool, str]: Either True for system CA bundle, False for no verification, + or path to root certificate file + """ + if not use_tls: + return False + + if root_certificate: + return root_certificate + + return True # Use system's default CA bundle + + def _verify_certificates(self): + """Verify SSL certificates and configuration.""" + import socket + import ssl + + # Try to establish a test connection + try: + hostname = self.base_url.split("://")[1].split(":")[0] + port = int(self.base_url.split(":")[2].split("/")[0]) + + # Create SSL context with specific options + context = ssl.create_default_context() + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + + # Set secure cipher suites + context.set_ciphers("ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384") + + # Disable older TLS versions + context.options |= ( + ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1 + | ssl.OP_NO_TLSv1_2 + | ssl.OP_NO_COMPRESSION + | ssl.OP_NO_TICKET + ) + + if self.root_certificate: + context.load_verify_locations(cafile=self.root_certificate) + + if self.certificate and self.private_key: + context.load_cert_chain(certfile=self.certificate, keyfile=self.private_key) + + # Use context managers for proper resource cleanup + with socket.create_connection((hostname, port)) as sock: + with context.wrap_socket(sock, server_hostname=hostname) as _: + pass # Connection successful if we get here + + except ssl.SSLError as e: + if "CERTIFICATE_UNKNOWN" in str(e): + logger.error( + "Certificate unknown error - this usually means the " + "server's certificate is not trusted" + ) + logger.error("Please verify that:") + logger.error( + "1. The root certificate contains all necessary intermediate certificates" + ) + logger.error("2. The server's certificate is properly signed by a trusted CA") + logger.error("3. The hostname matches the certificate's subject") + raise + except Exception as e: + logger.error(f"Connection verification failed: {e}") + raise + + def _build_header(self) -> dict: + """Build and return a header dictionary with security headers.""" + headers = { + "Receiver": self.aggregator_uuid, + "Federation-UUID": self.federation_uuid, + "Single-Col-Cert-CN": self.single_col_cert_common_name or "", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + "Sender": self.collaborator_name, + } + if self.use_tls: + headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" + return headers + + def _make_request( + self, method, url, data=None, params=None, headers=None, stream=False, timeout=None + ): + """Make a request with proper security settings.""" + start_time = time.time() + try: + self._validate_url_scheme(url) + request_headers = self._prepare_headers(headers) + response = self._execute_request( + method, url, request_headers, data, params, stream, timeout + ) + self._validate_response(response) + logger.debug(f"Request completed in {time.time() - start_time:.2f} seconds") + return response + + except requests.exceptions.Timeout: + logger.error(f"Request timed out after {time.time() - start_time:.2f} seconds") + raise + except requests.exceptions.ConnectionError as e: + self._handle_connection_error(e) + raise + except requests.exceptions.RequestException as e: + self._handle_request_error(e) + raise + + def _validate_url_scheme(self, url): + """Validate URL scheme matches TLS setting.""" + if self.use_tls and not url.startswith("https://"): + raise ValueError("TLS required but URL is not HTTPS") + elif not url.startswith("http://") and not url.startswith("https://"): + raise ValueError("URL must use either HTTP or HTTPS scheme") + + def _prepare_headers(self, headers): + """Prepare request headers with security settings.""" + request_headers = self._build_header() + if headers: + request_headers.update(headers) + return request_headers + + def _execute_request(self, method, url, headers, data, params, stream, timeout): + """Execute the HTTP request with retry logic.""" + max_retries = 3 + for attempt in range(max_retries): + try: + session = requests.Session() + if self.use_tls: + # Extract hostname from URL for verification + hostname = url.split("://")[1].split(":")[0].split("/")[0] + + # Create a custom SSL context for this request + context = ssl.create_default_context( + cafile=self.root_certificate if self.root_certificate else None + ) + context.verify_mode = ssl.CERT_REQUIRED + + # Configure session with SSL context and hostname verification + session.verify = self.cert_verification + session.cert = self.cert + + # Configure adapter with proper SSL settings + adapter = HTTPAdapter( + pool_connections=1, + pool_maxsize=1, + max_retries=Retry( + total=3, + backoff_factor=1, + status_forcelist=[408, 429, 500, 502, 503, 504], + allowed_methods=["GET", "POST"], + ), + ) + session.mount("https://", adapter) + + # Build the complete headers with security information + base_headers = self._build_header() + if headers: + # Merge user-provided headers with base headers + base_headers.update(headers) + headers = base_headers + headers["Host"] = hostname + + # Add certificate info to request kwargs + request_kwargs = { + "method": method, + "url": url, + "headers": headers, + "data": data, + "params": params, + "stream": stream, + "verify": self.cert_verification, + "timeout": timeout or self.timeout, + } + + # Add client certificate if mTLS is enabled + if self.require_client_auth: + if not self.certificate or not self.private_key: + raise ValueError( + "Both certificate and private key are required for mTLS " + "(client authentication). " + "Please provide both certificate and private key paths." + ) + # Use proper cert format + request_kwargs["cert"] = (self.certificate, self.private_key) + + response = session.request(**request_kwargs) + else: + # For non-TLS requests, still use the security headers + base_headers = self._build_header() + if headers: + base_headers.update(headers) + + response = session.request( + method=method, + url=url, + headers=base_headers, + data=data, + params=params, + stream=stream, + timeout=timeout or self.timeout, + ) + return response + except requests.exceptions.SSLError as e: + self._handle_ssl_error(e, attempt, max_retries) + if attempt == max_retries - 1: + raise + + def _handle_ssl_error(self, e, attempt, max_retries): + """Handle SSL errors with retry logic.""" + if "CERTIFICATE_UNKNOWN" in str(e): + logger.error( + "Certificate unknown error - this usually means the " + "server's certificate is not trusted" + ) + logger.error("Please verify that:") + logger.error("1. The root certificate contains all necessary intermediate certificates") + logger.error("2. The server's certificate is properly signed by a trusted CA") + logger.error("3. The hostname matches the certificate's subject") + if attempt < max_retries - 1 and self.refetch_server_cert_callback: + logger.debug("Attempting to refetch server certificate") + self.root_certificate = self.refetch_server_cert_callback() + # Update the cert_verification with the new root certificate + self.cert_verification = self._configure_cert_verification( + self.use_tls, self.root_certificate + ) + # Re-verify certificates + try: + self._verify_certificates() + except Exception as verify_error: + logger.error(f"Certificate re-verification failed: {verify_error}") + raise + else: + raise + else: + if attempt < max_retries - 1: + logger.warning(f"SSL error (attempt {attempt + 1}/{max_retries}): {str(e)}") + time.sleep(1) + else: + raise + + def _validate_response(self, response): + """Validate response headers and security settings.""" + security_headers = { + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "1; mode=block", + } + for header, expected_value in security_headers.items(): + if header in response.headers and response.headers[header] != expected_value: + logger.warning(f"Missing or incorrect security header: {header}") + + response.raise_for_status() + + def _handle_connection_error(self, e): + """Handle connection errors.""" + logger.error(f"Connection error: {e}") + if hasattr(e, "args") and len(e.args) > 0: + logger.error(f"Connection error details: {e.args[0]}") + + def _handle_request_error(self, e): + """Handle request errors.""" + logger.error(f"Request failed: {e}") + if hasattr(e, "args") and len(e.args) > 0: + logger.error(f"Request error details: {e.args[0]}") + + def get_tasks(self) -> Tuple[List[Any], int, int, bool]: + """Get tasks from the aggregator with proper security settings.""" + headers = {"Accept": "application/json", "Sender": self.collaborator_name} + params = { + "collaborator_id": self.collaborator_name, + "federation_uuid": self.federation_uuid, + } + url = f"{self.base_url}/tasks" + response = self._make_request("GET", url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + tasks_resp = aggregator_pb2.GetTasksResponse() + json_format.ParseDict(data, tasks_resp) + + logger.debug( + f"Received tasks response - Round: {tasks_resp.round_number}, " + f"Tasks: {[t.name for t in tasks_resp.tasks]}, " + f"Sleep: {tasks_resp.sleep_time}, Quit: {tasks_resp.quit}" + ) + return tasks_resp.tasks, tasks_resp.round_number, tasks_resp.sleep_time, tasks_resp.quit + + def get_aggregated_tensor( + self, + tensor_name: str, + round_number: int, + report: bool, + tags: List[str], + require_lossless: bool, + ) -> Any: + """Get aggregated tensor with proper security settings.""" + params = { + "sender": self.collaborator_name, + "receiver": self.aggregator_uuid, + "federation_uuid": self.federation_uuid, + "tensor_name": tensor_name, + "round_number": round_number, + "report": report, + "tags": tags, + "require_lossless": require_lossless, + "collaborator_id": self.collaborator_name, + } + headers = {"Accept": "application/json", "Sender": self.collaborator_name} + url = f"{self.base_url}/tensors/aggregated" + extended_timeout = (30, 600) # 30 seconds connect, 10 minutes read timeout + try: + logger.debug(f"Requesting aggregated tensor {tensor_name} for round {round_number}") + response = self._make_request( + "GET", url, params=params, headers=headers, timeout=extended_timeout + ) + data = response.json() + resp = aggregator_pb2.GetAggregatedTensorResponse() + json_format.ParseDict(data, resp, ignore_unknown_fields=True) + logger.debug(f"Successfully retrieved tensor {tensor_name} for round {round_number}") + return resp.tensor + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + # This is expected during round 0 or when tensor hasn't been aggregated yet + logger.debug( + f"No aggregated tensor found for {tensor_name} at round {round_number}" + ) + return None + raise + + def send_local_task_results( + self, + round_number: int, + task_name: str, + data_size: int, + named_tensors: List[Any], + ) -> bool: + """Send local task results with proper security settings.""" + logger.debug(f"Sending task results for round {round_number}, task {task_name}") + + # Create the TaskResults message + task_results = aggregator_pb2.TaskResults( + header=aggregator_pb2.MessageHeader( + sender=self.collaborator_name, + receiver=self.aggregator_uuid, + federation_uuid=self.federation_uuid, + single_col_cert_common_name=self.single_col_cert_common_name or "", + ), + round_number=round_number, + task_name=task_name, + data_size=data_size, + tensors=named_tensors, + ) + + # Serialize the TaskResults first + task_results_bytes = task_results.SerializeToString() + logger.debug(f"TaskResults serialized size: {len(task_results_bytes)} bytes") + + # Create a DataStream message containing the TaskResults bytes + data_stream = base_pb2.DataStream(size=len(task_results_bytes), npbytes=task_results_bytes) + + # Create an empty DataStream to signal end of stream + end_stream = base_pb2.DataStream(size=0, npbytes=b"") + + # Serialize both messages + data_bytes = data_stream.SerializeToString() + end_bytes = end_stream.SerializeToString() + + # Create length-prefixed stream format + stream_data = ( + struct.pack(">I", len(data_bytes)) # Length prefix for first message + + data_bytes # First message + + struct.pack(">I", len(end_bytes)) # Length prefix for second message + + end_bytes # Second message (empty message signals end) + ) + + url = f"{self.base_url}/tasks/results" + request_headers = self._build_header() + request_headers["Sender"] = self.collaborator_name + request_headers["Content-Type"] = "application/x-protobuf-stream" + request_headers["Content-Length"] = str(len(stream_data)) + + try: + response = self._make_request( + "POST", + url, + data=stream_data, + headers=request_headers, + timeout=(30, 60), # Keep shorter timeout since we're sending all data at once + ) + response.raise_for_status() + logger.debug(f"Successfully sent task results for round {round_number}") + return True + except Exception as e: + logger.error(f"Failed to send task results for round {round_number}: {str(e)}") + logger.error(f"Error type: {type(e).__name__}") + logger.error(f"Request headers were: {request_headers}") + raise + + def ping(self): + """Ping the aggregator to check connectivity.""" + logger.info("Aggregator ping...") + headers = {"Accept": "application/json", "Sender": self.collaborator_name} + params = { + "collaborator_id": self.collaborator_name, + "federation_uuid": self.federation_uuid, + } + url = f"{self.base_url}/ping" + response = self._make_request("GET", url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + + # Validate response header like GRPC client + header = data.get("header", {}) + assert header.get("receiver") == self.collaborator_name, ( + f"Receiver in response header does not match collaborator name. " + f"Expected: {self.collaborator_name}, Actual: {header.get('receiver')}" + ) + assert header.get("sender") == self.aggregator_uuid, ( + f"Sender in response header does not match aggregator UUID. " + f"Expected: {self.aggregator_uuid}, Actual: {header.get('sender')}" + ) + assert header.get("federationUuid") == self.federation_uuid, ( + f"Federation UUID in response header does not match. " + f"Expected: {self.federation_uuid}, Actual: {header.get('federationUuid')}" + ) + assert header.get("singleColCertCommonName", "") == ( + self.single_col_cert_common_name or "" + ), ( + f"Single collaborator certificate common name in response header does not match. " + f"Expected: {self.single_col_cert_common_name}, " + f"Actual: {header.get('singleColCertCommonName')}" + ) + + logger.info("Aggregator pong!") + + def send_message_to_server(self, openfl_message: Any, collaborator_name: str) -> Any: + """ + Forwards a converted message from the local REST client to the OpenFL server and returns + the response. + + Args: + openfl_message: The InteropMessage proto to be sent to the OpenFL server. + collaborator_name: The name of the collaborator. + + Returns: + The response from the OpenFL server (InteropMessage proto). + """ + # Set the header fields + header = aggregator_pb2.MessageHeader( + sender=collaborator_name, + receiver=self.aggregator_uuid, + federation_uuid=self.federation_uuid, + single_col_cert_common_name=self.single_col_cert_common_name or "", + ) + openfl_message.header.CopyFrom(header) + + # Serialize to JSON + json_payload = json_format.MessageToJson(openfl_message) + url = f"{self.base_url}/interop/relay" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Sender": collaborator_name, + } + response = self._make_request( + "POST", + url, + data=json_payload, + headers=headers, + timeout=(30, 300), + ) + response.raise_for_status() + response_json = response.json() + openfl_response = aggregator_pb2.InteropMessage() + json_format.ParseDict(response_json, openfl_response, ignore_unknown_fields=True) + return openfl_response + + def __del__(self): + """Cleanup when the client is destroyed.""" + self.session.close() diff --git a/openfl/transport/rest/aggregator_server.py b/openfl/transport/rest/aggregator_server.py new file mode 100644 index 0000000000..f581676871 --- /dev/null +++ b/openfl/transport/rest/aggregator_server.py @@ -0,0 +1,887 @@ +# Copyright 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""AggregatorRESTServer module.""" + +import logging +import ssl +import threading +import time +from functools import wraps +from random import random +from time import sleep + +from flask import Flask, abort, jsonify, request +from google.protobuf import json_format +from werkzeug.serving import make_server + +from openfl.protocols import aggregator_pb2, base_pb2 + +logger = logging.getLogger(__name__) + + +def synchronized(func): + """Synchronization decorator.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + with self._lock: + return func(self, *args, **kwargs) + + return wrapper + + +def create_header(sender, receiver, federation_uuid, single_col_cert_common_name=""): + """Create a standard message header with consistent fields.""" + return aggregator_pb2.MessageHeader( + sender=str(sender), + receiver=str(receiver), + federation_uuid=str(federation_uuid), + single_col_cert_common_name=single_col_cert_common_name or "", + ) + + +class AggregatorRESTServer: + """REST server for the aggregator.""" + + def __init__( + self, + aggregator, + agg_addr, + agg_port, + use_tls=True, + require_client_auth=True, + certificate=None, + private_key=None, + root_certificate=None, + **kwargs, + ): + """Initialize REST server with security defaults.""" + # Initialize lock for synchronized methods + self._lock = threading.Lock() + + # Set up base configuration + self.aggregator = aggregator + self.host = agg_addr + self.port = agg_port + + # Set API prefix + self.api_prefix = "experimental/v1" + + # Set security defaults + self.use_tls = use_tls + self.require_client_auth = require_client_auth + self.ssl_context = None + + # Set up server components with security focus + self._setup_server_components(certificate, private_key, root_certificate) + + # Set up routes with synchronized access + self._setup_routes() + + # Build the base URL + scheme = "https" if use_tls else "http" + self.base_url = f"{scheme}://{agg_addr}:{agg_port}/{self.api_prefix}" + + def _setup_ssl_context(self, certificate, private_key, root_certificate): + """Set up SSL context for TLS/mTLS.""" + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + + # Set secure cipher suites + ssl_context.set_ciphers("ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384") + + # Disable older TLS versions and set security options + ssl_context.options |= ( + ssl.OP_NO_TLSv1 + | ssl.OP_NO_TLSv1_1 + | ssl.OP_NO_TLSv1_2 + | ssl.OP_NO_COMPRESSION + | ssl.OP_NO_TICKET # Disable session tickets + | ssl.OP_CIPHER_SERVER_PREFERENCE # Server chooses cipher + | ssl.OP_SINGLE_DH_USE # Ensure perfect forward secrecy with DHE + | ssl.OP_SINGLE_ECDH_USE # Ensure perfect forward secrecy with ECDHE + ) + + # Set verification flags for strict certificate checking + ssl_context.verify_flags = ( + ssl.VERIFY_X509_STRICT | ssl.VERIFY_CRL_CHECK_CHAIN # Check certificate revocation + ) + + # Configure client certificate verification + if self.require_client_auth: + ssl_context.verify_mode = ssl.CERT_REQUIRED + # Load root CA for client cert verification + if root_certificate: + try: + ssl_context.load_verify_locations(cafile=root_certificate) + except Exception as e: + logger.error(f"Failed to load root CA certificate: {str(e)}") + raise + else: + logger.error("Root certificate is required when client authentication is enabled") + raise ValueError("Root certificate is required for mTLS") + else: + ssl_context.verify_mode = ssl.CERT_NONE + + # Load server certificate and key + try: + ssl_context.load_cert_chain(certfile=certificate, keyfile=private_key) + except Exception as e: + logger.error(f"Failed to load server certificate and key: {str(e)}") + raise + + # Load and trust the root CA certificate + if root_certificate: + try: + ssl_context.load_verify_locations(cafile=root_certificate) + except Exception as e: + logger.error(f"Failed to load root CA certificate: {str(e)}") + raise + # Enable post-handshake authentication for better security + if hasattr(ssl_context, "post_handshake_auth"): + ssl_context.post_handshake_auth = True + + # Set verification purpose + ssl_context.purpose = ssl.Purpose.CLIENT_AUTH + + return ssl_context + + def _setup_flask_app(self): + """Configure Flask application with proper settings for both TLS and non-TLS modes.""" + app = Flask(__name__) + + # Set session and file age defaults + app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 1800 # 30 minutes + app.config["PERMANENT_SESSION_LIFETIME"] = 1800 # 30 minutes + + # Configure logging to be minimal + import logging + + # Disable Flask's default logging + log = logging.getLogger("werkzeug") + log.setLevel(logging.ERROR) + + # Add security headers + @app.after_request + def add_security_headers(response): + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + if self.use_tls: + response.headers["Strict-Transport-Security"] = ( + "max-age=31536000; includeSubDomains" + ) + return response + + return app + + def _validate_client_certificate(self, request_environ, collaborator_name): + """ + Validate client certificate when mTLS is enabled. + Args: + request_environ: The request environment containing SSL information + collaborator_name: The collaborator name from the request (from header.sender) + Returns: + bool: True if validation passes + Raises: + abort: HTTP error if validation fails + """ + if not self.use_tls: + return True + + try: + # Default to collaborator name (like gRPC) + common_name = collaborator_name + + # Get certificate information if client auth is required + if self.require_client_auth: + cert_cn = self._get_certificate_cn(request_environ, collaborator_name) + if not cert_cn: + abort(401, "Client certificate validation failed - certificate not found") + common_name = cert_cn + + # Validate collaborator identity + return self._validate_collaborator(common_name, collaborator_name) + + except Exception as e: + logger.error(f"Certificate validation failed: {str(e)}") + abort(401, str(e)) + + def _get_certificate_cn(self, request_environ, collaborator_name): + """Get certificate CN from environment or headers.""" + # Try to get certificate info from environment + peercert = request_environ.get("SSL_CLIENT_CERT") + cert_cn = request_environ.get("SSL_CLIENT_S_DN_CN") + + # Try to extract CN if we have certificate but no CN + if peercert and not cert_cn: + try: + cert_cn = self._extract_cn_from_cert(peercert) + except Exception as e: + logger.error(f"Failed to extract CN from certificate: {e}") + + # If no certificate found, try fallback methods + if not peercert: + # Try header-based fallback for experimental mode + cert_cn = self._try_header_fallback(collaborator_name) + + return cert_cn + + def _try_header_fallback(self, collaborator_name): + """Try to get CN from headers as fallback in experimental mode.""" + # FALLBACK: In experimental mode, allow using header-based auth + # This should NOT be used in production + try: + from flask import request + + # Use Sender header as fallback + if hasattr(request, "headers") and "Sender" in request.headers: + cert_cn = request.headers.get("Sender") + return cert_cn + except Exception as e: + logger.error(f"Error in header fallback: {e}") + + # THIS SHOULD BE REMOVED POST EXPERIMENTAL MODE + return collaborator_name + + def _validate_collaborator(self, common_name, collaborator_name): + """Validate collaborator identity.""" + if not self.aggregator.valid_collaborator_cn_and_id(common_name, collaborator_name): + # Add timing attack protection + sleep(5 * random()) + logger.error( + f"Invalid collaborator. CN: |{common_name}| " + f"collaborator_name: |{collaborator_name}|" + ) + abort(401, "Collaborator validation failed") + + return True + + def _extract_cn_from_cert(self, cert_pem): + """Extract CN from a PEM certificate using standard libraries.""" + import re + + pass + pass + + # Try regex approach first (most reliable with PEM format) + cn_match = re.search( + r"CN\s*=\s*([^,/\n]+)", + cert_pem.decode("utf-8") if isinstance(cert_pem, bytes) else cert_pem, + ) + if cn_match: + return cn_match.group(1).strip() + + # Try using cryptography if available + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + + # Convert PEM to certificate object + cert_data = cert_pem.encode("utf-8") if isinstance(cert_pem, str) else cert_pem + cert = x509.load_pem_x509_certificate(cert_data, default_backend()) + + # Extract CN from subject + for attribute in cert.subject: + if attribute.oid._name == "commonName": + return attribute.value + except ImportError: + pass + + # Fallback: use the collaborator name from the environment + return None + + def _setup_interop_client(self): + """Set up inter-federation connector client.""" + try: + return self.aggregator.get_interop_client() + except AttributeError: + return None + + def _is_authorized(self, collaborator_id, federation_id, cert_common_name=None): + """ + Validate collaborator identity with strict checks. + + Args: + collaborator_id (str): The collaborator's ID + federation_id (str): The federation UUID + cert_common_name (str, optional): Certificate CN if using mTLS + + Returns: + bool: True if validation passes + + Raises: + abort: HTTP error if validation fails + """ + is_valid = False + try: + # Validate collaborator identity + if not collaborator_id: + logger.error("Collaborator identity not provided") + abort(400, "Collaborator identity not provided") + + # First check if collaborator is authorized + if collaborator_id not in self.aggregator.authorized_cols: + logger.error(f"Collaborator not in authorized list. Got: {collaborator_id}") + abort(401, "Unauthorized collaborator") + + # Validate collaborator identity + common_name = cert_common_name if cert_common_name is not None else collaborator_id + if not self.aggregator.valid_collaborator_cn_and_id(common_name, collaborator_id): + logger.error( + f"Collaborator validation failed. CN: {common_name}, ID: {collaborator_id}" + ) + abort(401, "Collaborator validation failed") + + # Validate client certificate if mTLS is enabled + if self.use_tls and self.require_client_auth: + self._validate_client_certificate(request.environ, collaborator_id) + + # Verify federation UUID + if federation_id != str(self.aggregator.federation_uuid): + logger.error( + f"Federation UUID mismatch. Expected: {self.aggregator.federation_uuid}, " + f"Got: {federation_id}" + ) + abort(401, "Federation UUID mismatch") + + is_valid = True + return True + + except Exception as e: + logger.error(f"Validation failed: {str(e)}") + abort(401, str(e)) + finally: + # Add timing attack protection for all error cases + if not is_valid: + sleep(5 * random()) + + def _validate_task_headers(self, headers): + """ + Validate task submission headers with timing attack protection. + + Args: + headers (dict): Request headers + + Returns: + str: Validated collaborator name + + Raises: + abort: HTTP error if validation fails + """ + try: + # Get collaborator identity from certificate or headers + collab_name = None + if self.use_tls and self.require_client_auth: + # Try to get from certificate first + collab_name = request.environ.get("SSL_CLIENT_S_DN_CN") + logger.debug(f"Using certificate CN: {collab_name}") + + # If not from certificate, try headers + if not collab_name: + collab_name = headers.get("Sender") + if not collab_name: + sleep(5 * random()) # Add timing attack protection + logger.error("No Sender header provided") + abort(401, "No Sender header provided") + + # Get other required headers + receiver = headers.get("Receiver") + federation_id = headers.get("Federation-UUID") + cert_common_name = headers.get("Single-Col-Cert-CN", "") + + # Validate collaborator identity + if not self.aggregator.valid_collaborator_cn_and_id(collab_name, collab_name): + sleep(5 * random()) # Add timing attack protection + msg = f"CN: {collab_name}, ID: {collab_name}" + logger.error(f"Collaborator validation failed. {msg}") + abort(401, "Collaborator validation failed") + + # Verify all headers with strict validation + assert receiver == str(self.aggregator.uuid), ( + f"Header receiver mismatch. Expected: {self.aggregator.uuid}, Got: {receiver}" + ) + + assert federation_id == str(self.aggregator.federation_uuid), ( + f"Federation UUID mismatch. Expected: {self.aggregator.federation_uuid}, " + f"Got: {federation_id}" + ) + + expected_cn = self.aggregator.single_col_cert_common_name or "" + assert cert_common_name == expected_cn, ( + f"Single col cert CN mismatch. Expected: {expected_cn}, Got: {cert_common_name}" + ) + + return collab_name + except AssertionError as e: + sleep(5 * random()) # Add timing attack protection + logger.error(f"Header validation failed: {str(e)}") + abort(401, str(e)) + + def _parse_protobuf_stream(self, data): + """Parse protobuf stream data.""" + logger.debug(f"Received {len(data)} bytes of protobuf stream data") + + # First message is DataStream containing TaskResults + msg_len = int.from_bytes(data[:4], byteorder="big") + logger.debug(f"First message length: {msg_len}") + data_stream_bytes = data[4 : 4 + msg_len] + data_stream = base_pb2.DataStream() + data_stream.ParseFromString(data_stream_bytes) + logger.debug(f"Parsed DataStream with size: {data_stream.size}") + + # Extract TaskResults from DataStream + task_results = aggregator_pb2.TaskResults() + task_results.ParseFromString(data_stream.npbytes) + + # Log task details + task_info = ( + f"Task: {task_results.task_name}, " + f"Round: {task_results.round_number}, " + f"Size: {task_results.data_size}, " + f"Tensors: {len(task_results.tensors)}" + ) + logger.debug(f"Extracted TaskResults from DataStream - {task_info}") + + # Verify end message + end_msg_offset = 4 + msg_len + end_msg_len = int.from_bytes(data[end_msg_offset : end_msg_offset + 4], byteorder="big") + logger.debug(f"End message length: {end_msg_len}") + + if end_msg_len != 0: + logger.error(f"Invalid end message length: {end_msg_len}") + abort(400, "Invalid stream format - expected empty end message") + + # Verify total length + expected_total_len = 4 + msg_len + 4 + end_msg_len + if len(data) != expected_total_len: + msg = f"Got {len(data)}, expected {expected_total_len}" + logger.error(f"Data length mismatch. {msg}") + abort(400, "Invalid stream data length") + + return task_results + + def _build_tasks_response( + self, + tasks_list, + round_number, + sleep_time, + time_to_quit, + collab_id, + ): + """Build GetTasksResponse protobuf.""" + tasks_proto = [] + if tasks_list: + if isinstance(tasks_list[0], str): + # Backward compatibility: list of task names + tasks_proto = [aggregator_pb2.Task(name=t) for t in tasks_list] + else: + tasks_proto = [ + aggregator_pb2.Task( + name=getattr(t, "name", ""), + function_name=getattr(t, "function_name", ""), + task_type=getattr(t, "task_type", ""), + apply_local=getattr(t, "apply_local", False), + ) + for t in tasks_list + ] + + # Create response header + header = create_header( + sender=str(self.aggregator.uuid), + receiver=collab_id, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + + return aggregator_pb2.GetTasksResponse( + header=header, + round_number=round_number, + tasks=tasks_proto, + sleep_time=sleep_time, + quit=time_to_quit, + ) + + def _setup_server_components(self, certificate=None, private_key=None, root_certificate=None): + """Set up server components including SSL, Flask app, and interop client.""" + # Set up SSL if enabled + if self.use_tls: + self.ssl_context = self._setup_ssl_context(certificate, private_key, root_certificate) + else: + self.ssl_context = None # Explicitly set to None when TLS is disabled + + # Set up Flask app + self.app = self._setup_flask_app() + + # Set up interop client + self.interop_client = self._setup_interop_client() + self.use_connector = self.interop_client is not None + + def _setup_routes(self): + """Set up Flask routes.""" + # Register the route handlers + self._setup_ping_route() + self._setup_tasks_route() + self._setup_task_results_route() + self._setup_tensor_route() + self._setup_relay_route() + # Add middleware for client certificate extraction + self._setup_certificate_middleware() + + def _setup_certificate_middleware(self): + """Set up middleware to capture SSL certificates from client connections.""" + + @self.app.before_request + def extract_client_cert(): + """Extract client certificate and add it to request environment.""" + if not (self.use_tls and self.require_client_auth): + return None + + # Get SSL connection information + try: + from flask import request + + # Try to extract certificate from the socket + cert_data = self._extract_certificate_from_socket(request.environ) + if cert_data: + # Process the certificate data + self._process_certificate_data(request.environ, cert_data) + + except Exception as e: + logger.warning(f"Failed to extract client certificate: {e}") + # Continue processing the request even if cert extraction fails + + return None + + def _extract_certificate_from_socket(self, environ): + """Extract the certificate from the socket if available.""" + # Access underlying SSL socket if possible + transport = environ.get("werkzeug.socket") + if not (transport and hasattr(transport, "getpeercert")): + return None + + # Extract certificate from socket + return transport.getpeercert(binary_form=True) + + def _process_certificate_data(self, environ, der_cert): + """Process the DER certificate data and store in environment.""" + if not der_cert: + return False + + # Convert DER to PEM format using built-in libraries + try: + # Try using cryptography if available + cn, pem_cert = self._convert_der_using_cryptography(der_cert) + if pem_cert: + environ["SSL_CLIENT_CERT"] = pem_cert + if cn: + environ["SSL_CLIENT_S_DN_CN"] = cn + logger.info(f"Extracted client certificate CN: {cn}") + return True + except ImportError: + # Fall back to regex method + return self._try_regex_cn_extraction(environ, der_cert) + except Exception as e: + logger.warning(f"Error converting certificate format: {e}") + + return False + + def _convert_der_using_cryptography(self, der_cert): + """Convert DER certificate using cryptography library.""" + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + cert = x509.load_der_x509_certificate(der_cert, default_backend()) + pem_cert = cert.public_bytes(encoding=serialization.Encoding.PEM) + + # Parse the subject to get CN + cn = None + for attribute in cert.subject: + if attribute.oid._name == "commonName": + cn = attribute.value + break + + return cn, pem_cert + + def _try_regex_cn_extraction(self, environ, der_cert): + """Try to extract CN using regex from binary certificate.""" + try: + import binascii + import re + + # Convert to hex and then look for CN + hex_data = binascii.hexlify(der_cert).decode("ascii") + # Look for common name pattern in hex + # This is a simplified approach and may not work for all certs + cn_pattern = ( + r"(?:3[0-9]|4[0-9]|5[0-9])(?:06|07|08|09|0a|0b|0c|0d|0e|0f)" + r"(?:03|04|05|06)(?:13|14|15|16)(.{2,60})(?:30|31)" + ) + cn_match = re.search(cn_pattern, hex_data) + if cn_match: + # Convert hex to ASCII + cn_hex = cn_match.group(1) + try: + cn = binascii.unhexlify(cn_hex).decode("utf-8") + environ["SSL_CLIENT_S_DN_CN"] = cn + logger.info(f"Extracted client certificate CN using regex: {cn}") + return True + except Exception as e: + logger.warning(f"Failed to decode CN: {e}") + + return False + except Exception as e: + logger.warning(f"Error in regex CN extraction: {e}") + return False + + def _setup_ping_route(self): + """Set up the /ping endpoint.""" + + @self.app.route(f"/{self.api_prefix}/ping", methods=["GET"]) + def ping(): + """Simple ping endpoint to check server connectivity.""" + try: + # Get collaborator identity from certificate or query param + collaborator_id = None + if self.require_client_auth: + collaborator_id = request.environ.get("SSL_CLIENT_S_DN_CN") + if collaborator_id is None: + collaborator_id = request.args.get("collaborator_id") + + federation_id = request.args.get("federation_uuid") + + # Use the consolidated validation method + self._is_authorized(collaborator_id, federation_id) + + # Create response header + header = create_header( + sender=str(self.aggregator.uuid), + receiver=collaborator_id, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + + # Return response in same format as GRPC + return jsonify({"header": json_format.MessageToDict(header)}) + except Exception as e: + logger.error(f"Ping request failed: {str(e)}") + abort(401, str(e)) + + def _setup_tasks_route(self): + """Set up the /tasks endpoint.""" + + @self.app.route(f"/{self.api_prefix}/tasks", methods=["GET"]) + def get_tasks(): + """Endpoint for collaborators to fetch pending tasks.""" + # Get collaborator identity from certificate or query param + collaborator_id = None + if self.require_client_auth: + collaborator_id = request.environ.get("SSL_CLIENT_S_DN_CN") + if collaborator_id is None: + collaborator_id = request.args.get("collaborator_id") + + federation_id = request.args.get("federation_uuid") + + # Use the consolidated validation method + self._is_authorized(collaborator_id, federation_id) + + # Check if connector mode is enabled + if self.use_connector: + abort(501, "GetTasks not supported in connector mode") + + # Fetch tasks from Aggregator core - directly delegate to the aggregator + tasks_list, round_number, sleep_time, time_to_quit = self.aggregator.get_tasks( + collaborator_id + ) + + # Log task assignment + task_names = [getattr(t, "name", t) for t in (tasks_list or [])] + logger.debug( + f"Collaborator {collaborator_id} requested tasks. " + f"Round: {round_number}, Tasks: {task_names}, " + f"Sleep: {sleep_time}, Quit: {time_to_quit}" + ) + + # Build and return response + response_proto = self._build_tasks_response( + tasks_list, round_number, sleep_time, time_to_quit, collaborator_id + ) + return jsonify(json_format.MessageToDict(response_proto)) + + def _setup_task_results_route(self): + """Set up the /tasks/results endpoint.""" + + @self.app.route(f"/{self.api_prefix}/tasks/results", methods=["POST"]) + def post_task_results(): + """Handle task results submission.""" + try: + # Validate headers and get collaborator name + collab_name = self._validate_task_headers(request.headers) + + # Parse protobuf stream data + task_results = self._parse_protobuf_stream(request.data) + + # Direct delegation to the aggregator for task results processing + # This matches the gRPC approach of calling send_local_task_results directly + self.aggregator.send_local_task_results( + collab_name, + task_results.round_number, + task_results.task_name, + task_results.data_size, + task_results.tensors, + ) + + return jsonify({"status": "success"}) + + except Exception as e: + logger.error(f"Error processing task results: {str(e)}") + abort(400, f"Error processing task results: {str(e)}") + + def _setup_tensor_route(self): + """Set up the /tensors/aggregated endpoint.""" + + @self.app.route(f"/{self.api_prefix}/tensors/aggregated", methods=["GET"]) + def get_aggregated_tensor(): + """Endpoint for collaborators to retrieve an aggregated tensor.""" + start_time = time.time() + + # Validate that this endpoint is not used in connector mode + if self.use_connector: + abort(501, "GetAggregatedTensor not supported in connector mode") + + # Get and validate collaborator identity + collaborator_id = request.args.get("collaborator_id") + federation_id = request.args.get("federation_uuid") + + # Use the consolidated validation method + self._is_authorized(collaborator_id, federation_id) + + # Extract tensor request parameters + tensor_name = request.args.get("tensor_name") + try: + round_number = int(request.args.get("round_number", 0)) + except (TypeError, ValueError): + abort(400, "Invalid round number") + report = request.args.get("report", "").lower() == "true" + tags = request.args.getlist("tags") + require_lossless = request.args.get("require_lossless", "").lower() == "true" + + # Get the tensor from aggregator - direct delegation to the aggregator + named_tensor = self.aggregator.get_aggregated_tensor( + tensor_name, + round_number, + report=report, + tags=tuple(tags), + require_lossless=require_lossless, + requested_by=collaborator_id, + ) + + # Create response header using the standardized method + header = create_header( + sender=str(self.aggregator.uuid), + receiver=collaborator_id, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + + # Create response with empty tensor if not found + response_proto = aggregator_pb2.GetAggregatedTensorResponse( + header=header, + round_number=round_number, + tensor=named_tensor + if named_tensor is not None + else aggregator_pb2.NamedTensorProto(), + ) + + logger.debug(f"Tensor retrieval completed in {time.time() - start_time:.2f} seconds") + return jsonify(json_format.MessageToDict(response_proto)) + + def _setup_relay_route(self): + """Set up the /interop/relay endpoint.""" + + @self.app.route(f"/{self.api_prefix}/interop/relay", methods=["POST"]) + def relay_message(): + """Endpoint for collaborator-to-aggregator message relay.""" + # This endpoint is optional; only enable if connector mode is configured + if not self.use_connector or self.interop_client is None: + abort(501, "Interop relay is not enabled on this aggregator") + + # Parse the incoming JSON to an InteropRelay protobuf message + try: + relay_req = json_format.Parse( + request.data.decode("utf-8"), aggregator_pb2.InteropRelay() + ) + except Exception as e: + abort(400, f"Invalid InteropRelay payload: {e}") + + # Validate the collaborator via header + collab_name = relay_req.header.sender + self._is_authorized(collab_name, relay_req.header.federation_uuid) + + if relay_req.header.receiver != str(self.aggregator.uuid): + abort(400, "Header receiver mismatch") + + # Forward the request to the configured interop connector and get response + logger.debug( + f"Relaying message from {collab_name} to external federation via connector" + ) + # Create a header for forwarding using the standardized method + forward_header = create_header( + sender=str(self.aggregator.uuid), + receiver=relay_req.header.receiver, + federation_uuid=str(self.aggregator.federation_uuid), + single_col_cert_common_name=self.aggregator.single_col_cert_common_name or "", + ) + # Use the aggregator's interop client to send and receive + response_proto = self.interop_client.send_receive(relay_req, header=forward_header) + # Return the response from the remote as JSON + return jsonify(json_format.MessageToDict(response_proto)) + + def serve(self): + """Start the REST server with proper configuration for both TLS and non-TLS modes.""" + # If connector mode is enabled, start the connector service + if self.use_connector: + try: + self.aggregator.start_connector() + except AttributeError: + pass + + # Configure server based on TLS mode + if self.use_tls and self.ssl_context: + server = make_server( + self.host, + self.port, + self.app, + ssl_context=self.ssl_context, + threaded=True, # Enable threading for better performance + ) + else: + server = make_server( + self.host, + self.port, + self.app, + threaded=True, # Enable threading for better performance + ) + + # Configure server thread + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + logger.warning( + "Starting Aggregator REST Server (EXPERIMENTAL API - Not for production use)" + ) + thread.start() + + try: + while not self.aggregator.all_quit_jobs_sent(): + sleep(5) + finally: + # Synchronized shutdown + if self.use_connector: + try: + self.aggregator.stop_connector() + except AttributeError: + pass + server.shutdown() + thread.join() + logger.info("Aggregator REST Server stopped.") diff --git a/setup.py b/setup.py index 884618e146..07f14a7100 100644 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ def run(self): 'tensorboardX', 'protobuf>=4.21,<6.0.0', 'grpcio>=1.56.2,<1.66.0', + 'Flask==3.1.0', ], python_requires='>=3.10, <3.13', project_urls={ diff --git a/test-requirements.txt b/test-requirements.txt index 694fa5bf2b..cfea433d42 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,8 +1,10 @@ docker +Flask==3.1.0 lxml==5.3.1 paramiko pytest==8.3.5 pytest-asyncio==0.26.0 +pytest-cov>=2.10.0 pytest-mock==3.14.0 defusedxml==0.7.1 matplotlib==3.10.1 @@ -14,3 +16,4 @@ boto3>=1.37.19 moto==5.1.1 torchvision==0.22.0 azure-storage-blob==12.25.1 +cryptography>=3.4.0 diff --git a/tests/openfl/federated/plan/test_plan.py b/tests/openfl/federated/plan/test_plan.py index dcc1f9707a..a4d63339b8 100644 --- a/tests/openfl/federated/plan/test_plan.py +++ b/tests/openfl/federated/plan/test_plan.py @@ -10,6 +10,8 @@ from openfl.federated.plan.plan import Plan from openfl.component.assigner import RandomGroupedAssigner from openfl.component.aggregator import Aggregator +from openfl.transport.rest.aggregator_server import AggregatorRESTServer +from openfl.transport.grpc.aggregator_server import AggregatorGRPCServer @pytest.fixture @@ -47,3 +49,27 @@ def test_get_aggregator(mocker, plan): mocker.patch('openfl.protocols.utils.load_proto', mock.Mock()) Aggregator._load_initial_tensors = mock.Mock() assert isinstance(plan.get_aggregator(), Aggregator) + +def test_get_server_rest(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + mock_setup_ssl = mocker.patch('openfl.transport.rest.aggregator_server.AggregatorRESTServer._setup_ssl_context', return_value=mock.Mock()) + plan.config['network']['settings']['transport_protocol'] = 'rest' + server = plan.get_server() + assert isinstance(server, AggregatorRESTServer) + +def test_get_server_grpc(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + plan.config['network']['settings']['transport_protocol'] = 'grpc' + server = plan.get_server() + assert isinstance(server, AggregatorGRPCServer) + +def test_get_server_default_certificates(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + server = plan.get_server() + assert isinstance(server, AggregatorGRPCServer) # Default to gRPC + +def test_get_server_invalid_protocol(plan,mocker): + mocker.patch('openfl.protocols.utils.load_proto', return_value=mock.Mock()) + plan.config['network']['settings']['transport_protocol'] = 'invalid_protocol' + with pytest.raises(ValueError): + plan.get_server() diff --git a/tests/openfl/transport/rest/__init__.py b/tests/openfl/transport/rest/__init__.py new file mode 100644 index 0000000000..ba2df757ff --- /dev/null +++ b/tests/openfl/transport/rest/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Transport tests package.""" diff --git a/tests/openfl/transport/rest/conftest.py b/tests/openfl/transport/rest/conftest.py new file mode 100644 index 0000000000..f1572cd34a --- /dev/null +++ b/tests/openfl/transport/rest/conftest.py @@ -0,0 +1,39 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Shared test configurations for transport tests.""" + +import pytest +import logging +from pathlib import Path + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Configure logging for tests.""" + logging.basicConfig(level=logging.DEBUG) + yield + + +@pytest.fixture(autouse=True) +def mock_environment(monkeypatch): + """Mock environment variables and system settings.""" + monkeypatch.setenv('PYTHONPATH', '') # Clear PYTHONPATH to avoid interference + yield + + +@pytest.fixture +def test_data_dir(): + """Get the test data directory.""" + return Path(__file__).parent / 'data' + + +@pytest.fixture(autouse=True) +def setup_test_data(test_data_dir): + """Set up test data directory.""" + test_data_dir.mkdir(exist_ok=True) + yield + # Cleanup after tests if needed + if test_data_dir.exists(): + for file in test_data_dir.glob('*'): + if file.is_file(): + file.unlink() diff --git a/tests/openfl/transport/rest/test_rest_server.py b/tests/openfl/transport/rest/test_rest_server.py new file mode 100644 index 0000000000..8c3273fe6c --- /dev/null +++ b/tests/openfl/transport/rest/test_rest_server.py @@ -0,0 +1,329 @@ +# Copyright (C) 2020-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""REST server tests module.""" + +import pytest +import ssl +from unittest import mock +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from datetime import datetime, timedelta + +from openfl.transport.rest.aggregator_server import AggregatorRESTServer +from openfl.protocols import aggregator_pb2, base_pb2 + + +def generate_test_certificates(cert_path, key_path, root_cert_path): + """Generate self-signed certificates for testing.""" + # Generate private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048 + ) + + # Generate self-signed certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, u"test.example.com"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Test Organization"), + ]) + + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + private_key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.utcnow() + ).not_valid_after( + datetime.utcnow() + timedelta(days=1) + ).sign(private_key, hashes.SHA256()) + + # Write private key + with open(key_path, "wb") as f: + f.write(private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + )) + + # Write certificate + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + # For testing, use the same cert as root CA + with open(root_cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + +@pytest.fixture +def mock_aggregator(): + """Create a mock aggregator for testing.""" + aggregator = mock.Mock() + aggregator.uuid = "test-uuid" + aggregator.federation_uuid = "fed-uuid" + aggregator.authorized_cols = ["test-collaborator"] + aggregator.single_col_cert_common_name = "test-cert-cn" + aggregator.valid_collaborator_cn_and_id = mock.Mock(return_value=True) + aggregator.get_tasks = mock.Mock(return_value=(["task1", "task2"], 1, 5, False)) + aggregator.get_aggregated_tensor = mock.Mock() + aggregator.send_local_task_results = mock.Mock() + # Disable connector mode by default + aggregator.get_interop_client = mock.Mock(return_value=None) + # Add mock for task completion tracking + aggregator._collaborator_task_completed = mock.Mock(return_value=True) + # Add mock assigner + mock_assigner = mock.Mock() + mock_assigner.get_tasks_for_collaborator = mock.Mock(return_value=[]) + aggregator.assigner = mock_assigner + # Add collaborators_done list + aggregator.collaborators_done = [] + return aggregator + + +@pytest.fixture +def ssl_certs(tmp_path): + """Create temporary SSL certificate files for testing.""" + cert_path = tmp_path / "test_cert.pem" + key_path = tmp_path / "test_key.pem" + root_path = tmp_path / "test_root.pem" + + generate_test_certificates(cert_path, key_path, root_path) + + return { + 'cert': str(cert_path), + 'key': str(key_path), + 'root': str(root_path) + } + + +@pytest.fixture +def rest_server(mock_aggregator, ssl_certs): + """Create REST server instance for testing.""" + server = AggregatorRESTServer( + aggregator=mock_aggregator, + agg_addr="localhost", + agg_port=8080, + use_tls=True, + require_client_auth=True, + certificate=ssl_certs['cert'], + private_key=ssl_certs['key'], + root_certificate=ssl_certs['root'] + ) + return server + + +class TestAggregatorRESTServer: + """Test cases for AggregatorRESTServer.""" + + def test_ssl_context_setup(self, rest_server, ssl_certs): + """Test SSL context configuration.""" + with mock.patch('ssl.SSLContext') as mock_ssl_context: + mock_context = mock.Mock() + mock_ssl_context.return_value = mock_context + # Mock the options attribute to be an integer that can handle bitwise operations + mock_context.options = 0 + + rest_server._setup_ssl_context( + certificate=ssl_certs['cert'], + private_key=ssl_certs['key'], + root_certificate=ssl_certs['root'] + ) + + mock_ssl_context.assert_called_once_with(ssl.PROTOCOL_TLS_SERVER) + + # Check that load_cert_chain was called exactly once with the expected parameters + mock_context.load_cert_chain.assert_called_once_with( + certfile=ssl_certs['cert'], + keyfile=ssl_certs['key'] + ) + + # Check that load_verify_locations was called exactly twice with the same parameters + assert mock_context.load_verify_locations.call_count == 2 + assert all( + call == mock.call(cafile=ssl_certs['root']) + for call in mock_context.load_verify_locations.call_args_list + ) + + assert mock_context.verify_mode == ssl.CERT_REQUIRED + + def test_get_tasks_valid_request(self, rest_server, mock_aggregator): + """Test successful task retrieval.""" + # Mock the get_tasks method to return proper Task objects + mock_tasks = [ + aggregator_pb2.Task(name="task1", function_name="func1", task_type="train"), + aggregator_pb2.Task(name="task2", function_name="func2", task_type="validate") + ] + mock_aggregator.get_tasks.return_value = (mock_tasks, 1, 5, True) # Set quit to True to ensure it appears in JSON + + with rest_server.app.test_client() as client: + response = client.get('experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid" + }) + + assert response.status_code == 200 + data = response.get_json() + assert data["roundNumber"] == 1 + assert len(data["tasks"]) == 2 + assert data["sleepTime"] == 5 + assert "quit" in data # Verify quit field exists + assert data["quit"] # Should be True now + + # Test with quit=False + mock_aggregator.get_tasks.return_value = (mock_tasks, 1, 5, False) + + with rest_server.app.test_client() as client: + response = client.get('experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid" + }) + + assert response.status_code == 200 + data = response.get_json() + # When quit is False (default value), it might be omitted in JSON + # So we use get() with a default value + assert not data.get("quit", False) + + def test_get_tasks_unauthorized(self, rest_server): + """Test task retrieval with unauthorized collaborator.""" + with rest_server.app.test_client() as client: + response = client.get('experimental/v1/tasks', query_string={ + "collaborator_id": "unauthorized-collaborator", + "federation_uuid": "fed-uuid" + }) + assert response.status_code == 401 + + def test_post_task_results(self, rest_server, mock_aggregator): + """Test task results submission.""" + # Create mock task results + task_results = aggregator_pb2.TaskResults() + task_results.task_name = "test_task" + task_results.round_number = 1 + task_results.data_size = 100 + + # Create mock header + task_results.header.sender = "test-collaborator" + task_results.header.receiver = str(mock_aggregator.uuid) + task_results.header.federation_uuid = str(mock_aggregator.federation_uuid) + task_results.header.single_col_cert_common_name = "test-cert-cn" + + # Add a named tensor + tensor = base_pb2.NamedTensor() + tensor.name = "test_tensor" + task_results.tensors.append(tensor) + + # Create DataStream + data_stream = base_pb2.DataStream() + data_stream.npbytes = task_results.SerializeToString() + data_stream.size = len(data_stream.npbytes) + + # Prepare request data + request_data = ( + len(data_stream.SerializeToString()).to_bytes(4, byteorder='big') + + data_stream.SerializeToString() + + (0).to_bytes(4, byteorder='big') + ) + + # Configure mock assigner to return tasks + mock_aggregator.assigner.get_tasks_for_collaborator.return_value = [ + aggregator_pb2.Task(name="test_task") + ] + + with rest_server.app.test_client() as client: + response = client.post( + 'experimental/v1/tasks/results', + data=request_data, + headers={ + "Sender": "test-collaborator", + "Receiver": str(mock_aggregator.uuid), + "Federation-UUID": str(mock_aggregator.federation_uuid), + "Single-Col-Cert-CN": "test-cert-cn" + } + ) + + assert response.status_code == 200 + mock_aggregator.send_local_task_results.assert_called_once() + + def test_get_aggregated_tensor(self, rest_server, mock_aggregator): + """Test aggregated tensor retrieval.""" + # Create mock tensor response + mock_tensor = base_pb2.NamedTensor() + mock_tensor.name = "test_tensor" + mock_aggregator.get_aggregated_tensor.return_value = mock_tensor + + with rest_server.app.test_client() as client: + response = client.get('/experimental/v1/tensors/aggregated', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid", + "tensor_name": "test_tensor", + "round_number": "1" + }) + + assert response.status_code == 200 + data = response.get_json() + assert data["roundNumber"] == 1 + assert "tensor" in data + + def test_relay_message_not_enabled(self, rest_server): + """Test relay endpoint when not enabled.""" + # Create a valid relay message + relay_msg = aggregator_pb2.InteropMessage() + relay_msg.header.sender = "test-collaborator" + relay_msg.header.receiver = str(rest_server.aggregator.uuid) + relay_msg.header.federation_uuid = str(rest_server.aggregator.federation_uuid) + + with rest_server.app.test_client() as client: + response = client.post( + '/experimental/v1/interop/relay', + json={"header": {"sender": "test-collaborator"}} + ) + assert response.status_code == 501 + + def test_invalid_federation_uuid(self, rest_server): + """Test request with invalid federation UUID.""" + with rest_server.app.test_client() as client: + response = client.get('/experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "invalid-uuid" + }) + assert response.status_code == 401 + + def test_malformed_task_results(self, rest_server): + """Test submission of malformed task results.""" + with rest_server.app.test_client() as client: + response = client.post( + 'experimental/v1/tasks/results', + data=b"invalid data", + headers={ + "Sender": "test-collaborator", + "Receiver": str(rest_server.aggregator.uuid), + "Federation-UUID": str(rest_server.aggregator.federation_uuid) + } + ) + assert response.status_code == 400 + + def test_connector_mode_tasks(self, rest_server): + """Test task retrieval in connector mode.""" + rest_server.use_connector = True + with rest_server.app.test_client() as client: + response = client.get('/experimental/v1/tasks', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid" + }) + assert response.status_code == 501 + + def test_invalid_round_number(self, rest_server): + """Test tensor retrieval with invalid round number.""" + with rest_server.app.test_client() as client: + response = client.get('/experimental/v1/tensors/aggregated', query_string={ + "collaborator_id": "test-collaborator", + "federation_uuid": "fed-uuid", + "tensor_name": "test_tensor", + "round_number": "invalid" + }) + assert response.status_code == 400