{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:04<00:00,  1.07it/s]\n"
     ]
    }
   ],
   "source": [
    "from fedsim.data_manager.basic_data_manager import BasicDataManager\n",
    "from fedsim.fl.algorithms.fedavg import FedAvg\n",
    "from fedsim.models.mcmahan_nets import cnn_cifar100\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "n_clients = 500\n",
    "\n",
    "dm = BasicDataManager('./data', 'cifar100', n_clients)\n",
    "sw = SummaryWriter()\n",
    "\n",
    "alg = FedAvg(\n",
    "    data_manager=dm,\n",
    "    num_clients=n_clients,\n",
    "    sample_scheme='uniform',\n",
    "    sample_rate=0.01,\n",
    "    model_class=cnn_cifar100,\n",
    "    epochs=5,\n",
    "    loss_fn='ce',\n",
    "    batch_size=32,\n",
    "    metric_logger=sw,\n",
    "    device='cuda',\n",
    "\n",
    ")\n",
    "\n",
    "alg.train(rounds=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The tensorboard extension is already loaded. To reload it, use:\n",
      "  %reload_ext tensorboard\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Reusing TensorBoard on port 6006 (pid 440127), started 0:00:14 ago. (Use '!kill 440127' to kill it.)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "      <iframe id=\"tensorboard-frame-dc3383836b9f15c4\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
       "      </iframe>\n",
       "      <script>\n",
       "        (function() {\n",
       "          const frame = document.getElementById(\"tensorboard-frame-dc3383836b9f15c4\");\n",
       "          const url = new URL(\"/\", window.location);\n",
       "          const port = 6006;\n",
       "          if (port) {\n",
       "            url.port = port;\n",
       "          }\n",
       "          frame.src = url;\n",
       "        })();\n",
       "      </script>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%load_ext tensorboard \n",
    "%tensorboard --logdir=runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "cc2d49e75558cae02d4ed89a9193209eda341c0ac37e391b9bd156c453375c2f"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
