{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trialwise Decoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we will use a convolutional neural network on the [Physiobank EEG Motor Movement/Imagery Dataset](https://www.physionet.org/physiobank/database/eegmmidb/) to decode two classes:\n",
    "\n",
    "1. Executed and imagined opening and closing of both hands\n",
    "2. Executed and imagined opening and closing of both feet\n",
    "\n",
    "<div class=\"alert alert-warning\">\n",
    "\n",
    "We use only one subject (with 90 trials) in this tutorial for demonstration purposes. A more interesting decoding task with many more trials would be to do cross-subject decoding on the same dataset.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can load and preprocess your EEG dataset in any way, Braindecode only expects a 3darray (trials, channels, timesteps) of input signals `X` and a vector of labels `y` later (see below). In this tutorial, we will use the [MNE](https://www.martinos.org/mne/stable/index.html) library to load an EEG motor imagery/motor execution dataset. For a tutorial from MNE using Common Spatial Patterns to decode this data, see [here](http://martinos.org/mne/stable/auto_examples/decoding/plot_decoding_csp_eeg.html). For another library useful for loading EEG data, take a look at [Neo IO](https://pythonhosted.org/neo/io.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removing orphaned offset at the beginning of the file.\n",
      "179 events found\n",
      "Events id: [1 2 3]\n",
      "90 matching events found\n",
      "Loading data for 90 events and 497 original time points ...\n",
      "0 bad epochs dropped\n"
     ]
    }
   ],
   "source": [
    "import mne\n",
    "from mne.io import concatenate_raws\n",
    "\n",
    "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n",
    "subject_id = 1\n",
    "event_codes = [5,6,9,10,13,14]\n",
    "\n",
    "# This will download the files if you don't have them yet,\n",
    "# and then return the paths to the files.\n",
    "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n",
    "\n",
    "# Load each of the files\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n",
    "         for path in physionet_paths]\n",
    "\n",
    "# Concatenate them\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "# Find the events in this dataset\n",
    "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Use only EEG channels\n",
    "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "# Extract trials, only using EEG channels\n",
    "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n",
    "                baseline=None, preload=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert data to Braindecode Format"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Braindecode has a minimalistic ```SignalAndTarget``` class, with attributes `X` for the signal and `y` for the labels. `X` should have these dimensions: trials x channels x timesteps. `y` should have one label per trial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# Convert data from volt to millivolt\n",
    "# Pytorch expects float32 for input and int64 for labels.\n",
    "X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the first 60 trials for training and the last 30 for testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "\n",
    "train_set = SignalAndTarget(X[:60], y=y[:60])\n",
    "test_set = SignalAndTarget(X[60:], y=y[60:])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the shallow ConvNet model from [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = False\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "n_classes = 2\n",
    "in_chans = train_set.X.shape[1]\n",
    "# final_conv_length = auto ensures we only get a single output in the time dimension\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n",
    "                        input_time_length=train_set.X.shape[2],\n",
    "                        final_conv_length='auto').create_network()\n",
    "if cuda:\n",
    "    model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use [Adam](https://arxiv.org/abs/1412.6980) to optimize the parameters of our network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a conventional mini-batch stochastic gradient descent training loop:\n",
    "\n",
    "1. Get randomly shuffled batches of trials\n",
    "2. Compute outputs, loss and gradients on the batches of trials\n",
    "3. Update your model\n",
    "4. After iterating through all batches of your dataset, report some statistics like mean accuracy and mean loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train  Loss: 1.17760\n",
      "Train  Accuracy: 51.7%\n",
      "Test   Loss: 1.26024\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 1\n",
      "Train  Loss: 0.70688\n",
      "Train  Accuracy: 63.3%\n",
      "Test   Loss: 0.93679\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 2\n",
      "Train  Loss: 0.39426\n",
      "Train  Accuracy: 86.7%\n",
      "Test   Loss: 0.65984\n",
      "Test   Accuracy: 66.7%\n",
      "Epoch 3\n",
      "Train  Loss: 0.33593\n",
      "Train  Accuracy: 90.0%\n",
      "Test   Loss: 0.65626\n",
      "Test   Accuracy: 63.3%\n",
      "Epoch 4\n",
      "Train  Loss: 0.27905\n",
      "Train  Accuracy: 96.7%\n",
      "Test   Loss: 0.61044\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 5\n",
      "Train  Loss: 0.27319\n",
      "Train  Accuracy: 96.7%\n",
      "Test   Loss: 0.59499\n",
      "Test   Accuracy: 66.7%\n",
      "Epoch 6\n",
      "Train  Loss: 0.26624\n",
      "Train  Accuracy: 88.3%\n",
      "Test   Loss: 0.61360\n",
      "Test   Accuracy: 70.0%\n",
      "Epoch 7\n",
      "Train  Loss: 0.23338\n",
      "Train  Accuracy: 91.7%\n",
      "Test   Loss: 0.65626\n",
      "Test   Accuracy: 73.3%\n",
      "Epoch 8\n",
      "Train  Loss: 0.21903\n",
      "Train  Accuracy: 90.0%\n",
      "Test   Loss: 0.68762\n",
      "Test   Accuracy: 63.3%\n",
      "Epoch 9\n",
      "Train  Loss: 0.18902\n",
      "Train  Accuracy: 96.7%\n",
      "Test   Loss: 0.67174\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 10\n",
      "Train  Loss: 0.17774\n",
      "Train  Accuracy: 96.7%\n",
      "Test   Loss: 0.68129\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 11\n",
      "Train  Loss: 0.16396\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.69190\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 12\n",
      "Train  Loss: 0.15179\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.69641\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 13\n",
      "Train  Loss: 0.15035\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.70514\n",
      "Test   Accuracy: 60.0%\n",
      "Epoch 14\n",
      "Train  Loss: 0.14255\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.74239\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 15\n",
      "Train  Loss: 0.13597\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.77918\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 16\n",
      "Train  Loss: 0.13680\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.84159\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 17\n",
      "Train  Loss: 0.12191\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.87802\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 18\n",
      "Train  Loss: 0.10604\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.90945\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 19\n",
      "Train  Loss: 0.10445\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.93717\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 20\n",
      "Train  Loss: 0.10497\n",
      "Train  Accuracy: 98.3%\n",
      "Test   Loss: 0.95502\n",
      "Test   Accuracy: 60.0%\n",
      "Epoch 21\n",
      "Train  Loss: 0.10946\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.99562\n",
      "Test   Accuracy: 60.0%\n",
      "Epoch 22\n",
      "Train  Loss: 0.05721\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.84824\n",
      "Test   Accuracy: 60.0%\n",
      "Epoch 23\n",
      "Train  Loss: 0.04464\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.78784\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 24\n",
      "Train  Loss: 0.04181\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.78525\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 25\n",
      "Train  Loss: 0.03730\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.75642\n",
      "Test   Accuracy: 63.3%\n",
      "Epoch 26\n",
      "Train  Loss: 0.03549\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.71706\n",
      "Test   Accuracy: 66.7%\n",
      "Epoch 27\n",
      "Train  Loss: 0.03415\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.69795\n",
      "Test   Accuracy: 66.7%\n",
      "Epoch 28\n",
      "Train  Loss: 0.03205\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.69323\n",
      "Test   Accuracy: 66.7%\n",
      "Epoch 29\n",
      "Train  Loss: 0.02805\n",
      "Train  Accuracy: 100.0%\n",
      "Test   Loss: 0.69241\n",
      "Test   Accuracy: 63.3%\n"
     ]
    }
   ],
   "source": [
    "from braindecode.torch_ext.util import np_to_var, var_to_np\n",
    "from braindecode.datautil.iterators import get_balanced_batches\n",
    "import torch.nn.functional as F\n",
    "from numpy.random import RandomState\n",
    "rng = RandomState((2017,6,30))\n",
    "for i_epoch in range(30):\n",
    "    i_trials_in_batch = get_balanced_batches(len(train_set.X), rng, shuffle=True,\n",
    "                                            batch_size=30)\n",
    "    # Set model to training mode\n",
    "    model.train()\n",
    "    for i_trials in i_trials_in_batch:\n",
    "        # Have to add empty fourth dimension to X\n",
    "        batch_X = train_set.X[i_trials][:,:,:,None]\n",
    "        batch_y = train_set.y[i_trials]\n",
    "        net_in = np_to_var(batch_X)\n",
    "        if cuda:\n",
    "            net_in = net_in.cuda()\n",
    "        net_target = np_to_var(batch_y)\n",
    "        if cuda:\n",
    "            net_target = net_target.cuda()\n",
    "        # Remove gradients of last backward pass from all parameters \n",
    "        optimizer.zero_grad()\n",
    "        # Compute outputs of the network\n",
    "        outputs = model(net_in)\n",
    "        # Compute the loss\n",
    "        loss = F.nll_loss(outputs, net_target)\n",
    "        # Do the backpropagation\n",
    "        loss.backward()\n",
    "        # Update parameters with the optimizer\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Print some statistics each epoch\n",
    "    model.eval()\n",
    "    print(\"Epoch {:d}\".format(i_epoch))\n",
    "    for setname, dataset in (('Train', train_set), ('Test', test_set)):\n",
    "        # Here, we will use the entire dataset at once, which is still possible\n",
    "        # for such smaller datasets. Otherwise we would have to use batches.\n",
    "        net_in = np_to_var(dataset.X[:,:,:,None])\n",
    "        if cuda:\n",
    "            net_in = net_in.cuda()\n",
    "        net_target = np_to_var(dataset.y)\n",
    "        if cuda:\n",
    "            net_target = net_target.cuda()\n",
    "        outputs = model(net_in)\n",
    "        loss = F.nll_loss(outputs, net_target)\n",
    "        print(\"{:6s} Loss: {:.5f}\".format(\n",
    "            setname, float(var_to_np(loss))))\n",
    "        predicted_labels = np.argmax(var_to_np(outputs), axis=1)\n",
    "        accuracy = np.mean(dataset.y  == predicted_labels)\n",
    "        print(\"{:6s} Accuracy: {:.1f}%\".format(\n",
    "            setname, accuracy * 100))\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eventually, we arrive at 63.3% accuracy, so 19 from 30 trials are correctly predicted. In the [Cropped Decoding Tutorial](./Cropped_Decoding.html), we can learn how to achieve higher accuracies using cropped training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "If you want to try cross-subject decoding, changing the loading code to the following will perform cross-subject decoding on imagined left vs right hand closing, with 50 training and 5 test subjects (Warning, might be very slow):\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mne\n",
    "import numpy as np\n",
    "from mne.io import concatenate_raws\n",
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "\n",
    "physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]\n",
    "physionet_paths = np.concatenate(physionet_paths)\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n",
    "         for path in physionet_paths] \n",
    "\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Read epochs (train will be done only between 1 and 2s)\n",
    "# Testing will be done with a running classifier\n",
    "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,\n",
    "                baseline=None, preload=True)\n",
    "\n",
    "physionet_paths_test = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]\n",
    "physionet_paths_test = np.concatenate(physionet_paths_test)\n",
    "parts_test = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n",
    "         for path in physionet_paths_test]\n",
    "raw_test = concatenate_raws(parts_test)\n",
    "\n",
    "picks_test = mne.pick_types(raw_test.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "events_test = mne.find_events(raw_test, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Read epochs (train will be done only between 1 and 2s)\n",
    "# Testing will be done with a running classifier\n",
    "epoched_test = mne.Epochs(raw_test, events_test, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_test,\n",
    "                baseline=None, preload=True)\n",
    "\n",
    "train_X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n",
    "test_X = (epoched_test.get_data() * 1e6).astype(np.float32)\n",
    "test_y = (epoched_test.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n",
    "train_set = SignalAndTarget(train_X, y=train_y)\n",
    "test_set = SignalAndTarget(test_X, y=test_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset References\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n",
    " \n",
    "     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n",
    "\n",
    "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n",
    "\n",
    "    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
