Customization
==================


DataManager
------------------

Any custome DataManager class should inherit from ``fedsim.data_manager.data_manager.DataManager`` (or its children) and implement its abstract methods. For example:

.. code-block:: python

   from fedsim.data_manager.data_manager import DataManager

   class CustomDataManager(DataManager)
       def __init__(self, root, other_arg, ...):
           self.other_arg = other_arg
           # note that super should be called at the end of init \
           # because the abstract classes are called in its __init__
           super(BasicDataManager, self).__init__(root, seed, save_path=save_path)

       def make_datasets(self, root: str) -> Iterable[Dict[str, object]]:
           """Abstract method to be implemented by child class.

           Args:
               dataset_name (str): name of the dataset.
               root (str): directory to download and manipulate data.
               save_path (str): directory to store the data after partitioning.

           Raises:
               NotImplementedError: if the dataset_name is not defined

           Returns:
               Iterable[Dict[str, object]]: dict of local datasets [split:dataset]
                                            followed by global ones.
           """
           raise NotImplementedError


       def partition_local_data(self, datasets: Dict[str, object]) -> Dict[str, Iterable[Iterable[int]]]:
           raise NotImplementedError


       def get_identifiers(self) -> Sequence[str]:
           """ Returns identifiers 
               to be used for saving the partition info.

           Raises:
               NotImplementedError: this abstract method should be 
                   implemented by child classes

           Returns:
               Sequence[str]: a sequence of str identifing class instance 
           """
           raise NotImplementedError



Integration with included cli
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To automatically include your custom data manager in the provided cli tool, you can place your class in a file under ``fedsim/data_manager``. Then, call it using option ``--data-manager``. To deliver arguments to the ``__init__`` method of your custom data manager, you can pass options in form of ``--d-<arg-name>`` where ``<arg-name>`` is the argument. Example

.. code-block:: bash

   fedsim fed-learn --data-manager CustomDataManager --d-other_arg <other_arg_value> ...




FLAlgorithm
-----------

Any custome DataManager class should inherit from ``fedsim.fl.fl_algorithm.FLAlgorithm`` (or its children) and implement its abstract methods. For example:

.. code-block:: python

   from typing import Optional, Hashable, Mapping, Dict, Any
   from fedsim.fl.fl_algorithm import FLAlgorithm

   class CustomFLAlgorithm(FLAlgorithm):
       def __init__(
           self, data_manager, num_clients, sample_scheme, sample_rate, model_class, epochs, loss_fn,
           batch_size, test_batch_size, local_weight_decay, slr, clr, clr_decay, clr_decay_type, 
           min_clr, clr_step_size, metric_logger, device, log_freq, other_arg, ... , *args, **kwargs,
       ):
           self.other_arg = other_arg

           super(FedAvg, self).__init__(
               data_manager, num_clients, sample_scheme, sample_rate, model_class, epochs, loss_fn,
               batch_size, test_batch_size, local_weight_decay, slr, clr, clr_decay, clr_decay_type, 
               min_clr, clr_step_size, metric_logger, device, log_freq,
           )
           # make mode and optimizer
           model = self.get_model_class()().to(self.device)
           params = deepcopy(
               parameters_to_vector(model.parameters()).clone().detach())
           optimizer = SGD(params=[params], lr=slr)
           # write model and optimizer to server
           self.write_server('model', model)
           self.write_server('cloud_params', params)
           self.write_server('optimizer', optimizer)
           ...

       def send_to_client(self, client_id: int) -> Mapping[Hashable, Any]:
           """ returns context to send to the client corresponding to client_id.
               Do not send shared objects like server model if you made any 
               before you deepcopy it.

           Args:
               client_id (int): id of the receiving client

           Raises:
               NotImplementedError: abstract class to be implemented by child

           Returns:
               Mapping[Hashable, Any]: the context to be sent in form of a Mapping
           """
           raise NotImplementedError

       def send_to_server( 
           self, client_id: int, datasets: Dict[str, Iterable], epochs: int, loss_fn: nn.Module,
           batch_size: int, lr: float, weight_decay: float = 0, device: Union[int, str] = 'cuda',
           ctx: Optional[Dict[Hashable, Any]] = None, *args, **kwargs
       ) -> Mapping[str, Any]:
           """ client operation on the recieved information.

           Args:
               client_id (int): id of the client
               datasets (Dict[str, Iterable]): this comes from Data Manager
               epochs (int): number of epochs to train
               loss_fn (nn.Module): either 'ce' (for cross-entropy) or 'mse'
               batch_size (int): training batch_size
               lr (float): client learning rate
               weight_decay (float, optional): weight decay for SGD. Defaults to 0.
               device (Union[int, str], optional): Defaults to 'cuda'.
               ctx (Optional[Dict[Hashable, Any]], optional): context reveived from server. Defaults to None.

           Raises:
               NotImplementedError: abstract class to be implemented by child

           Returns:
               Mapping[str, Any]: client context to be sent to the server
           """
           raise NotImplementedError

       def receive_from_client(self, client_id: int, client_msg: Mapping[Hashable, Any], aggregator: Any):
           """ receive and aggregate info from selected clients 

           Args:
               client_id (int): id of the sender (client)
               client_msg (Mapping[Hashable, Any]): client context that is sent
               aggregator (Any): aggregator instance to collect info

           Raises:
               NotImplementedError: abstract class to be implemented by child
           """
           raise NotImplementedError

       def optimize(self, aggregator: Any) -> Mapping[Hashable, Any]:
           """ optimize server mdoel(s) and return metrics to be reported

           Args:
               aggregator (Any): Aggregator instance

           Raises:
               NotImplementedError: abstract class to be implemented by child

           Returns:
               Mapping[Hashable, Any]: context to be reported
           """
           raise NotImplementedError

       def deploy(self) -> Optional[Mapping[Hashable, Any]]:
           """ return Mapping of name -> parameters_set to test the model

           Raises:
               NotImplementedError: abstract class to be implemented by child
           """
           raise NotImplementedError

       def report(
           self, dataloaders, metric_logger: Any, device: str, optimize_reports: Mapping[Hashable, Any],
           deployment_points: Optional[Mapping[Hashable, torch.Tensor]] = None
       ) -> None:
           """test on global data and report info

           Args:
               dataloaders (Any): dict of data loaders to test the global model(s)
               metric_logger (Any): the logging object (e.g., SummaryWriter)
               device (str): 'cuda', 'cpu' or gpu number
               optimize_reports (Mapping[Hashable, Any]): dict returned by optimzier
               deployment_points (Mapping[Hashable, torch.Tensor], optional): output of deploy method

           Raises:
               NotImplementedError: abstract class to be implemented by child
           """
           raise NotImplementedError

Integration with included cli
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To automatically include your custom algorithm by the provided cli tool, you can place your class in a file under fedsim/fl/algorithms. Then, call it using option --algorithm. To deliver arguments to the **init** method of your custom algorithm, you can pass options in form of `--a-<arg-name>` where `<arg-name>` is the argument. Example

.. code-block:: bash

   fedsim fed-learn --algorithm CustomFLAlgorithm --a-other_arg <other_arg_value> ...