blazefl.core.ProcessPoolClientTrainer#

class blazefl.core.ProcessPoolClientTrainer(*args, **kwargs)[source]#

Bases: BaseClientTrainer[UplinkPackage, DownlinkPackage], Protocol[UplinkPackage, DownlinkPackage, ClientConfig]

Abstract base class for parallel client training using a process pool.

This class enables parallel processing of clients by distributing tasks across multiple processes.

num_parallels#

Number of parallel processes to use.

Type:

int

device#

Primary device for computation (e.g., “cpu”, “cuda”).

Type:

str

device_count#

Number of available CUDA devices for distribution.

Type:

int

cache#

Cache to store results from clients.

Type:

list[UplinkPackage]

stop_event#

Event to signal workers to stop.

Type:

threading.Event

Raises:

NotImplementedError – If the abstract methods are not implemented in a subclass.

__init__(*args, **kwargs)#

Methods

__init__(*args, **kwargs)

get_client_config(cid)

Retrieve the configuration for a given client ID.

get_client_device(cid)

Retrieve the device to use for processing a given client.

local_process(payload, cid_list)

Manage the parallel processing of clients.

prepare_uplink_package_buffer()

progress_fn(it)

A no-op progress function that can be overridden to provide custom progress tracking.

uplink_package()

Prepare the data package to be sent from the client to the server.

worker(config, payload, device, stop_event, *)

Process a single client's training task.

Attributes

cache: list[UplinkPackage]#
device: str#
device_count: int#
get_client_config(cid: int) ClientConfig[source]#

Retrieve the configuration for a given client ID.

Parameters:

cid (int) – Client ID.

Returns:

The configuration for the specified client.

Return type:

ClientConfig

get_client_device(cid: int) str[source]#

Retrieve the device to use for processing a given client.

Parameters:

cid (int) – Client ID.

Returns:

The device to use for processing the client.

Return type:

str

local_process(payload: DownlinkPackage, cid_list: list[int]) None[source]#

Manage the parallel processing of clients.

This method distributes the processing of multiple clients across parallel processes, handling data saving, loading, and caching.

Parameters:
  • payload (DownlinkPackage) – The data package received from the server.

  • cid_list (list[int]) – A list of client IDs to process.

Returns:

None

num_parallels: int#
progress_fn(it: list[ApplyResult]) Iterable[ApplyResult][source]#

A no-op progress function that can be overridden to provide custom progress tracking.

Parameters:

it (list[ApplyResult]) – A list of ApplyResult objects.

Returns:

The original iterable.

Return type:

Iterable[ApplyResult]

stop_event: Event#
static worker(config: ClientConfig, payload: DownlinkPackage, device: str, stop_event: Event, *, shm_buffer: UplinkPackage | None = None) UplinkPackage[source]#

Process a single client’s training task.

This method is executed by each worker process in the pool. It handles loading client configuration and payload, performing the client-specific operations, and returning the result.

Parameters:
  • config (ClientConfig) – The client’s configuration data.

  • payload (DownlinkPackage) – The downlink payload from the server

  • device (str) – Device to use for processing (e.g., “cpu”, “cuda:0”).

  • stop_event (threading.Event) – Event to signal stopping the worker.

  • shm_buffer (UplinkPackage | None) – Optional shared memory buffer for the uplink package.

Returns:

The uplink package containing the client’s results.

Return type:

UplinkPackage