{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trialwise Decoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we will use a convolutional neural network on the [Physiobank EEG Motor Movement/Imagery Dataset](https://www.physionet.org/physiobank/database/eegmmidb/) to decode two classes:\n",
    "\n",
    "1. Executed and imagined opening and closing of both hands\n",
    "2. Executed and imagined opening and closing of both feet\n",
    "\n",
    "<div class=\"alert alert-warning\">\n",
    "\n",
    "We use only one subject (with 90 trials) in this tutorial for demonstration purposes. A more interesting decoding task with many more trials would be to do cross-subject decoding on the same dataset.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Enable logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import importlib\n",
    "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n",
    "log = logging.getLogger()\n",
    "log.setLevel('INFO')\n",
    "import sys\n",
    "\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.INFO, stream=sys.stdout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can load and preprocess your EEG dataset in any way, Braindecode only expects a 3darray (trials, channels, timesteps) of input signals `X` and a vector of labels `y` later (see below). In this tutorial, we will use the [MNE](https://www.martinos.org/mne/stable/index.html) library to load an EEG motor imagery/motor execution dataset. For a tutorial from MNE using Common Spatial Patterns to decode this data, see [here](http://martinos.org/mne/stable/auto_examples/decoding/plot_decoding_csp_eeg.html). For another library useful for loading EEG data, take a look at [Neo IO](https://pythonhosted.org/neo/io.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mne\n",
    "from mne.io import concatenate_raws\n",
    "\n",
    "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n",
    "subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)\n",
    "event_codes = [5,6,9,10,13,14]\n",
    "#event_codes = [3,4,5,6,7,8,9,10,11,12,13,14]\n",
    "\n",
    "# This will download the files if you don't have them yet,\n",
    "# and then return the paths to the files.\n",
    "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n",
    "\n",
    "# Load each of the files\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n",
    "         for path in physionet_paths]\n",
    "\n",
    "# Concatenate them\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "# Find the events in this dataset\n",
    "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Use only EEG channels\n",
    "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "# Extract trials, only using EEG channels\n",
    "epoched = mne.Epochs(raw, events, dict(hands_or_left=2, feet_or_right=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n",
    "                baseline=None, preload=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert data to Braindecode format"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Braindecode has a minimalistic ```SignalAndTarget``` class, with attributes `X` for the signal and `y` for the labels. `X` should have these dimensions: trials x channels x timesteps. `y` should have one label per trial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# Convert data from volt to millivolt\n",
    "# Pytorch expects float32 for input and int64 for labels.\n",
    "X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the first 40 trials for training and the next 30 trials for validation. The validation accuracies can be used to tune hyperparameters such as learning rate etc. The final 20 trials are split apart so we have a final hold-out evaluation set that is not part of any hyperparameter optimization. As mentioned before, this dataset is dangerously small to get any meaningful results and only used here for quick demonstration purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "\n",
    "train_set = SignalAndTarget(X[:40], y=y[:40])\n",
    "valid_set = SignalAndTarget(X[40:70], y=y[40:70])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the shallow ConvNet model from [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = False\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "n_classes = 2\n",
    "in_chans = train_set.X.shape[1]\n",
    "# final_conv_length = auto ensures we only get a single output in the time dimension\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n",
    "                        input_time_length=train_set.X.shape[2],\n",
    "                        final_conv_length='auto')\n",
    "if cuda:\n",
    "    model.cuda()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use [AdamW](https://arxiv.org/abs/1711.05101) to optimize the parameters of our network together with [Cosine Annealing](https://arxiv.org/abs/1608.03983) of the learning rate. We supply some default parameters that we have found to work well for motor decoding, however we strongly encourage you to perform your own hyperparameter optimization using cross validation on your training data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "We will now use the Braindecode model class directly to perform the training in a few lines of code. If you instead want to use your own training loop, have a look at the [Trialwise Low-Level Tutorial](./TrialWise_LowLevel.html).\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.torch_ext.optimizers import AdamW\n",
    "import torch.nn.functional as F\n",
    "#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model\n",
    "optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)\n",
    "model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-08-08 11:49:59,336 INFO : Run until first stop...\n",
      "2018-08-08 11:50:02,270 INFO : Time only for training updates: 2.93s\n",
      "2018-08-08 11:50:04,211 INFO : Epoch 0\n",
      "2018-08-08 11:50:04,216 INFO : train_loss                1.88531\n",
      "2018-08-08 11:50:04,219 INFO : valid_loss                2.01343\n",
      "2018-08-08 11:50:04,223 INFO : train_misclass            0.52500\n",
      "2018-08-08 11:50:04,226 INFO : valid_misclass            0.53333\n",
      "2018-08-08 11:50:04,230 INFO : runtime                   0.00000\n",
      "2018-08-08 11:50:04,234 INFO : \n",
      "2018-08-08 11:50:06,959 INFO : Time only for training updates: 2.72s\n",
      "2018-08-08 11:50:08,869 INFO : Epoch 1\n",
      "2018-08-08 11:50:08,875 INFO : train_loss                0.66515\n",
      "2018-08-08 11:50:08,879 INFO : valid_loss                0.78154\n",
      "2018-08-08 11:50:08,883 INFO : train_misclass            0.35000\n",
      "2018-08-08 11:50:08,887 INFO : valid_misclass            0.40000\n",
      "2018-08-08 11:50:08,893 INFO : runtime                   4.68826\n",
      "2018-08-08 11:50:08,897 INFO : \n",
      "2018-08-08 11:50:11,670 INFO : Time only for training updates: 2.77s\n",
      "2018-08-08 11:50:13,611 INFO : Epoch 2\n",
      "2018-08-08 11:50:13,613 INFO : train_loss                0.41463\n",
      "2018-08-08 11:50:13,615 INFO : valid_loss                0.55128\n",
      "2018-08-08 11:50:13,616 INFO : train_misclass            0.20000\n",
      "2018-08-08 11:50:13,617 INFO : valid_misclass            0.23333\n",
      "2018-08-08 11:50:13,619 INFO : runtime                   4.71190\n",
      "2018-08-08 11:50:13,620 INFO : \n",
      "2018-08-08 11:50:16,332 INFO : Time only for training updates: 2.71s\n",
      "2018-08-08 11:50:18,278 INFO : Epoch 3\n",
      "2018-08-08 11:50:18,283 INFO : train_loss                0.32448\n",
      "2018-08-08 11:50:18,287 INFO : valid_loss                0.49955\n",
      "2018-08-08 11:50:18,291 INFO : train_misclass            0.17500\n",
      "2018-08-08 11:50:18,295 INFO : valid_misclass            0.16667\n",
      "2018-08-08 11:50:18,300 INFO : runtime                   4.66119\n",
      "2018-08-08 11:50:18,304 INFO : \n",
      "2018-08-08 11:50:21,011 INFO : Time only for training updates: 2.71s\n",
      "2018-08-08 11:50:22,936 INFO : Epoch 4\n",
      "2018-08-08 11:50:22,940 INFO : train_loss                0.23701\n",
      "2018-08-08 11:50:22,943 INFO : valid_loss                0.44736\n",
      "2018-08-08 11:50:22,947 INFO : train_misclass            0.10000\n",
      "2018-08-08 11:50:22,950 INFO : valid_misclass            0.20000\n",
      "2018-08-08 11:50:22,953 INFO : runtime                   4.67849\n",
      "2018-08-08 11:50:22,957 INFO : \n",
      "2018-08-08 11:50:25,645 INFO : Time only for training updates: 2.69s\n",
      "2018-08-08 11:50:27,538 INFO : Epoch 5\n",
      "2018-08-08 11:50:27,542 INFO : train_loss                0.21576\n",
      "2018-08-08 11:50:27,546 INFO : valid_loss                0.45930\n",
      "2018-08-08 11:50:27,550 INFO : train_misclass            0.07500\n",
      "2018-08-08 11:50:27,555 INFO : valid_misclass            0.16667\n",
      "2018-08-08 11:50:27,558 INFO : runtime                   4.63209\n",
      "2018-08-08 11:50:27,562 INFO : \n",
      "2018-08-08 11:50:30,174 INFO : Time only for training updates: 2.61s\n",
      "2018-08-08 11:50:32,077 INFO : Epoch 6\n",
      "2018-08-08 11:50:32,082 INFO : train_loss                0.18904\n",
      "2018-08-08 11:50:32,086 INFO : valid_loss                0.45700\n",
      "2018-08-08 11:50:32,091 INFO : train_misclass            0.07500\n",
      "2018-08-08 11:50:32,095 INFO : valid_misclass            0.13333\n",
      "2018-08-08 11:50:32,099 INFO : runtime                   4.53206\n",
      "2018-08-08 11:50:32,103 INFO : \n",
      "2018-08-08 11:50:34,804 INFO : Time only for training updates: 2.70s\n",
      "2018-08-08 11:50:36,701 INFO : Epoch 7\n",
      "2018-08-08 11:50:36,703 INFO : train_loss                0.15661\n",
      "2018-08-08 11:50:36,704 INFO : valid_loss                0.44282\n",
      "2018-08-08 11:50:36,705 INFO : train_misclass            0.05000\n",
      "2018-08-08 11:50:36,706 INFO : valid_misclass            0.13333\n",
      "2018-08-08 11:50:36,707 INFO : runtime                   4.62976\n",
      "2018-08-08 11:50:36,709 INFO : \n",
      "2018-08-08 11:50:39,389 INFO : Time only for training updates: 2.68s\n",
      "2018-08-08 11:50:41,334 INFO : Epoch 8\n",
      "2018-08-08 11:50:41,339 INFO : train_loss                0.11771\n",
      "2018-08-08 11:50:41,343 INFO : valid_loss                0.44731\n",
      "2018-08-08 11:50:41,347 INFO : train_misclass            0.05000\n",
      "2018-08-08 11:50:41,351 INFO : valid_misclass            0.16667\n",
      "2018-08-08 11:50:41,354 INFO : runtime                   4.58493\n",
      "2018-08-08 11:50:41,358 INFO : \n",
      "2018-08-08 11:50:44,141 INFO : Time only for training updates: 2.78s\n",
      "2018-08-08 11:50:46,064 INFO : Epoch 9\n",
      "2018-08-08 11:50:46,068 INFO : train_loss                0.09302\n",
      "2018-08-08 11:50:46,071 INFO : valid_loss                0.45860\n",
      "2018-08-08 11:50:46,075 INFO : train_misclass            0.05000\n",
      "2018-08-08 11:50:46,078 INFO : valid_misclass            0.16667\n",
      "2018-08-08 11:50:46,082 INFO : runtime                   4.75208\n",
      "2018-08-08 11:50:46,085 INFO : \n",
      "2018-08-08 11:50:48,835 INFO : Time only for training updates: 2.75s\n",
      "2018-08-08 11:50:50,793 INFO : Epoch 10\n",
      "2018-08-08 11:50:50,797 INFO : train_loss                0.07532\n",
      "2018-08-08 11:50:50,801 INFO : valid_loss                0.47792\n",
      "2018-08-08 11:50:50,804 INFO : train_misclass            0.02500\n",
      "2018-08-08 11:50:50,808 INFO : valid_misclass            0.23333\n",
      "2018-08-08 11:50:50,811 INFO : runtime                   4.69131\n",
      "2018-08-08 11:50:50,814 INFO : \n",
      "2018-08-08 11:50:53,557 INFO : Time only for training updates: 2.74s\n",
      "2018-08-08 11:50:55,441 INFO : Epoch 11\n",
      "2018-08-08 11:50:55,445 INFO : train_loss                0.06256\n",
      "2018-08-08 11:50:55,448 INFO : valid_loss                0.49903\n",
      "2018-08-08 11:50:55,452 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:50:55,456 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:50:55,459 INFO : runtime                   4.72475\n",
      "2018-08-08 11:50:55,462 INFO : \n",
      "2018-08-08 11:50:58,206 INFO : Time only for training updates: 2.74s\n",
      "2018-08-08 11:51:00,290 INFO : Epoch 12\n",
      "2018-08-08 11:51:00,295 INFO : train_loss                0.05512\n",
      "2018-08-08 11:51:00,299 INFO : valid_loss                0.51851\n",
      "2018-08-08 11:51:00,303 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:00,307 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:00,311 INFO : runtime                   4.64953\n",
      "2018-08-08 11:51:00,314 INFO : \n",
      "2018-08-08 11:51:03,079 INFO : Time only for training updates: 2.76s\n",
      "2018-08-08 11:51:05,008 INFO : Epoch 13\n",
      "2018-08-08 11:51:05,012 INFO : train_loss                0.05148\n",
      "2018-08-08 11:51:05,015 INFO : valid_loss                0.53031\n",
      "2018-08-08 11:51:05,019 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:05,022 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:05,024 INFO : runtime                   4.87176\n",
      "2018-08-08 11:51:05,025 INFO : \n",
      "2018-08-08 11:51:07,685 INFO : Time only for training updates: 2.65s\n",
      "2018-08-08 11:51:09,548 INFO : Epoch 14\n",
      "2018-08-08 11:51:09,554 INFO : train_loss                0.04834\n",
      "2018-08-08 11:51:09,559 INFO : valid_loss                0.53809\n",
      "2018-08-08 11:51:09,564 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:09,565 INFO : valid_misclass            0.30000\n",
      "2018-08-08 11:51:09,570 INFO : runtime                   4.60554\n",
      "2018-08-08 11:51:09,572 INFO : \n",
      "2018-08-08 11:51:12,250 INFO : Time only for training updates: 2.68s\n",
      "2018-08-08 11:51:14,197 INFO : Epoch 15\n",
      "2018-08-08 11:51:14,201 INFO : train_loss                0.04612\n",
      "2018-08-08 11:51:14,205 INFO : valid_loss                0.54513\n",
      "2018-08-08 11:51:14,209 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:14,212 INFO : valid_misclass            0.30000\n",
      "2018-08-08 11:51:14,215 INFO : runtime                   4.56585\n",
      "2018-08-08 11:51:14,219 INFO : \n",
      "2018-08-08 11:51:16,869 INFO : Time only for training updates: 2.65s\n",
      "2018-08-08 11:51:18,730 INFO : Epoch 16\n",
      "2018-08-08 11:51:18,735 INFO : train_loss                0.04390\n",
      "2018-08-08 11:51:18,739 INFO : valid_loss                0.55042\n",
      "2018-08-08 11:51:18,743 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:18,747 INFO : valid_misclass            0.30000\n",
      "2018-08-08 11:51:18,751 INFO : runtime                   4.61905\n",
      "2018-08-08 11:51:18,755 INFO : \n",
      "2018-08-08 11:51:21,534 INFO : Time only for training updates: 2.78s\n",
      "2018-08-08 11:51:23,350 INFO : Epoch 17\n",
      "2018-08-08 11:51:23,356 INFO : train_loss                0.04045\n",
      "2018-08-08 11:51:23,361 INFO : valid_loss                0.55286\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-08-08 11:51:23,366 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:23,370 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:23,375 INFO : runtime                   4.66496\n",
      "2018-08-08 11:51:23,377 INFO : \n",
      "2018-08-08 11:51:26,114 INFO : Time only for training updates: 2.74s\n",
      "2018-08-08 11:51:28,027 INFO : Epoch 18\n",
      "2018-08-08 11:51:28,032 INFO : train_loss                0.03734\n",
      "2018-08-08 11:51:28,035 INFO : valid_loss                0.55090\n",
      "2018-08-08 11:51:28,038 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:28,041 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:28,045 INFO : runtime                   4.57936\n",
      "2018-08-08 11:51:28,049 INFO : \n",
      "2018-08-08 11:51:30,782 INFO : Time only for training updates: 2.73s\n",
      "2018-08-08 11:51:32,686 INFO : Epoch 19\n",
      "2018-08-08 11:51:32,691 INFO : train_loss                0.03431\n",
      "2018-08-08 11:51:32,695 INFO : valid_loss                0.54657\n",
      "2018-08-08 11:51:32,699 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:32,703 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:32,707 INFO : runtime                   4.66918\n",
      "2018-08-08 11:51:32,710 INFO : \n",
      "2018-08-08 11:51:35,408 INFO : Time only for training updates: 2.69s\n",
      "2018-08-08 11:51:37,324 INFO : Epoch 20\n",
      "2018-08-08 11:51:37,329 INFO : train_loss                0.03115\n",
      "2018-08-08 11:51:37,333 INFO : valid_loss                0.54284\n",
      "2018-08-08 11:51:37,344 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:37,348 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:37,353 INFO : runtime                   4.62276\n",
      "2018-08-08 11:51:37,356 INFO : \n",
      "2018-08-08 11:51:39,948 INFO : Time only for training updates: 2.59s\n",
      "2018-08-08 11:51:41,864 INFO : Epoch 21\n",
      "2018-08-08 11:51:41,869 INFO : train_loss                0.02841\n",
      "2018-08-08 11:51:41,873 INFO : valid_loss                0.54000\n",
      "2018-08-08 11:51:41,877 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:41,881 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:41,882 INFO : runtime                   4.54261\n",
      "2018-08-08 11:51:41,883 INFO : \n",
      "2018-08-08 11:51:44,606 INFO : Time only for training updates: 2.72s\n",
      "2018-08-08 11:51:46,509 INFO : Epoch 22\n",
      "2018-08-08 11:51:46,514 INFO : train_loss                0.02606\n",
      "2018-08-08 11:51:46,517 INFO : valid_loss                0.53694\n",
      "2018-08-08 11:51:46,521 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:46,524 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:46,527 INFO : runtime                   4.65737\n",
      "2018-08-08 11:51:46,531 INFO : \n",
      "2018-08-08 11:51:49,151 INFO : Time only for training updates: 2.62s\n",
      "2018-08-08 11:51:51,057 INFO : Epoch 23\n",
      "2018-08-08 11:51:51,061 INFO : train_loss                0.02413\n",
      "2018-08-08 11:51:51,065 INFO : valid_loss                0.53370\n",
      "2018-08-08 11:51:51,068 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:51,072 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:51,075 INFO : runtime                   4.54569\n",
      "2018-08-08 11:51:51,079 INFO : \n",
      "2018-08-08 11:51:53,810 INFO : Time only for training updates: 2.73s\n",
      "2018-08-08 11:51:55,771 INFO : Epoch 24\n",
      "2018-08-08 11:51:55,776 INFO : train_loss                0.02262\n",
      "2018-08-08 11:51:55,783 INFO : valid_loss                0.53038\n",
      "2018-08-08 11:51:55,784 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:55,791 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:55,792 INFO : runtime                   4.65646\n",
      "2018-08-08 11:51:55,799 INFO : \n",
      "2018-08-08 11:51:58,476 INFO : Time only for training updates: 2.68s\n",
      "2018-08-08 11:52:00,433 INFO : Epoch 25\n",
      "2018-08-08 11:52:00,438 INFO : train_loss                0.02136\n",
      "2018-08-08 11:52:00,441 INFO : valid_loss                0.52729\n",
      "2018-08-08 11:52:00,445 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:00,449 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:52:00,453 INFO : runtime                   4.66856\n",
      "2018-08-08 11:52:00,457 INFO : \n",
      "2018-08-08 11:52:03,160 INFO : Time only for training updates: 2.70s\n",
      "2018-08-08 11:52:05,055 INFO : Epoch 26\n",
      "2018-08-08 11:52:05,059 INFO : train_loss                0.02039\n",
      "2018-08-08 11:52:05,063 INFO : valid_loss                0.52436\n",
      "2018-08-08 11:52:05,067 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:05,070 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:52:05,073 INFO : runtime                   4.68141\n",
      "2018-08-08 11:52:05,076 INFO : \n",
      "2018-08-08 11:52:07,806 INFO : Time only for training updates: 2.73s\n",
      "2018-08-08 11:52:09,685 INFO : Epoch 27\n",
      "2018-08-08 11:52:09,689 INFO : train_loss                0.01964\n",
      "2018-08-08 11:52:09,693 INFO : valid_loss                0.52170\n",
      "2018-08-08 11:52:09,696 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:09,700 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:52:09,703 INFO : runtime                   4.64933\n",
      "2018-08-08 11:52:09,707 INFO : \n",
      "2018-08-08 11:52:12,420 INFO : Time only for training updates: 2.71s\n",
      "2018-08-08 11:52:14,408 INFO : Epoch 28\n",
      "2018-08-08 11:52:14,413 INFO : train_loss                0.01904\n",
      "2018-08-08 11:52:14,417 INFO : valid_loss                0.51934\n",
      "2018-08-08 11:52:14,421 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:14,425 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:52:14,429 INFO : runtime                   4.61259\n",
      "2018-08-08 11:52:14,433 INFO : \n",
      "2018-08-08 11:52:17,092 INFO : Time only for training updates: 2.66s\n",
      "2018-08-08 11:52:18,984 INFO : Epoch 29\n",
      "2018-08-08 11:52:18,989 INFO : train_loss                0.01858\n",
      "2018-08-08 11:52:18,994 INFO : valid_loss                0.51724\n",
      "2018-08-08 11:52:18,999 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:19,003 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:52:19,007 INFO : runtime                   4.67249\n",
      "2018-08-08 11:52:19,011 INFO : \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<braindecode.experiments.experiment.Experiment at 0x7ffae9f407f0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_set.X, train_set.y, epochs=30, batch_size=64, scheduler='cosine',\n",
    "         validation_data=(valid_set.X, valid_set.y),)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The monitored values are also stored into a pandas dataframe:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>train_misclass</th>\n",
       "      <th>valid_misclass</th>\n",
       "      <th>runtime</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1.885311</td>\n",
       "      <td>2.013432</td>\n",
       "      <td>0.525</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.665155</td>\n",
       "      <td>0.781536</td>\n",
       "      <td>0.350</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>4.688265</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.414632</td>\n",
       "      <td>0.551279</td>\n",
       "      <td>0.200</td>\n",
       "      <td>0.233333</td>\n",
       "      <td>4.711897</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.324484</td>\n",
       "      <td>0.499546</td>\n",
       "      <td>0.175</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>4.661192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.237008</td>\n",
       "      <td>0.447357</td>\n",
       "      <td>0.100</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>4.678494</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.215761</td>\n",
       "      <td>0.459302</td>\n",
       "      <td>0.075</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>4.632092</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.189035</td>\n",
       "      <td>0.457000</td>\n",
       "      <td>0.075</td>\n",
       "      <td>0.133333</td>\n",
       "      <td>4.532057</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.156614</td>\n",
       "      <td>0.442818</td>\n",
       "      <td>0.050</td>\n",
       "      <td>0.133333</td>\n",
       "      <td>4.629760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.117707</td>\n",
       "      <td>0.447310</td>\n",
       "      <td>0.050</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>4.584930</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.093025</td>\n",
       "      <td>0.458601</td>\n",
       "      <td>0.050</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>4.752083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.075319</td>\n",
       "      <td>0.477922</td>\n",
       "      <td>0.025</td>\n",
       "      <td>0.233333</td>\n",
       "      <td>4.691312</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0.062565</td>\n",
       "      <td>0.499030</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.724754</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.055121</td>\n",
       "      <td>0.518506</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.649527</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.051483</td>\n",
       "      <td>0.530313</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.871762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.048344</td>\n",
       "      <td>0.538089</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.300000</td>\n",
       "      <td>4.605540</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.046123</td>\n",
       "      <td>0.545129</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.300000</td>\n",
       "      <td>4.565845</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.043898</td>\n",
       "      <td>0.550419</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.300000</td>\n",
       "      <td>4.619048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0.040452</td>\n",
       "      <td>0.552857</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.664961</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0.037340</td>\n",
       "      <td>0.550898</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.579360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.034310</td>\n",
       "      <td>0.546568</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.669176</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.031152</td>\n",
       "      <td>0.542845</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.622759</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.028410</td>\n",
       "      <td>0.539997</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.542608</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.026065</td>\n",
       "      <td>0.536938</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.657368</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.024133</td>\n",
       "      <td>0.533703</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.545693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0.022622</td>\n",
       "      <td>0.530376</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.656460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.021355</td>\n",
       "      <td>0.527293</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.668564</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.020391</td>\n",
       "      <td>0.524360</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.681408</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.019635</td>\n",
       "      <td>0.521704</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.649333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.019043</td>\n",
       "      <td>0.519344</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.612593</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0.018577</td>\n",
       "      <td>0.517245</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>4.672487</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    train_loss  valid_loss  train_misclass  valid_misclass   runtime\n",
       "0     1.885311    2.013432           0.525        0.533333  0.000000\n",
       "1     0.665155    0.781536           0.350        0.400000  4.688265\n",
       "2     0.414632    0.551279           0.200        0.233333  4.711897\n",
       "3     0.324484    0.499546           0.175        0.166667  4.661192\n",
       "4     0.237008    0.447357           0.100        0.200000  4.678494\n",
       "5     0.215761    0.459302           0.075        0.166667  4.632092\n",
       "6     0.189035    0.457000           0.075        0.133333  4.532057\n",
       "7     0.156614    0.442818           0.050        0.133333  4.629760\n",
       "8     0.117707    0.447310           0.050        0.166667  4.584930\n",
       "9     0.093025    0.458601           0.050        0.166667  4.752083\n",
       "10    0.075319    0.477922           0.025        0.233333  4.691312\n",
       "11    0.062565    0.499030           0.000        0.266667  4.724754\n",
       "12    0.055121    0.518506           0.000        0.266667  4.649527\n",
       "13    0.051483    0.530313           0.000        0.266667  4.871762\n",
       "14    0.048344    0.538089           0.000        0.300000  4.605540\n",
       "15    0.046123    0.545129           0.000        0.300000  4.565845\n",
       "16    0.043898    0.550419           0.000        0.300000  4.619048\n",
       "17    0.040452    0.552857           0.000        0.266667  4.664961\n",
       "18    0.037340    0.550898           0.000        0.266667  4.579360\n",
       "19    0.034310    0.546568           0.000        0.266667  4.669176\n",
       "20    0.031152    0.542845           0.000        0.266667  4.622759\n",
       "21    0.028410    0.539997           0.000        0.266667  4.542608\n",
       "22    0.026065    0.536938           0.000        0.266667  4.657368\n",
       "23    0.024133    0.533703           0.000        0.266667  4.545693\n",
       "24    0.022622    0.530376           0.000        0.266667  4.656460\n",
       "25    0.021355    0.527293           0.000        0.266667  4.668564\n",
       "26    0.020391    0.524360           0.000        0.266667  4.681408\n",
       "27    0.019635    0.521704           0.000        0.266667  4.649333\n",
       "28    0.019043    0.519344           0.000        0.266667  4.612593\n",
       "29    0.018577    0.517245           0.000        0.266667  4.672487"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.epochs_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eventually, we arrive at 83.4% accuracy, so 25 from 30 trials are correctly predicted. In the [Cropped Decoding Tutorial](./Cropped_Decoding.html), we can learn how to achieve higher accuracies using cropped training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 0.2964690923690796,\n",
       " 'misclass': 0.15000000000000002,\n",
       " 'runtime': 0.0007402896881103516}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_set = SignalAndTarget(X[70:], y=y[70:])\n",
    "\n",
    "model.evaluate(test_set.X, test_set.y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also retrieve individual trial predictions as such:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(test_set.X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "If you want to try cross-subject decoding, changing the loading code to the following will perform cross-subject decoding on imagined left vs right hand closing, with 50 training and 5 validation subjects (Warning, might be very slow if you are on CPU):\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mne\n",
    "import numpy as np\n",
    "from mne.io import concatenate_raws\n",
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "\n",
    "# First 50 subjects as train\n",
    "physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]\n",
    "physionet_paths = np.concatenate(physionet_paths)\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n",
    "         for path in physionet_paths] \n",
    "\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Read epochs (train will be done only between 1 and 2s)\n",
    "# Testing will be done with a running classifier\n",
    "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,\n",
    "                baseline=None, preload=True)\n",
    "\n",
    "# 51-55 as validation subjects\n",
    "physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]\n",
    "physionet_paths_valid = np.concatenate(physionet_paths_valid)\n",
    "parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')\n",
    "         for path in physionet_paths_valid]\n",
    "raw_valid = concatenate_raws(parts_valid)\n",
    "\n",
    "picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Read epochs (train will be done only between 1 and 2s)\n",
    "# Testing will be done with a running classifier\n",
    "epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,\n",
    "                baseline=None, preload=True)\n",
    "\n",
    "train_X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n",
    "valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)\n",
    "valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n",
    "train_set = SignalAndTarget(train_X, y=train_y)\n",
    "valid_set = SignalAndTarget(valid_X, y=valid_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset references\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n",
    " \n",
    "     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n",
    "\n",
    "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n",
    "\n",
    "    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
