blazefl.core.ProcessPoolClientTrainer#
- class blazefl.core.ProcessPoolClientTrainer(*args, **kwargs)[source]#
Bases:
BaseClientTrainer
[UplinkPackage
,DownlinkPackage
],Protocol
[UplinkPackage
,DownlinkPackage
,ClientConfig
]Abstract base class for parallel client training in federated learning.
This class extends SerialClientTrainer to enable parallel processing of clients, allowing multiple clients to be trained concurrently.
- num_parallels#
Number of parallel processes to use for client training.
- Type:
int
Directory path for sharing data between processes.
- Type:
Path
- device#
The primary device to use for computation (e.g., “cpu”, “cuda”).
- Type:
str
- device_count#
The number of available CUDA devices, if device is “cuda”.
- Type:
int
- cache#
Cache to store uplink packages from clients.
- Type:
list[UplinkPackage]
- ipc_mode#
Inter-process communication mode. “storage” uses disk for data exchange, “shared_memory” uses shared memory for tensor data. Defaults to “storage”.
- Type:
Literal[“storage”, “shared_memory”]
- Raises:
NotImplementedError – If the abstract methods are not implemented in a subclass.
- __init__(*args, **kwargs)#
Methods
__init__
(*args, **kwargs)get_client_config
(cid)Retrieve the configuration for a given client ID.
get_client_device
(cid)Retrieve the device to use for processing a given client.
local_process
(payload, cid_list)Manage the parallel processing of clients.
progress_fn
(it)A no-op progress function that can be overridden to provide custom progress tracking.
uplink_package
()Prepare the data package to be sent from the client to the server.
worker
(config, payload, device, stop_event)Process a single client's training task.
Attributes
- cache: list[UplinkPackage]#
- device: str#
- device_count: int#
- get_client_config(cid: int) ClientConfig [source]#
Retrieve the configuration for a given client ID.
- Parameters:
cid (int) – Client ID.
- Returns:
The configuration for the specified client.
- Return type:
ClientConfig
- get_client_device(cid: int) str [source]#
Retrieve the device to use for processing a given client.
- Parameters:
cid (int) – Client ID.
- Returns:
The device to use for processing the client.
- Return type:
str
- ipc_mode: Literal['storage', 'shared_memory'] = 'storage'#
- local_process(payload: DownlinkPackage, cid_list: list[int]) None [source]#
Manage the parallel processing of clients.
This method distributes the processing of multiple clients across parallel processes, handling data saving, loading, and caching.
- Parameters:
payload (DownlinkPackage) – The data package received from the server.
cid_list (list[int]) – A list of client IDs to process.
- Returns:
None
- num_parallels: int#
- progress_fn(it: list[ApplyResult]) Iterable[ApplyResult] [source]#
A no-op progress function that can be overridden to provide custom progress tracking.
- Parameters:
it (list[ApplyResult]) – A list of ApplyResult objects.
- Returns:
The original iterable.
- Return type:
Iterable[ApplyResult]
- share_dir: Path#
- stop_event: Event#
- static worker(config: ClientConfig | Path, payload: DownlinkPackage | Path, device: str, stop_event: Event) UplinkPackage | Path [source]#
Process a single client’s training task.
This method is executed by each worker process in the pool. It handles loading client configuration and payload, performing the client-specific operations, and returning the result.
- Parameters:
config (ClientConfig | Path) – The client’s configuration data, or a path to a file containing the configuration if ipc_mode is “storage”.
payload (DownlinkPackage | Path) – The downlink payload from the server, or a path to a file containing the payload if ipc_mode is “storage”.
device (str) – Device to use for processing (e.g., “cpu”, “cuda:0”).
- Returns:
The uplink package containing the client’s results, or a path to a file containing the package if ipc_mode is “storage”.
- Return type:
UplinkPackage | Path