{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initialize tutorial"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we perform some initial steps to set up the tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-18T08:56:30.663173Z",
     "start_time": "2018-07-18T08:56:29.825521Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import quantities as pq\n",
    "import neo\n",
    "\n",
    "import elephant.unitary_event_analysis as ue\n",
    "\n",
    "# Fix random seed to guarantee fixed output\n",
    "random.seed(1224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we download a data file containing spike train data from multiple trials of two neurons."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-18T08:56:32.142189Z",
     "start_time": "2018-07-18T08:56:31.420462Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Download data\n",
    "!wget -Nq https://github.com/INM-6/elephant-tutorial-data/raw/master/dataset-1/dataset-1.h5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "# Write a plotting function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-18T08:56:32.920355Z",
     "start_time": "2018-07-18T08:56:32.611208Z"
    },
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "def plot_UE(data,Js_dict,Js_sig,binsize,winsize,winstep, pat,N,t_winpos,**kwargs):\n",
    "    \"\"\"\n",
    "    Examples:\n",
    "    ---------\n",
    "    dict_args = {'events':{'SO':[100*pq.ms]},\n",
    "     'save_fig': True,\n",
    "     'path_filename_format':'UE1.pdf',\n",
    "     'showfig':True,\n",
    "     'suptitle':True,\n",
    "     'figsize':(12,10),\n",
    "    'unit_ids':[10, 19, 20],\n",
    "    'ch_ids':[1,3,4],\n",
    "    'fontsize':15,\n",
    "    'linewidth':2,\n",
    "    'set_xticks' :False'}\n",
    "    'marker_size':8,\n",
    "    \"\"\"\n",
    "    import matplotlib.pylab as plt\n",
    "    t_start = data[0][0].t_start\n",
    "    t_stop = data[0][0].t_stop\n",
    "\n",
    "\n",
    "    arg_dict = {'events':{},'figsize':(12,10), 'top':0.9, 'bottom':0.05, 'right':0.95,'left':0.1,\n",
    "                'hspace':0.5,'wspace':0.5,'fontsize':15,'unit_ids':range(1,N+1,1),\n",
    "                'ch_real_ids':[],'showfig':False, 'lw':2,'S_ylim':[-3,3],\n",
    "                'marker_size':8, 'suptitle':False, 'set_xticks':False,\n",
    "                'save_fig':False,'path_filename_format':'UE.pdf'}\n",
    "    arg_dict.update(kwargs)\n",
    "    \n",
    "    num_tr = len(data)\n",
    "    unit_real_ids = arg_dict['unit_ids']\n",
    "    \n",
    "    num_row = 5\n",
    "    num_col = 1\n",
    "    ls = '-'\n",
    "    alpha = 0.5\n",
    "    plt.figure(1,figsize = arg_dict['figsize'])\n",
    "    if arg_dict['suptitle'] == True:\n",
    "        plt.suptitle(\"Spike Pattern:\"+ str((pat.T)[0]),fontsize = 20)\n",
    "    print('plotting UEs ...')\n",
    "    plt.subplots_adjust(top=arg_dict['top'], right=arg_dict['right'], left=arg_dict['left']\n",
    "                        , bottom=arg_dict['bottom'], hspace=arg_dict['hspace'] , wspace=arg_dict['wspace'])\n",
    "    ax = plt.subplot(num_row,1,1)\n",
    "    ax.set_title('Unitary Events',fontsize = arg_dict['fontsize'],color = 'r')\n",
    "    for n in range(N):\n",
    "        for tr,data_tr in enumerate(data):\n",
    "            plt.plot(data_tr[n].rescale('ms').magnitude, np.ones_like(data_tr[n].magnitude)*tr + n*(num_tr + 1) + 1, '.', markersize=0.5,color = 'k')\n",
    "            sig_idx_win = np.where(Js_dict['Js']>= Js_sig)[0]\n",
    "            if len(sig_idx_win)>0:\n",
    "                x = np.unique(Js_dict['indices']['trial'+str(tr)])\n",
    "                if len(x)>0:\n",
    "                    xx = []\n",
    "                    for j in sig_idx_win:\n",
    "                        xx =np.append(xx,x[np.where((x*binsize>=t_winpos[j]) &(x*binsize<t_winpos[j] + winsize))])\n",
    "                    plt.plot(\n",
    "                        np.unique(xx)*binsize, np.ones_like(np.unique(xx))*tr + n*(num_tr + 1) + 1,\n",
    "                        ms=arg_dict['marker_size'], marker = 's', ls = '',mfc='none', mec='r')\n",
    "        plt.axhline((tr + 2)*(n+1) ,lw = 2, color = 'k')\n",
    "    y_ticks_pos = np.arange(num_tr/2 + 1,N*(num_tr+1), num_tr+1)\n",
    "    plt.yticks(y_ticks_pos)\n",
    "    plt.gca().set_yticklabels(unit_real_ids,fontsize = arg_dict['fontsize'])\n",
    "    for ch_cnt, ch_id in enumerate(arg_dict['ch_real_ids']):\n",
    "        print(ch_id)\n",
    "        plt.gca().text((max(t_winpos) + winsize).rescale('ms').magnitude,\n",
    "                       y_ticks_pos[ch_cnt],'CH-'+str(ch_id),fontsize = arg_dict['fontsize'])\n",
    "\n",
    "    plt.ylim(0, (tr + 2)*(n+1) + 1)\n",
    "    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)\n",
    "    plt.xticks([])\n",
    "    plt.ylabel('Unit ID',fontsize = arg_dict['fontsize'])\n",
    "    for key in arg_dict['events'].keys():\n",
    "        for e_val in arg_dict['events'][key]:\n",
    "            plt.axvline(e_val,ls = ls,color = 'r',lw = 2,alpha = alpha)\n",
    "    if arg_dict['set_xticks'] == False:\n",
    "        plt.xticks([])\n",
    "    print('plotting Raw Coincidences ...')\n",
    "    ax1 = plt.subplot(num_row,1,2,sharex = ax)\n",
    "    ax1.set_title('Raw Coincidences',fontsize = 20,color = 'c')\n",
    "    for n in range(N):\n",
    "        for tr,data_tr in enumerate(data):\n",
    "            plt.plot(data_tr[n].rescale('ms').magnitude,\n",
    "                     np.ones_like(data_tr[n].magnitude)*tr + n*(num_tr + 1) + 1,\n",
    "                     '.', markersize=0.5, color = 'k')\n",
    "            plt.plot(\n",
    "                np.unique(Js_dict['indices']['trial'+str(tr)])*binsize,\n",
    "                np.ones_like(np.unique(Js_dict['indices']['trial'+str(tr)]))*tr + n*(num_tr + 1) + 1,\n",
    "                ls = '',ms=arg_dict['marker_size'], marker = 's', markerfacecolor='none', markeredgecolor='c')\n",
    "        plt.axhline((tr + 2)*(n+1) ,lw = 2, color = 'k')\n",
    "    plt.ylim(0, (tr + 2)*(n+1) + 1)\n",
    "    plt.yticks(np.arange(num_tr/2 + 1,N*(num_tr+1), num_tr+1))\n",
    "    plt.gca().set_yticklabels(unit_real_ids,fontsize = arg_dict['fontsize'])\n",
    "    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)\n",
    "    plt.xticks([])\n",
    "    plt.ylabel('Unit ID',fontsize = arg_dict['fontsize'])\n",
    "    for key in arg_dict['events'].keys():\n",
    "        for e_val in arg_dict['events'][key]:\n",
    "            plt.axvline(e_val,ls = ls,color = 'r',lw = 2,alpha = alpha)\n",
    "\n",
    "    print('plotting PSTH ...')\n",
    "    plt.subplot(num_row,1,3,sharex=ax)\n",
    "    #max_val_psth = 0.*pq.Hz\n",
    "    for n in range(N):\n",
    "        plt.plot(t_winpos + winsize/2.,Js_dict['rate_avg'][:,n].rescale('Hz'),\n",
    "                 label = 'unit '+str(arg_dict['unit_ids'][n]),lw = arg_dict['lw'])\n",
    "    plt.ylabel('Rate [Hz]',fontsize = arg_dict['fontsize'])\n",
    "    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)\n",
    "    max_val_psth = plt.gca().get_ylim()[1]\n",
    "    plt.ylim(0, max_val_psth)\n",
    "    plt.yticks([0, int(max_val_psth/2),int(max_val_psth)],fontsize = arg_dict['fontsize'])\n",
    "    plt.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)\n",
    "    for key in arg_dict['events'].keys():\n",
    "        for e_val in arg_dict['events'][key]:\n",
    "            plt.axvline(e_val,ls = ls,color = 'r',lw = arg_dict['lw'],alpha = alpha)\n",
    "\n",
    "    if arg_dict['set_xticks'] == False:\n",
    "        plt.xticks([])\n",
    "    print( 'plotting emp. and exp. coincidences rate ...')\n",
    "    plt.subplot(num_row,1,4,sharex=ax)\n",
    "    plt.plot(t_winpos + winsize/2.,Js_dict['n_emp'],label = 'empirical',lw = arg_dict['lw'],color = 'c')\n",
    "    plt.plot(t_winpos + winsize/2.,Js_dict['n_exp'],label = 'expected',lw = arg_dict['lw'],color = 'm')\n",
    "    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)\n",
    "    plt.ylabel('# Coinc.',fontsize = arg_dict['fontsize'])\n",
    "    plt.legend(bbox_to_anchor=(1.12, 1.05), fancybox=True, shadow=True)\n",
    "    YTicks = plt.ylim(0,int(max(max(Js_dict['n_emp']), max(Js_dict['n_exp']))))\n",
    "    plt.yticks([0,YTicks[1]],fontsize = arg_dict['fontsize'])\n",
    "    for key in arg_dict['events'].keys():\n",
    "        for e_val in arg_dict['events'][key]:\n",
    "            plt.axvline(e_val,ls = ls,color = 'r',lw = 2,alpha = alpha)\n",
    "    if arg_dict['set_xticks'] == False:\n",
    "        plt.xticks([])\n",
    "\n",
    "    print('plotting Surprise ...')\n",
    "    plt.subplot(num_row,1,5,sharex=ax)\n",
    "    plt.plot(t_winpos + winsize/2., Js_dict['Js'],lw = arg_dict['lw'],color = 'k')\n",
    "    plt.xlim(0, (max(t_winpos) + winsize).rescale('ms').magnitude)\n",
    "    plt.axhline(Js_sig,ls = '-', color = 'gray')\n",
    "    plt.axhline(-Js_sig,ls = '-', color = 'gray')\n",
    "    plt.xticks(t_winpos.magnitude[::int(len(t_winpos)/10)])\n",
    "    plt.yticks([-2,0,2],fontsize = arg_dict['fontsize'])\n",
    "    plt.ylabel('S',fontsize = arg_dict['fontsize'])\n",
    "    plt.xlabel('Time [ms]', fontsize = arg_dict['fontsize'])\n",
    "    plt.ylim(arg_dict['S_ylim'])\n",
    "    for key in arg_dict['events'].keys():\n",
    "        for e_val in arg_dict['events'][key]:\n",
    "            plt.axvline(e_val,ls = ls,color = 'r',lw = arg_dict['lw'],alpha = alpha)\n",
    "            plt.gca().text(e_val - 10*pq.ms,2*arg_dict['S_ylim'][0],key,fontsize = arg_dict['fontsize'],color = 'r')\n",
    "\n",
    "    if arg_dict['set_xticks'] == False:\n",
    "        plt.xticks([])\n",
    "\n",
    "    if arg_dict['save_fig'] == True:\n",
    "        plt.savefig(arg_dict['path_filename_format'])\n",
    "        if arg_dict['showfig'] == False:\n",
    "            plt.cla()\n",
    "            plt.close()\n",
    "\n",
    "    if arg_dict['showfig'] == True:\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data and extract spiketrains"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-18T08:56:33.836628Z",
     "start_time": "2018-07-18T08:56:33.647488Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "block = neo.io.NeoHdf5IO(\"./dataset-1.h5\")\n",
    "sts1 = block.read_block().segments[0].spiketrains\n",
    "sts2 = block.read_block().segments[1].spiketrains\n",
    "spiketrains = np.vstack((sts1,sts2)).T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculate Unitary Events"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-18T08:56:37.042743Z",
     "start_time": "2018-07-18T08:56:34.926673Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "UE = ue.unitary_event_analysis(\n",
    "    spiketrains, bin_size=5*pq.ms, window_size=100*pq.ms, window_step=10*pq.ms, pattern_hash=[3])\n",
    "\n",
    "plot_UE(\n",
    "        spiketrains, UE, ue.jointJ(0.05),binsize=5*pq.ms,winsize=100*pq.ms,winstep=10*pq.ms,\n",
    "        pat=ue.inverse_hash_from_pattern([3], N=2), N=2,\n",
    "        t_winpos=ue._winpos(0*pq.ms,spiketrains[0][0].t_stop,winsize=100*pq.ms,winstep=10*pq.ms))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [stdpy2]",
   "language": "python",
   "name": "Python [stdpy2]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
