blazefl.core.ProcessPoolClientTrainer#

class blazefl.core.ProcessPoolClientTrainer(*args, **kwargs)[source]#

Bases: BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage, ClientConfig]

Abstract base class for parallel client training in federated learning.

This class extends SerialClientTrainer to enable parallel processing of clients, allowing multiple clients to be trained concurrently.

num_parallels#

Number of parallel processes to use for client training.

Type:

int

share_dir#

Directory path for sharing data between processes.

Type:

Path

device#

The primary device to use for computation (e.g., “cpu”, “cuda”).

Type:

str

device_count#

The number of available CUDA devices, if device is “cuda”.

Type:

int

cache#

Cache to store uplink packages from clients.

Type:

list[UplinkPackage]

ipc_mode#

Inter-process communication mode. “storage” uses disk for data exchange, “shared_memory” uses shared memory for tensor data. Defaults to “storage”.

Type:

Literal[“storage”, “shared_memory”]

Raises:

NotImplementedError – If the abstract methods are not implemented in a subclass.

__init__(*args, **kwargs)#

Methods

__init__(*args, **kwargs)

get_client_config(cid)

Retrieve the configuration for a given client ID.

get_client_device(cid)

Retrieve the device to use for processing a given client.

local_process(payload, cid_list)

Manage the parallel processing of clients.

progress_fn(it)

A no-op progress function that can be overridden to provide custom progress tracking.

uplink_package()

Prepare the data package to be sent from the client to the server.

worker(config, payload, device, stop_event)

Process a single client's training task.

Attributes

cache: list[UplinkPackage]#
device: str#
device_count: int#
get_client_config(cid: int) ClientConfig[source]#

Retrieve the configuration for a given client ID.

Parameters:

cid (int) – Client ID.

Returns:

The configuration for the specified client.

Return type:

ClientConfig

get_client_device(cid: int) str[source]#

Retrieve the device to use for processing a given client.

Parameters:

cid (int) – Client ID.

Returns:

The device to use for processing the client.

Return type:

str

ipc_mode: Literal['storage', 'shared_memory'] = 'storage'#
local_process(payload: DownlinkPackage, cid_list: list[int]) None[source]#

Manage the parallel processing of clients.

This method distributes the processing of multiple clients across parallel processes, handling data saving, loading, and caching.

Parameters:
  • payload (DownlinkPackage) – The data package received from the server.

  • cid_list (list[int]) – A list of client IDs to process.

Returns:

None

num_parallels: int#
progress_fn(it: list[ApplyResult]) Iterable[ApplyResult][source]#

A no-op progress function that can be overridden to provide custom progress tracking.

Parameters:

it (list[ApplyResult]) – A list of ApplyResult objects.

Returns:

The original iterable.

Return type:

Iterable[ApplyResult]

share_dir: Path#
stop_event: Event#
static worker(config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str, stop_event: Event) UplinkPackage | Path[source]#

Process a single client’s training task.

This method is executed by each worker process in the pool. It handles loading client configuration and payload, performing the client-specific operations, and returning the result.

Parameters:
  • config (ClientConfig | Path) – The client’s configuration data, or a path to a file containing the configuration if ipc_mode is “storage”.

  • payload (DownlinkPackage | Path) – The downlink payload from the server, or a path to a file containing the payload if ipc_mode is “storage”.

  • device (str) – Device to use for processing (e.g., “cpu”, “cuda:0”).

Returns:

The uplink package containing the client’s results, or a path to a file containing the package if ipc_mode is “storage”.

Return type:

UplinkPackage | Path