{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cropped Decoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will use cropped decoding. Cropped decoding means the ConvNet is trained on time windows/time crops within the trials. We will explain this visually by comparing trialwise to cropped decoding.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Trialwise Decoding | Cropped Decoding\n",
    "- | - \n",
    "![Trialwise Decoding](./trialwise_explanation.png \"Trialwise Decoding\") | ![Cropped Decoding](./cropped_explanation.png \"Cropped Decoding\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the left, you see trialwise decoding:\n",
    "\n",
    "1. A complete trial is pushed through the network\n",
    "2. The network produces a prediction\n",
    "3. The prediction is compared to the target (label) for that trial to compute the loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the right, you see cropped decoding:\n",
    "\n",
    "1. Instead of a complete trial, windows within the trial, here called *crops*, are pushed through the network\n",
    "2. For computational efficiency, multiple neighbouring crops are pushed through the network simultaneously (these neighbouring crops are called a *supercrop*)\n",
    "3. Therefore, the network produces multiple predictions (one per crop in the supercrop)\n",
    "4. The individual crop predictions are averaged before computing the loss function\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notes:\n",
    "\n",
    "* The network architecture implicitly defines the crop size (it is the receptive field size, i.e., the number of timesteps the network uses to make a single prediction)\n",
    "* The supercrop size is a user-defined hyperparameter, called `input_time_length` in Braindecode. It mostly affects runtime (larger supercrop sizes should be faster). As a rule of thumb, you can set it to two times the crop size.\n",
    "* Crop size and supercrop size together define how many predictions the network makes per supercrop:  $\\mathrm{\\#supercrop}-\\mathrm{\\#crop}+1=\\mathrm{\\#predictions}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For cropped decoding, the above training setup is mathematically identical to sampling crops in your dataset, pushing them through the network and training directly on the individual crops. At the same time, the above training setup is much faster as it avoids redundant computations by using dilated convolutions, see our paper [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051). However, the two setups are only mathematically identical in case (1) your network does not use any padding and (2) your loss function leads to the same gradients when using the averaged output. The first is true for our shallow and deep ConvNet models and the second is true for the log-softmax outputs and negative log likelihood loss that is typically used for classification in PyTorch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most of the code for cropped decoding is identical to the [Trialwise Decoding Tutorial](Trialwise_Decoding.html), differences are explained in the text."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Enable logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import importlib\n",
    "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n",
    "log = logging.getLogger()\n",
    "log.setLevel('INFO')\n",
    "import sys\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.INFO, stream=sys.stdout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mne\n",
    "from mne.io import concatenate_raws\n",
    "\n",
    "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n",
    "subject_id = 22\n",
    "event_codes = [5,6,9,10,13,14]\n",
    "#event_codes = [3,4,5,6,7,8,9,10,11,12,13,14]\n",
    "\n",
    "# This will download the files if you don't have them yet,\n",
    "# and then return the paths to the files.\n",
    "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n",
    "\n",
    "# Load each of the files\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n",
    "         for path in physionet_paths]\n",
    "\n",
    "# Concatenate them\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "# Find the events in this dataset\n",
    "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Use only EEG channels\n",
    "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "# Extract trials, only using EEG channels\n",
    "epoched = mne.Epochs(raw, events, dict(hands_or_left=2, feet_or_right=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n",
    "                baseline=None, preload=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert data to Braindecode format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# Convert data from volt to millivolt\n",
    "# Pytorch expects float32 for input and int64 for labels.\n",
    "X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "\n",
    "train_set = SignalAndTarget(X[:40], y=y[:40])\n",
    "valid_set = SignalAndTarget(X[40:70], y=y[40:70])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "As in the trialwise decoding tutorial, we will use the Braindecode model class directly to perform the training in a few lines of code. If you instead want to use your own training loop, have a look at the [Cropped Manual Training Loop Tutorial](./Cropped_Manual_Training_Loop.html).\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For cropped decoding, we now transform the model into a model that outputs a dense time series of predictions.\n",
    "For this, we manually set the length of the final convolution layer to some length that makes the receptive field of the ConvNet smaller than the number of samples in a trial (see `final_conv_length=12` in the model definition). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = False\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "n_classes = 2\n",
    "in_chans = train_set.X.shape[1]\n",
    "\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n",
    "                        input_time_length=None,\n",
    "                        final_conv_length=12)\n",
    "if cuda:\n",
    "    model.cuda()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we supply `cropped=True` to our compile function "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.torch_ext.optimizers import AdamW\n",
    "import torch.nn.functional as F\n",
    "#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model\n",
    "optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)\n",
    "model.compile(loss=F.nll_loss, optimizer=optimizer,  iterator_seed=1, cropped=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For fitting, we must supply the super crop size. Here, we it to 450 by setting `input_time_length = 450`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-08-08 11:49:48,805 INFO : Run until first stop...\n",
      "2018-08-08 11:49:52,805 INFO : Time only for training updates: 3.99s\n",
      "2018-08-08 11:49:54,825 INFO : Epoch 0\n",
      "2018-08-08 11:49:54,826 INFO : train_loss                4.16812\n",
      "2018-08-08 11:49:54,830 INFO : valid_loss                3.39068\n",
      "2018-08-08 11:49:54,830 INFO : train_misclass            0.50000\n",
      "2018-08-08 11:49:54,831 INFO : valid_misclass            0.53333\n",
      "2018-08-08 11:49:54,833 INFO : runtime                   0.00000\n",
      "2018-08-08 11:49:54,835 INFO : \n",
      "2018-08-08 11:49:59,395 INFO : Time only for training updates: 4.55s\n",
      "2018-08-08 11:50:03,372 INFO : Epoch 1\n",
      "2018-08-08 11:50:03,379 INFO : train_loss                3.04388\n",
      "2018-08-08 11:50:03,384 INFO : valid_loss                2.38477\n",
      "2018-08-08 11:50:03,389 INFO : train_misclass            0.52500\n",
      "2018-08-08 11:50:03,394 INFO : valid_misclass            0.53333\n",
      "2018-08-08 11:50:03,398 INFO : runtime                   6.59494\n",
      "2018-08-08 11:50:03,403 INFO : \n",
      "2018-08-08 11:50:09,317 INFO : Time only for training updates: 5.90s\n",
      "2018-08-08 11:50:13,263 INFO : Epoch 2\n",
      "2018-08-08 11:50:13,268 INFO : train_loss                2.06639\n",
      "2018-08-08 11:50:13,273 INFO : valid_loss                1.61044\n",
      "2018-08-08 11:50:13,277 INFO : train_misclass            0.47500\n",
      "2018-08-08 11:50:13,281 INFO : valid_misclass            0.50000\n",
      "2018-08-08 11:50:13,285 INFO : runtime                   9.91786\n",
      "2018-08-08 11:50:13,289 INFO : \n",
      "2018-08-08 11:50:19,343 INFO : Time only for training updates: 6.03s\n",
      "2018-08-08 11:50:23,245 INFO : Epoch 3\n",
      "2018-08-08 11:50:23,251 INFO : train_loss                1.41936\n",
      "2018-08-08 11:50:23,255 INFO : valid_loss                1.18789\n",
      "2018-08-08 11:50:23,259 INFO : train_misclass            0.47500\n",
      "2018-08-08 11:50:23,263 INFO : valid_misclass            0.50000\n",
      "2018-08-08 11:50:23,267 INFO : runtime                   10.02883\n",
      "2018-08-08 11:50:23,271 INFO : \n",
      "2018-08-08 11:50:29,225 INFO : Time only for training updates: 5.94s\n",
      "2018-08-08 11:50:33,218 INFO : Epoch 4\n",
      "2018-08-08 11:50:33,223 INFO : train_loss                1.00751\n",
      "2018-08-08 11:50:33,227 INFO : valid_loss                0.95534\n",
      "2018-08-08 11:50:33,232 INFO : train_misclass            0.40000\n",
      "2018-08-08 11:50:33,236 INFO : valid_misclass            0.43333\n",
      "2018-08-08 11:50:33,242 INFO : runtime                   9.88115\n",
      "2018-08-08 11:50:33,248 INFO : \n",
      "2018-08-08 11:50:39,206 INFO : Time only for training updates: 5.94s\n",
      "2018-08-08 11:50:43,233 INFO : Epoch 5\n",
      "2018-08-08 11:50:43,237 INFO : train_loss                0.76304\n",
      "2018-08-08 11:50:43,240 INFO : valid_loss                0.83303\n",
      "2018-08-08 11:50:43,243 INFO : train_misclass            0.32500\n",
      "2018-08-08 11:50:43,247 INFO : valid_misclass            0.43333\n",
      "2018-08-08 11:50:43,250 INFO : runtime                   9.97928\n",
      "2018-08-08 11:50:43,255 INFO : \n",
      "2018-08-08 11:50:49,142 INFO : Time only for training updates: 5.87s\n",
      "2018-08-08 11:50:53,124 INFO : Epoch 6\n",
      "2018-08-08 11:50:53,130 INFO : train_loss                0.61960\n",
      "2018-08-08 11:50:53,134 INFO : valid_loss                0.78367\n",
      "2018-08-08 11:50:53,139 INFO : train_misclass            0.30000\n",
      "2018-08-08 11:50:53,143 INFO : valid_misclass            0.43333\n",
      "2018-08-08 11:50:53,148 INFO : runtime                   9.93809\n",
      "2018-08-08 11:50:53,152 INFO : \n",
      "2018-08-08 11:50:59,114 INFO : Time only for training updates: 5.95s\n",
      "2018-08-08 11:51:03,191 INFO : Epoch 7\n",
      "2018-08-08 11:51:03,197 INFO : train_loss                0.51544\n",
      "2018-08-08 11:51:03,202 INFO : valid_loss                0.75722\n",
      "2018-08-08 11:51:03,207 INFO : train_misclass            0.25000\n",
      "2018-08-08 11:51:03,211 INFO : valid_misclass            0.36667\n",
      "2018-08-08 11:51:03,215 INFO : runtime                   9.97154\n",
      "2018-08-08 11:51:03,220 INFO : \n",
      "2018-08-08 11:51:09,159 INFO : Time only for training updates: 5.92s\n",
      "2018-08-08 11:51:13,178 INFO : Epoch 8\n",
      "2018-08-08 11:51:13,183 INFO : train_loss                0.41659\n",
      "2018-08-08 11:51:13,187 INFO : valid_loss                0.70797\n",
      "2018-08-08 11:51:13,191 INFO : train_misclass            0.20000\n",
      "2018-08-08 11:51:13,194 INFO : valid_misclass            0.33333\n",
      "2018-08-08 11:51:13,197 INFO : runtime                   10.04526\n",
      "2018-08-08 11:51:13,201 INFO : \n",
      "2018-08-08 11:51:19,142 INFO : Time only for training updates: 5.93s\n",
      "2018-08-08 11:51:23,173 INFO : Epoch 9\n",
      "2018-08-08 11:51:23,195 INFO : train_loss                0.33328\n",
      "2018-08-08 11:51:23,197 INFO : valid_loss                0.65720\n",
      "2018-08-08 11:51:23,198 INFO : train_misclass            0.17500\n",
      "2018-08-08 11:51:23,199 INFO : valid_misclass            0.30000\n",
      "2018-08-08 11:51:23,200 INFO : runtime                   9.98219\n",
      "2018-08-08 11:51:23,201 INFO : \n",
      "2018-08-08 11:51:29,120 INFO : Time only for training updates: 5.89s\n",
      "2018-08-08 11:51:32,990 INFO : Epoch 10\n",
      "2018-08-08 11:51:32,995 INFO : train_loss                0.27232\n",
      "2018-08-08 11:51:32,999 INFO : valid_loss                0.62543\n",
      "2018-08-08 11:51:33,004 INFO : train_misclass            0.15000\n",
      "2018-08-08 11:51:33,008 INFO : valid_misclass            0.30000\n",
      "2018-08-08 11:51:33,012 INFO : runtime                   9.97608\n",
      "2018-08-08 11:51:33,015 INFO : \n",
      "2018-08-08 11:51:38,980 INFO : Time only for training updates: 5.95s\n",
      "2018-08-08 11:51:42,894 INFO : Epoch 11\n",
      "2018-08-08 11:51:42,899 INFO : train_loss                0.22401\n",
      "2018-08-08 11:51:42,903 INFO : valid_loss                0.59923\n",
      "2018-08-08 11:51:42,907 INFO : train_misclass            0.07500\n",
      "2018-08-08 11:51:42,911 INFO : valid_misclass            0.30000\n",
      "2018-08-08 11:51:42,914 INFO : runtime                   9.86210\n",
      "2018-08-08 11:51:42,918 INFO : \n",
      "2018-08-08 11:51:48,798 INFO : Time only for training updates: 5.87s\n",
      "2018-08-08 11:51:52,831 INFO : Epoch 12\n",
      "2018-08-08 11:51:52,835 INFO : train_loss                0.18465\n",
      "2018-08-08 11:51:52,839 INFO : valid_loss                0.56904\n",
      "2018-08-08 11:51:52,843 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:51:52,847 INFO : valid_misclass            0.26667\n",
      "2018-08-08 11:51:52,850 INFO : runtime                   9.81899\n",
      "2018-08-08 11:51:52,854 INFO : \n",
      "2018-08-08 11:51:58,798 INFO : Time only for training updates: 5.93s\n",
      "2018-08-08 11:52:02,766 INFO : Epoch 13\n",
      "2018-08-08 11:52:02,771 INFO : train_loss                0.15289\n",
      "2018-08-08 11:52:02,775 INFO : valid_loss                0.53387\n",
      "2018-08-08 11:52:02,780 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:02,784 INFO : valid_misclass            0.23333\n",
      "2018-08-08 11:52:02,788 INFO : runtime                   9.99905\n",
      "2018-08-08 11:52:02,793 INFO : \n",
      "2018-08-08 11:52:08,751 INFO : Time only for training updates: 5.94s\n",
      "2018-08-08 11:52:12,614 INFO : Epoch 14\n",
      "2018-08-08 11:52:12,620 INFO : train_loss                0.12854\n",
      "2018-08-08 11:52:12,627 INFO : valid_loss                0.49827\n",
      "2018-08-08 11:52:12,633 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:12,637 INFO : valid_misclass            0.20000\n",
      "2018-08-08 11:52:12,642 INFO : runtime                   9.95297\n",
      "2018-08-08 11:52:12,644 INFO : \n",
      "2018-08-08 11:52:18,474 INFO : Time only for training updates: 5.82s\n",
      "2018-08-08 11:52:21,438 INFO : Epoch 15\n",
      "2018-08-08 11:52:21,439 INFO : train_loss                0.11070\n",
      "2018-08-08 11:52:21,440 INFO : valid_loss                0.46770\n",
      "2018-08-08 11:52:21,441 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:21,443 INFO : valid_misclass            0.16667\n",
      "2018-08-08 11:52:21,444 INFO : runtime                   9.72386\n",
      "2018-08-08 11:52:21,445 INFO : \n",
      "2018-08-08 11:52:24,518 INFO : Time only for training updates: 3.06s\n",
      "2018-08-08 11:52:26,554 INFO : Epoch 16\n",
      "2018-08-08 11:52:26,556 INFO : train_loss                0.09734\n",
      "2018-08-08 11:52:26,557 INFO : valid_loss                0.44107\n",
      "2018-08-08 11:52:26,558 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:26,559 INFO : valid_misclass            0.16667\n",
      "2018-08-08 11:52:26,560 INFO : runtime                   6.04018\n",
      "2018-08-08 11:52:26,561 INFO : \n",
      "2018-08-08 11:52:29,640 INFO : Time only for training updates: 3.07s\n",
      "2018-08-08 11:52:31,673 INFO : Epoch 17\n",
      "2018-08-08 11:52:31,674 INFO : train_loss                0.08650\n",
      "2018-08-08 11:52:31,678 INFO : valid_loss                0.41685\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2018-08-08 11:52:31,680 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:31,681 INFO : valid_misclass            0.16667\n",
      "2018-08-08 11:52:31,682 INFO : runtime                   5.12189\n",
      "2018-08-08 11:52:31,683 INFO : \n",
      "2018-08-08 11:52:34,754 INFO : Time only for training updates: 3.06s\n",
      "2018-08-08 11:52:36,786 INFO : Epoch 18\n",
      "2018-08-08 11:52:36,787 INFO : train_loss                0.07784\n",
      "2018-08-08 11:52:36,791 INFO : valid_loss                0.39502\n",
      "2018-08-08 11:52:36,792 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:36,793 INFO : valid_misclass            0.13333\n",
      "2018-08-08 11:52:36,794 INFO : runtime                   5.11526\n",
      "2018-08-08 11:52:36,796 INFO : \n",
      "2018-08-08 11:52:39,876 INFO : Time only for training updates: 3.07s\n",
      "2018-08-08 11:52:41,910 INFO : Epoch 19\n",
      "2018-08-08 11:52:41,911 INFO : train_loss                0.07096\n",
      "2018-08-08 11:52:41,915 INFO : valid_loss                0.37620\n",
      "2018-08-08 11:52:41,916 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:41,920 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:52:41,921 INFO : runtime                   5.12150\n",
      "2018-08-08 11:52:41,923 INFO : \n",
      "2018-08-08 11:52:45,000 INFO : Time only for training updates: 3.07s\n",
      "2018-08-08 11:52:47,032 INFO : Epoch 20\n",
      "2018-08-08 11:52:47,033 INFO : train_loss                0.06583\n",
      "2018-08-08 11:52:47,037 INFO : valid_loss                0.36116\n",
      "2018-08-08 11:52:47,038 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:47,040 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:52:47,042 INFO : runtime                   5.12370\n",
      "2018-08-08 11:52:47,045 INFO : \n",
      "2018-08-08 11:52:50,116 INFO : Time only for training updates: 3.06s\n",
      "2018-08-08 11:52:52,148 INFO : Epoch 21\n",
      "2018-08-08 11:52:52,149 INFO : train_loss                0.06186\n",
      "2018-08-08 11:52:52,151 INFO : valid_loss                0.34902\n",
      "2018-08-08 11:52:52,152 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:52,153 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:52:52,154 INFO : runtime                   5.11570\n",
      "2018-08-08 11:52:52,156 INFO : \n",
      "2018-08-08 11:52:55,236 INFO : Time only for training updates: 3.07s\n",
      "2018-08-08 11:52:57,268 INFO : Epoch 22\n",
      "2018-08-08 11:52:57,269 INFO : train_loss                0.05873\n",
      "2018-08-08 11:52:57,273 INFO : valid_loss                0.33921\n",
      "2018-08-08 11:52:57,274 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:52:57,278 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:52:57,279 INFO : runtime                   5.12058\n",
      "2018-08-08 11:52:57,282 INFO : \n",
      "2018-08-08 11:53:00,444 INFO : Time only for training updates: 3.15s\n",
      "2018-08-08 11:53:02,601 INFO : Epoch 23\n",
      "2018-08-08 11:53:02,602 INFO : train_loss                0.05626\n",
      "2018-08-08 11:53:02,606 INFO : valid_loss                0.33130\n",
      "2018-08-08 11:53:02,607 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:53:02,607 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:53:02,610 INFO : runtime                   5.20754\n",
      "2018-08-08 11:53:02,612 INFO : \n",
      "2018-08-08 11:53:05,691 INFO : Time only for training updates: 3.07s\n",
      "2018-08-08 11:53:07,723 INFO : Epoch 24\n",
      "2018-08-08 11:53:07,724 INFO : train_loss                0.05439\n",
      "2018-08-08 11:53:07,727 INFO : valid_loss                0.32496\n",
      "2018-08-08 11:53:07,728 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:53:07,729 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:53:07,730 INFO : runtime                   5.24707\n",
      "2018-08-08 11:53:07,731 INFO : \n",
      "2018-08-08 11:53:10,805 INFO : Time only for training updates: 3.06s\n",
      "2018-08-08 11:53:12,832 INFO : Epoch 25\n",
      "2018-08-08 11:53:12,833 INFO : train_loss                0.05298\n",
      "2018-08-08 11:53:12,835 INFO : valid_loss                0.31987\n",
      "2018-08-08 11:53:12,836 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:53:12,837 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:53:12,838 INFO : runtime                   5.11422\n",
      "2018-08-08 11:53:12,839 INFO : \n",
      "2018-08-08 11:53:15,915 INFO : Time only for training updates: 3.06s\n",
      "2018-08-08 11:53:17,942 INFO : Epoch 26\n",
      "2018-08-08 11:53:17,943 INFO : train_loss                0.05193\n",
      "2018-08-08 11:53:17,947 INFO : valid_loss                0.31576\n",
      "2018-08-08 11:53:17,948 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:53:17,949 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:53:17,950 INFO : runtime                   5.10983\n",
      "2018-08-08 11:53:17,951 INFO : \n",
      "2018-08-08 11:53:21,025 INFO : Time only for training updates: 3.06s\n",
      "2018-08-08 11:53:23,051 INFO : Epoch 27\n",
      "2018-08-08 11:53:23,052 INFO : train_loss                0.05121\n",
      "2018-08-08 11:53:23,055 INFO : valid_loss                0.31244\n",
      "2018-08-08 11:53:23,056 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:53:23,057 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:53:23,058 INFO : runtime                   5.10984\n",
      "2018-08-08 11:53:23,060 INFO : \n",
      "2018-08-08 11:53:26,138 INFO : Time only for training updates: 3.07s\n",
      "2018-08-08 11:53:28,171 INFO : Epoch 28\n",
      "2018-08-08 11:53:28,172 INFO : train_loss                0.05073\n",
      "2018-08-08 11:53:28,175 INFO : valid_loss                0.30970\n",
      "2018-08-08 11:53:28,176 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:53:28,177 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:53:28,179 INFO : runtime                   5.11324\n",
      "2018-08-08 11:53:28,180 INFO : \n",
      "2018-08-08 11:53:31,252 INFO : Time only for training updates: 3.06s\n",
      "2018-08-08 11:53:33,283 INFO : Epoch 29\n",
      "2018-08-08 11:53:33,284 INFO : train_loss                0.05041\n",
      "2018-08-08 11:53:33,287 INFO : valid_loss                0.30740\n",
      "2018-08-08 11:53:33,288 INFO : train_misclass            0.00000\n",
      "2018-08-08 11:53:33,289 INFO : valid_misclass            0.10000\n",
      "2018-08-08 11:53:33,290 INFO : runtime                   5.11368\n",
      "2018-08-08 11:53:33,292 INFO : \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<braindecode.experiments.experiment.Experiment at 0x7f75b82e8940>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_time_length = 450\n",
    "model.fit(train_set.X, train_set.y, epochs=30, batch_size=64, scheduler='cosine',\n",
    "          input_time_length=input_time_length,\n",
    "         validation_data=(valid_set.X, valid_set.y),)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>train_misclass</th>\n",
       "      <th>valid_misclass</th>\n",
       "      <th>runtime</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4.168118</td>\n",
       "      <td>3.390682</td>\n",
       "      <td>0.500</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.043880</td>\n",
       "      <td>2.384769</td>\n",
       "      <td>0.525</td>\n",
       "      <td>0.533333</td>\n",
       "      <td>6.594939</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2.066392</td>\n",
       "      <td>1.610444</td>\n",
       "      <td>0.475</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>9.917859</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.419359</td>\n",
       "      <td>1.187893</td>\n",
       "      <td>0.475</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>10.028832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.007514</td>\n",
       "      <td>0.955341</td>\n",
       "      <td>0.400</td>\n",
       "      <td>0.433333</td>\n",
       "      <td>9.881147</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0.763041</td>\n",
       "      <td>0.833030</td>\n",
       "      <td>0.325</td>\n",
       "      <td>0.433333</td>\n",
       "      <td>9.979284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.619604</td>\n",
       "      <td>0.783671</td>\n",
       "      <td>0.300</td>\n",
       "      <td>0.433333</td>\n",
       "      <td>9.938087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0.515436</td>\n",
       "      <td>0.757223</td>\n",
       "      <td>0.250</td>\n",
       "      <td>0.366667</td>\n",
       "      <td>9.971540</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>0.416595</td>\n",
       "      <td>0.707971</td>\n",
       "      <td>0.200</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>10.045257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>0.333283</td>\n",
       "      <td>0.657200</td>\n",
       "      <td>0.175</td>\n",
       "      <td>0.300000</td>\n",
       "      <td>9.982185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>0.272317</td>\n",
       "      <td>0.625433</td>\n",
       "      <td>0.150</td>\n",
       "      <td>0.300000</td>\n",
       "      <td>9.976075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>0.224010</td>\n",
       "      <td>0.599227</td>\n",
       "      <td>0.075</td>\n",
       "      <td>0.300000</td>\n",
       "      <td>9.862099</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>0.184646</td>\n",
       "      <td>0.569042</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>9.818995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>0.152886</td>\n",
       "      <td>0.533874</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.233333</td>\n",
       "      <td>9.999050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>0.128538</td>\n",
       "      <td>0.498271</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>9.952966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>0.110699</td>\n",
       "      <td>0.467696</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>9.723856</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>0.097337</td>\n",
       "      <td>0.441067</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>6.040177</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>0.086503</td>\n",
       "      <td>0.416848</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>5.121891</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>0.077835</td>\n",
       "      <td>0.395015</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.133333</td>\n",
       "      <td>5.115263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>0.070961</td>\n",
       "      <td>0.376196</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.121497</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>0.065832</td>\n",
       "      <td>0.361156</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.123702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0.061856</td>\n",
       "      <td>0.349016</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.115699</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>0.058733</td>\n",
       "      <td>0.339215</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.120580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>0.056258</td>\n",
       "      <td>0.331300</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.207545</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>0.054392</td>\n",
       "      <td>0.324961</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.247069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>0.052977</td>\n",
       "      <td>0.319866</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.114219</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>0.051933</td>\n",
       "      <td>0.315760</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.109835</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>0.051210</td>\n",
       "      <td>0.312438</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.109844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>0.050732</td>\n",
       "      <td>0.309702</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.113238</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>0.050408</td>\n",
       "      <td>0.307397</td>\n",
       "      <td>0.000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>5.113682</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    train_loss  valid_loss  train_misclass  valid_misclass    runtime\n",
       "0     4.168118    3.390682           0.500        0.533333   0.000000\n",
       "1     3.043880    2.384769           0.525        0.533333   6.594939\n",
       "2     2.066392    1.610444           0.475        0.500000   9.917859\n",
       "3     1.419359    1.187893           0.475        0.500000  10.028832\n",
       "4     1.007514    0.955341           0.400        0.433333   9.881147\n",
       "5     0.763041    0.833030           0.325        0.433333   9.979284\n",
       "6     0.619604    0.783671           0.300        0.433333   9.938087\n",
       "7     0.515436    0.757223           0.250        0.366667   9.971540\n",
       "8     0.416595    0.707971           0.200        0.333333  10.045257\n",
       "9     0.333283    0.657200           0.175        0.300000   9.982185\n",
       "10    0.272317    0.625433           0.150        0.300000   9.976075\n",
       "11    0.224010    0.599227           0.075        0.300000   9.862099\n",
       "12    0.184646    0.569042           0.000        0.266667   9.818995\n",
       "13    0.152886    0.533874           0.000        0.233333   9.999050\n",
       "14    0.128538    0.498271           0.000        0.200000   9.952966\n",
       "15    0.110699    0.467696           0.000        0.166667   9.723856\n",
       "16    0.097337    0.441067           0.000        0.166667   6.040177\n",
       "17    0.086503    0.416848           0.000        0.166667   5.121891\n",
       "18    0.077835    0.395015           0.000        0.133333   5.115263\n",
       "19    0.070961    0.376196           0.000        0.100000   5.121497\n",
       "20    0.065832    0.361156           0.000        0.100000   5.123702\n",
       "21    0.061856    0.349016           0.000        0.100000   5.115699\n",
       "22    0.058733    0.339215           0.000        0.100000   5.120580\n",
       "23    0.056258    0.331300           0.000        0.100000   5.207545\n",
       "24    0.054392    0.324961           0.000        0.100000   5.247069\n",
       "25    0.052977    0.319866           0.000        0.100000   5.114219\n",
       "26    0.051933    0.315760           0.000        0.100000   5.109835\n",
       "27    0.051210    0.312438           0.000        0.100000   5.109844\n",
       "28    0.050732    0.309702           0.000        0.100000   5.113238\n",
       "29    0.050408    0.307397           0.000        0.100000   5.113682"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.epochs_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eventually, we arrive at 90% accuracy, so 27 from 30 trials are correctly predicted."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 0.4325282573699951,\n",
       " 'misclass': 0.09999999999999998,\n",
       " 'runtime': 0.0004916191101074219}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_set = SignalAndTarget(X[70:], y=y[70:])\n",
    "\n",
    "model.evaluate(test_set.X, test_set.y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(test_set.X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset references\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n",
    " \n",
    "     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n",
    "\n",
    "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n",
    "\n",
    "    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
