blazefl.core.ProcessPoolClientTrainer#
- class blazefl.core.ProcessPoolClientTrainer(*args, **kwargs)[source]#
Bases:
BaseClientTrainer[UplinkPackage,DownlinkPackage],Protocol[UplinkPackage,DownlinkPackage,ClientConfig]Abstract base class for parallel client training using a process pool.
This class enables parallel processing of clients by distributing tasks across multiple processes.
- num_parallels#
Number of parallel processes to use.
- Type:
int
- device#
Primary device for computation (e.g., “cpu”, “cuda”).
- Type:
str
- device_count#
Number of available CUDA devices for distribution.
- Type:
int
- cache#
Cache to store results from clients.
- Type:
list[UplinkPackage]
- stop_event#
Event to signal workers to stop.
- Type:
threading.Event
- Raises:
NotImplementedError – If the abstract methods are not implemented in a subclass.
- __init__(*args, **kwargs)#
Methods
__init__(*args, **kwargs)get_client_config(cid)Retrieve the configuration for a given client ID.
get_client_device(cid)Retrieve the device to use for processing a given client.
local_process(payload, cid_list)Manage the parallel processing of clients.
progress_fn(it)A no-op progress function that can be overridden to provide custom progress tracking.
uplink_package()Prepare the data package to be sent from the client to the server.
worker(config, payload, device, stop_event, *)Process a single client's training task.
Attributes
- cache: list[UplinkPackage]#
- device: str#
- device_count: int#
- get_client_config(cid: int) ClientConfig[source]#
Retrieve the configuration for a given client ID.
- Parameters:
cid (int) – Client ID.
- Returns:
The configuration for the specified client.
- Return type:
ClientConfig
- get_client_device(cid: int) str[source]#
Retrieve the device to use for processing a given client.
- Parameters:
cid (int) – Client ID.
- Returns:
The device to use for processing the client.
- Return type:
str
- local_process(payload: DownlinkPackage, cid_list: list[int]) None[source]#
Manage the parallel processing of clients.
This method distributes the processing of multiple clients across parallel processes, handling data saving, loading, and caching.
- Parameters:
payload (DownlinkPackage) – The data package received from the server.
cid_list (list[int]) – A list of client IDs to process.
- Returns:
None
- num_parallels: int#
- progress_fn(it: list[ApplyResult]) Iterable[ApplyResult][source]#
A no-op progress function that can be overridden to provide custom progress tracking.
- Parameters:
it (list[ApplyResult]) – A list of ApplyResult objects.
- Returns:
The original iterable.
- Return type:
Iterable[ApplyResult]
- stop_event: Event#
- static worker(config: ClientConfig, payload: DownlinkPackage, device: str, stop_event: Event, *, shm_buffer: UplinkPackage | None = None) UplinkPackage[source]#
Process a single client’s training task.
This method is executed by each worker process in the pool. It handles loading client configuration and payload, performing the client-specific operations, and returning the result.
- Parameters:
config (ClientConfig) – The client’s configuration data.
payload (DownlinkPackage) – The downlink payload from the server
device (str) – Device to use for processing (e.g., “cpu”, “cuda:0”).
stop_event (threading.Event) – Event to signal stopping the worker.
shm_buffer (UplinkPackage | None) – Optional shared memory buffer for the uplink package.
- Returns:
The uplink package containing the client’s results.
- Return type:
UplinkPackage