{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read and Decode BBCI Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial shows how to read and decode BBCI data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup logging to see outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.DEBUG, stream=sys.stdout)\n",
    "log = logging.getLogger()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and preprocess data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First set the filename and the sensors you want to load. If you set\n",
    "\n",
    "```python\n",
    "load_sensor_names=None\n",
    "```\n",
    "\n",
    "or just remove the parameter from the function call, all sensors will be loaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating RawArray with float64 data, n_channels=3, n_times=3451320\n",
      "    Range : 0 ... 3451319 =      0.000 ...  6902.638 secs\n",
      "Ready.\n"
     ]
    }
   ],
   "source": [
    "from braindecode.datasets.bbci import BBCIDataset\n",
    "train_filename = '/home/schirrmr/data/BBCI-without-last-runs/BhNoMoSc1S001R01_ds10_1-12.BBCI.mat'\n",
    "cnt = BBCIDataset(train_filename, load_sensor_names=['C3', 'CPz', 'C4']).load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing on continous data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First remove the stimulus channel, than apply any preprocessing you like. There are some very few directions available from Braindecode, such as resample_cnt. But you can apply any function on the chan x time matrix of the mne raw object (`cnt` in the code) by calling `mne_apply` with two arguments:\n",
    "\n",
    "1. Your function (2d-array-> 2darray), that transforms the channel x timesteps data array\n",
    "2. the Raw data object from mne itself"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2017-11-03 17:56:32,890 WARNING : This is not causal, uses future data....\n",
      "2017-11-03 17:56:32,891 INFO : Resampling from 500.000000 to 250.000000 Hz.\n",
      "Creating RawArray with float64 data, n_channels=3, n_times=1725660\n",
      "    Range : 0 ... 1725659 =      0.000 ...  6902.636 secs\n",
      "Ready.\n"
     ]
    }
   ],
   "source": [
    "from braindecode.mne_ext.signalproc import resample_cnt, mne_apply\n",
    "from braindecode.datautil.signalproc import exponential_running_standardize\n",
    "# Remove stimulus channel\n",
    "cnt = cnt.drop_channels(['STI 014'])\n",
    "cnt = resample_cnt(cnt, 250)\n",
    "# mne apply will apply the function to the data (a 2d-numpy-array)\n",
    "# have to transpose data back and forth, since\n",
    "# exponential_running_standardize expects time x chans order\n",
    "# while mne object has chans x time order\n",
    "cnt = mne_apply(lambda a: exponential_running_standardize(\n",
    "    a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,\n",
    "    cnt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transform to epoched dataset "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Braindecode supplies the `create_signal_target_from_raw_mne` function, which will transform the mne raw object into a `SignalAndTarget` object for use in Braindecode.\n",
    "`name_to_code` should be an `OrderedDict` that maps class names to either one or a list of marker codes for that class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2017-11-03 17:56:34,795 INFO : Trial per class:\n",
      "Counter({'Feet': 225, 'Right': 224, 'Rest': 224, 'Left': 224})\n"
     ]
    }
   ],
   "source": [
    "from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne\n",
    "from collections import OrderedDict\n",
    "# can also give lists of marker codes in case a class has multiple marker codes...\n",
    "name_to_code = OrderedDict([('Right', 1), ('Left', 2), ('Rest', 3), ('Feet', 4)])\n",
    "segment_ival_ms = [-500,4000]\n",
    "\n",
    "train_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Same for test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating RawArray with float64 data, n_channels=3, n_times=617090\n",
      "    Range : 0 ... 617089 =      0.000 ...  1234.178 secs\n",
      "Ready.\n",
      "2017-11-03 17:56:35,707 WARNING : This is not causal, uses future data....\n",
      "2017-11-03 17:56:35,708 INFO : Resampling from 500.000000 to 250.000000 Hz.\n",
      "Creating RawArray with float64 data, n_channels=3, n_times=308545\n",
      "    Range : 0 ... 308544 =      0.000 ...  1234.176 secs\n",
      "Ready.\n",
      "2017-11-03 17:56:36,026 INFO : Trial per class:\n",
      "Counter({'Feet': 40, 'Left': 40, 'Rest': 40, 'Right': 40})\n"
     ]
    }
   ],
   "source": [
    "test_filename = '/home/schirrmr/data/BBCI-only-last-runs/BhNoMoSc1S001R13_ds10_1-2BBCI.mat'\n",
    "cnt = BBCIDataset(test_filename, load_sensor_names=['C3', 'CPz', 'C4']).load()\n",
    "# Remove stimulus channel\n",
    "cnt = cnt.drop_channels(['STI 014'])\n",
    "cnt = resample_cnt(cnt, 250)\n",
    "cnt = mne_apply(lambda a: exponential_running_standardize(\n",
    "    a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,\n",
    "    cnt)\n",
    "test_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "In case of start and stop markers, provide a `name_to_stop_codes` dictionary (same as for the start codes in this example) as a final argument to `create_signal_target_from_raw_mne`. See [Read and Decode BBCI Data with Start-Stop-Markers Tutorial](BBCI_Data_Start_Stop.html)\n",
    "\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split off a validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.splitters import split_into_two_sets\n",
    "\n",
    "train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "from braindecode.models.util import to_dense_prediction_model\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = True\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "\n",
    "# This will determine how many crops are processed in parallel\n",
    "input_time_length = 800\n",
    "in_chans = 3\n",
    "n_classes = 4\n",
    "# final_conv_length determines the size of the receptive field of the ConvNet\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,\n",
    "                        final_conv_length=30).create_network()\n",
    "to_dense_prediction_model(model)\n",
    "\n",
    "if cuda:\n",
    "    model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup optimizer and iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "267 predictions per input/trial\n"
     ]
    }
   ],
   "source": [
    "from torch import optim\n",
    "import numpy as np\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "from braindecode.torch_ext.util import np_to_var\n",
    "# determine output size\n",
    "test_input = np_to_var(np.ones((2, 3, input_time_length, 1), dtype=np.float32))\n",
    "if cuda:\n",
    "    test_input = test_input.cuda()\n",
    "out = model(test_input)\n",
    "n_preds_per_input = out.cpu().data.numpy().shape[2]\n",
    "print(\"{:d} predictions per input/trial\".format(n_preds_per_input))\n",
    "\n",
    "from braindecode.datautil.iterators import CropsFromTrialsIterator\n",
    "iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,\n",
    "                                  n_preds_per_input=n_preds_per_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Monitors, Loss function, Stop Criteria"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.experiments.experiment import Experiment\n",
    "from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor\n",
    "from braindecode.experiments.stopcriteria import MaxEpochs\n",
    "import torch.nn.functional as F\n",
    "import torch as th\n",
    "from braindecode.torch_ext.modules import Expression\n",
    "\n",
    "\n",
    "loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2).squeeze(), targets)\n",
    "\n",
    "model_constraint = None\n",
    "monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),\n",
    "            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]\n",
    "stop_criterion = MaxEpochs(20)\n",
    "exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,\n",
    "          monitors, stop_criterion, remember_best_column='valid_misclass',\n",
    "          run_after_early_stop=True, batch_modifier=None, cuda=cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2017-11-03 17:56:39,458 INFO : Run until first stop...\n",
      "2017-11-03 17:56:40,298 INFO : Epoch 0\n",
      "2017-11-03 17:56:40,299 INFO : train_loss                7.89184\n",
      "2017-11-03 17:56:40,300 INFO : valid_loss                7.72731\n",
      "2017-11-03 17:56:40,301 INFO : test_loss                 7.75617\n",
      "2017-11-03 17:56:40,303 INFO : train_sample_misclass     0.75013\n",
      "2017-11-03 17:56:40,304 INFO : valid_sample_misclass     0.74856\n",
      "2017-11-03 17:56:40,305 INFO : test_sample_misclass      0.75115\n",
      "2017-11-03 17:56:40,306 INFO : train_misclass            0.75070\n",
      "2017-11-03 17:56:40,308 INFO : valid_misclass            0.74860\n",
      "2017-11-03 17:56:40,309 INFO : test_misclass             0.75000\n",
      "2017-11-03 17:56:40,310 INFO : runtime                   0.00000\n",
      "2017-11-03 17:56:40,311 INFO : \n",
      "2017-11-03 17:56:40,313 INFO : New best valid_misclass: 0.748603\n",
      "2017-11-03 17:56:40,314 INFO : \n",
      "2017-11-03 17:56:41,475 INFO : Time only for training updates: 0.94s\n",
      "2017-11-03 17:56:42,262 INFO : Epoch 1\n",
      "2017-11-03 17:56:42,263 INFO : train_loss                0.79550\n",
      "2017-11-03 17:56:42,264 INFO : valid_loss                0.79273\n",
      "2017-11-03 17:56:42,265 INFO : test_loss                 0.83673\n",
      "2017-11-03 17:56:42,266 INFO : train_sample_misclass     0.36961\n",
      "2017-11-03 17:56:42,267 INFO : valid_sample_misclass     0.37374\n",
      "2017-11-03 17:56:42,268 INFO : test_sample_misclass      0.42969\n",
      "2017-11-03 17:56:42,268 INFO : train_misclass            0.30223\n",
      "2017-11-03 17:56:42,269 INFO : valid_misclass            0.26257\n",
      "2017-11-03 17:56:42,270 INFO : test_misclass             0.33750\n",
      "2017-11-03 17:56:42,271 INFO : runtime                   2.01650\n",
      "2017-11-03 17:56:42,271 INFO : \n",
      "2017-11-03 17:56:42,274 INFO : New best valid_misclass: 0.262570\n",
      "2017-11-03 17:56:42,274 INFO : \n",
      "2017-11-03 17:56:43,423 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:44,209 INFO : Epoch 2\n",
      "2017-11-03 17:56:44,211 INFO : train_loss                0.68180\n",
      "2017-11-03 17:56:44,212 INFO : valid_loss                0.65888\n",
      "2017-11-03 17:56:44,213 INFO : test_loss                 0.74155\n",
      "2017-11-03 17:56:44,214 INFO : train_sample_misclass     0.29325\n",
      "2017-11-03 17:56:44,215 INFO : valid_sample_misclass     0.30236\n",
      "2017-11-03 17:56:44,216 INFO : test_sample_misclass      0.37317\n",
      "2017-11-03 17:56:44,217 INFO : train_misclass            0.22563\n",
      "2017-11-03 17:56:44,218 INFO : valid_misclass            0.21229\n",
      "2017-11-03 17:56:44,219 INFO : test_misclass             0.29375\n",
      "2017-11-03 17:56:44,220 INFO : runtime                   1.94802\n",
      "2017-11-03 17:56:44,221 INFO : \n",
      "2017-11-03 17:56:44,223 INFO : New best valid_misclass: 0.212291\n",
      "2017-11-03 17:56:44,224 INFO : \n",
      "2017-11-03 17:56:45,377 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:46,152 INFO : Epoch 3\n",
      "2017-11-03 17:56:46,153 INFO : train_loss                0.65244\n",
      "2017-11-03 17:56:46,154 INFO : valid_loss                0.68508\n",
      "2017-11-03 17:56:46,155 INFO : test_loss                 0.81607\n",
      "2017-11-03 17:56:46,156 INFO : train_sample_misclass     0.27737\n",
      "2017-11-03 17:56:46,157 INFO : valid_sample_misclass     0.32108\n",
      "2017-11-03 17:56:46,157 INFO : test_sample_misclass      0.43695\n",
      "2017-11-03 17:56:46,158 INFO : train_misclass            0.20195\n",
      "2017-11-03 17:56:46,159 INFO : valid_misclass            0.25140\n",
      "2017-11-03 17:56:46,160 INFO : test_misclass             0.36250\n",
      "2017-11-03 17:56:46,160 INFO : runtime                   1.95389\n",
      "2017-11-03 17:56:46,161 INFO : \n",
      "2017-11-03 17:56:47,319 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:48,106 INFO : Epoch 4\n",
      "2017-11-03 17:56:48,107 INFO : train_loss                0.59123\n",
      "2017-11-03 17:56:48,108 INFO : valid_loss                0.57887\n",
      "2017-11-03 17:56:48,109 INFO : test_loss                 0.71750\n",
      "2017-11-03 17:56:48,110 INFO : train_sample_misclass     0.25470\n",
      "2017-11-03 17:56:48,110 INFO : valid_sample_misclass     0.25364\n",
      "2017-11-03 17:56:48,111 INFO : test_sample_misclass      0.35189\n",
      "2017-11-03 17:56:48,112 INFO : train_misclass            0.20474\n",
      "2017-11-03 17:56:48,113 INFO : valid_misclass            0.20670\n",
      "2017-11-03 17:56:48,113 INFO : test_misclass             0.30625\n",
      "2017-11-03 17:56:48,114 INFO : runtime                   1.94214\n",
      "2017-11-03 17:56:48,115 INFO : \n",
      "2017-11-03 17:56:48,117 INFO : New best valid_misclass: 0.206704\n",
      "2017-11-03 17:56:48,118 INFO : \n",
      "2017-11-03 17:56:49,274 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:50,058 INFO : Epoch 5\n",
      "2017-11-03 17:56:50,060 INFO : train_loss                0.76997\n",
      "2017-11-03 17:56:50,060 INFO : valid_loss                0.72850\n",
      "2017-11-03 17:56:50,061 INFO : test_loss                 0.83681\n",
      "2017-11-03 17:56:50,062 INFO : train_sample_misclass     0.33041\n",
      "2017-11-03 17:56:50,063 INFO : valid_sample_misclass     0.32293\n",
      "2017-11-03 17:56:50,063 INFO : test_sample_misclass      0.39756\n",
      "2017-11-03 17:56:50,064 INFO : train_misclass            0.25487\n",
      "2017-11-03 17:56:50,065 INFO : valid_misclass            0.29050\n",
      "2017-11-03 17:56:50,066 INFO : test_misclass             0.34375\n",
      "2017-11-03 17:56:50,066 INFO : runtime                   1.95470\n",
      "2017-11-03 17:56:50,067 INFO : \n",
      "2017-11-03 17:56:51,213 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:51,992 INFO : Epoch 6\n",
      "2017-11-03 17:56:51,994 INFO : train_loss                0.61221\n",
      "2017-11-03 17:56:51,995 INFO : valid_loss                0.62389\n",
      "2017-11-03 17:56:51,996 INFO : test_loss                 0.72417\n",
      "2017-11-03 17:56:51,996 INFO : train_sample_misclass     0.26388\n",
      "2017-11-03 17:56:51,997 INFO : valid_sample_misclass     0.28059\n",
      "2017-11-03 17:56:51,998 INFO : test_sample_misclass      0.34581\n",
      "2017-11-03 17:56:51,999 INFO : train_misclass            0.21031\n",
      "2017-11-03 17:56:51,999 INFO : valid_misclass            0.22346\n",
      "2017-11-03 17:56:52,000 INFO : test_misclass             0.30625\n",
      "2017-11-03 17:56:52,001 INFO : runtime                   1.93918\n",
      "2017-11-03 17:56:52,002 INFO : \n",
      "2017-11-03 17:56:53,152 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:53,932 INFO : Epoch 7\n",
      "2017-11-03 17:56:53,933 INFO : train_loss                0.57982\n",
      "2017-11-03 17:56:53,934 INFO : valid_loss                0.60830\n",
      "2017-11-03 17:56:53,936 INFO : test_loss                 0.73615\n",
      "2017-11-03 17:56:53,937 INFO : train_sample_misclass     0.25238\n",
      "2017-11-03 17:56:53,938 INFO : valid_sample_misclass     0.30217\n",
      "2017-11-03 17:56:53,939 INFO : test_sample_misclass      0.36385\n",
      "2017-11-03 17:56:53,941 INFO : train_misclass            0.17827\n",
      "2017-11-03 17:56:53,942 INFO : valid_misclass            0.21788\n",
      "2017-11-03 17:56:53,943 INFO : test_misclass             0.31875\n",
      "2017-11-03 17:56:53,944 INFO : runtime                   1.93903\n",
      "2017-11-03 17:56:53,946 INFO : \n",
      "2017-11-03 17:56:55,088 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:55,866 INFO : Epoch 8\n",
      "2017-11-03 17:56:55,867 INFO : train_loss                0.53394\n",
      "2017-11-03 17:56:55,868 INFO : valid_loss                0.54075\n",
      "2017-11-03 17:56:55,868 INFO : test_loss                 0.69350\n",
      "2017-11-03 17:56:55,869 INFO : train_sample_misclass     0.22455\n",
      "2017-11-03 17:56:55,870 INFO : valid_sample_misclass     0.24399\n",
      "2017-11-03 17:56:55,871 INFO : test_sample_misclass      0.35034\n",
      "2017-11-03 17:56:55,872 INFO : train_misclass            0.15460\n",
      "2017-11-03 17:56:55,872 INFO : valid_misclass            0.14525\n",
      "2017-11-03 17:56:55,873 INFO : test_misclass             0.21250\n",
      "2017-11-03 17:56:55,874 INFO : runtime                   1.93581\n",
      "2017-11-03 17:56:55,875 INFO : \n",
      "2017-11-03 17:56:55,877 INFO : New best valid_misclass: 0.145251\n",
      "2017-11-03 17:56:55,878 INFO : \n",
      "2017-11-03 17:56:57,025 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:57,808 INFO : Epoch 9\n",
      "2017-11-03 17:56:57,809 INFO : train_loss                0.52253\n",
      "2017-11-03 17:56:57,810 INFO : valid_loss                0.55157\n",
      "2017-11-03 17:56:57,811 INFO : test_loss                 0.66259\n",
      "2017-11-03 17:56:57,812 INFO : train_sample_misclass     0.20662\n",
      "2017-11-03 17:56:57,813 INFO : valid_sample_misclass     0.23934\n",
      "2017-11-03 17:56:57,813 INFO : test_sample_misclass      0.31630\n",
      "2017-11-03 17:56:57,814 INFO : train_misclass            0.12535\n",
      "2017-11-03 17:56:57,815 INFO : valid_misclass            0.15084\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2017-11-03 17:56:57,816 INFO : test_misclass             0.19375\n",
      "2017-11-03 17:56:57,816 INFO : runtime                   1.93780\n",
      "2017-11-03 17:56:57,817 INFO : \n",
      "2017-11-03 17:56:58,966 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:56:59,747 INFO : Epoch 10\n",
      "2017-11-03 17:56:59,749 INFO : train_loss                0.63435\n",
      "2017-11-03 17:56:59,749 INFO : valid_loss                0.54363\n",
      "2017-11-03 17:56:59,750 INFO : test_loss                 0.66703\n",
      "2017-11-03 17:56:59,751 INFO : train_sample_misclass     0.23284\n",
      "2017-11-03 17:56:59,752 INFO : valid_sample_misclass     0.25216\n",
      "2017-11-03 17:56:59,752 INFO : test_sample_misclass      0.31337\n",
      "2017-11-03 17:56:59,753 INFO : train_misclass            0.19081\n",
      "2017-11-03 17:56:59,754 INFO : valid_misclass            0.16760\n",
      "2017-11-03 17:56:59,755 INFO : test_misclass             0.21875\n",
      "2017-11-03 17:56:59,755 INFO : runtime                   1.94030\n",
      "2017-11-03 17:56:59,756 INFO : \n",
      "2017-11-03 17:57:00,903 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:01,683 INFO : Epoch 11\n",
      "2017-11-03 17:57:01,684 INFO : train_loss                0.50823\n",
      "2017-11-03 17:57:01,685 INFO : valid_loss                0.57940\n",
      "2017-11-03 17:57:01,686 INFO : test_loss                 0.65863\n",
      "2017-11-03 17:57:01,686 INFO : train_sample_misclass     0.21697\n",
      "2017-11-03 17:57:01,687 INFO : valid_sample_misclass     0.26513\n",
      "2017-11-03 17:57:01,688 INFO : test_sample_misclass      0.33127\n",
      "2017-11-03 17:57:01,689 INFO : train_misclass            0.13510\n",
      "2017-11-03 17:57:01,689 INFO : valid_misclass            0.18994\n",
      "2017-11-03 17:57:01,690 INFO : test_misclass             0.24375\n",
      "2017-11-03 17:57:01,691 INFO : runtime                   1.93741\n",
      "2017-11-03 17:57:01,692 INFO : \n",
      "2017-11-03 17:57:02,839 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:03,619 INFO : Epoch 12\n",
      "2017-11-03 17:57:03,621 INFO : train_loss                0.51300\n",
      "2017-11-03 17:57:03,622 INFO : valid_loss                0.54631\n",
      "2017-11-03 17:57:03,623 INFO : test_loss                 0.64827\n",
      "2017-11-03 17:57:03,624 INFO : train_sample_misclass     0.20628\n",
      "2017-11-03 17:57:03,624 INFO : valid_sample_misclass     0.23820\n",
      "2017-11-03 17:57:03,625 INFO : test_sample_misclass      0.32021\n",
      "2017-11-03 17:57:03,626 INFO : train_misclass            0.13649\n",
      "2017-11-03 17:57:03,627 INFO : valid_misclass            0.15642\n",
      "2017-11-03 17:57:03,628 INFO : test_misclass             0.24375\n",
      "2017-11-03 17:57:03,628 INFO : runtime                   1.93566\n",
      "2017-11-03 17:57:03,629 INFO : \n",
      "2017-11-03 17:57:04,777 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:05,568 INFO : Epoch 13\n",
      "2017-11-03 17:57:05,569 INFO : train_loss                0.53439\n",
      "2017-11-03 17:57:05,570 INFO : valid_loss                0.57364\n",
      "2017-11-03 17:57:05,571 INFO : test_loss                 0.75146\n",
      "2017-11-03 17:57:05,572 INFO : train_sample_misclass     0.22174\n",
      "2017-11-03 17:57:05,573 INFO : valid_sample_misclass     0.25399\n",
      "2017-11-03 17:57:05,573 INFO : test_sample_misclass      0.36589\n",
      "2017-11-03 17:57:05,574 INFO : train_misclass            0.14903\n",
      "2017-11-03 17:57:05,575 INFO : valid_misclass            0.16201\n",
      "2017-11-03 17:57:05,576 INFO : test_misclass             0.27500\n",
      "2017-11-03 17:57:05,576 INFO : runtime                   1.93826\n",
      "2017-11-03 17:57:05,577 INFO : \n",
      "2017-11-03 17:57:06,724 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:07,510 INFO : Epoch 14\n",
      "2017-11-03 17:57:07,511 INFO : train_loss                0.49450\n",
      "2017-11-03 17:57:07,512 INFO : valid_loss                0.50904\n",
      "2017-11-03 17:57:07,513 INFO : test_loss                 0.59711\n",
      "2017-11-03 17:57:07,514 INFO : train_sample_misclass     0.19409\n",
      "2017-11-03 17:57:07,514 INFO : valid_sample_misclass     0.23406\n",
      "2017-11-03 17:57:07,515 INFO : test_sample_misclass      0.29099\n",
      "2017-11-03 17:57:07,516 INFO : train_misclass            0.11421\n",
      "2017-11-03 17:57:07,517 INFO : valid_misclass            0.15084\n",
      "2017-11-03 17:57:07,517 INFO : test_misclass             0.21875\n",
      "2017-11-03 17:57:07,518 INFO : runtime                   1.94703\n",
      "2017-11-03 17:57:07,519 INFO : \n",
      "2017-11-03 17:57:08,667 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:09,447 INFO : Epoch 15\n",
      "2017-11-03 17:57:09,448 INFO : train_loss                0.55814\n",
      "2017-11-03 17:57:09,449 INFO : valid_loss                0.52381\n",
      "2017-11-03 17:57:09,450 INFO : test_loss                 0.63100\n",
      "2017-11-03 17:57:09,450 INFO : train_sample_misclass     0.19992\n",
      "2017-11-03 17:57:09,451 INFO : valid_sample_misclass     0.22669\n",
      "2017-11-03 17:57:09,452 INFO : test_sample_misclass      0.30698\n",
      "2017-11-03 17:57:09,453 INFO : train_misclass            0.14624\n",
      "2017-11-03 17:57:09,453 INFO : valid_misclass            0.16201\n",
      "2017-11-03 17:57:09,454 INFO : test_misclass             0.23750\n",
      "2017-11-03 17:57:09,455 INFO : runtime                   1.94305\n",
      "2017-11-03 17:57:09,456 INFO : \n",
      "2017-11-03 17:57:10,604 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:11,388 INFO : Epoch 16\n",
      "2017-11-03 17:57:11,389 INFO : train_loss                0.54798\n",
      "2017-11-03 17:57:11,390 INFO : valid_loss                0.60483\n",
      "2017-11-03 17:57:11,391 INFO : test_loss                 0.70932\n",
      "2017-11-03 17:57:11,392 INFO : train_sample_misclass     0.23531\n",
      "2017-11-03 17:57:11,393 INFO : valid_sample_misclass     0.26090\n",
      "2017-11-03 17:57:11,394 INFO : test_sample_misclass      0.35440\n",
      "2017-11-03 17:57:11,396 INFO : train_misclass            0.19359\n",
      "2017-11-03 17:57:11,397 INFO : valid_misclass            0.22905\n",
      "2017-11-03 17:57:11,398 INFO : test_misclass             0.32500\n",
      "2017-11-03 17:57:11,399 INFO : runtime                   1.93673\n",
      "2017-11-03 17:57:11,399 INFO : \n",
      "2017-11-03 17:57:12,550 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:13,342 INFO : Epoch 17\n",
      "2017-11-03 17:57:13,343 INFO : train_loss                0.57735\n",
      "2017-11-03 17:57:13,344 INFO : valid_loss                0.60624\n",
      "2017-11-03 17:57:13,344 INFO : test_loss                 0.71293\n",
      "2017-11-03 17:57:13,345 INFO : train_sample_misclass     0.23263\n",
      "2017-11-03 17:57:13,346 INFO : valid_sample_misclass     0.25866\n",
      "2017-11-03 17:57:13,347 INFO : test_sample_misclass      0.35750\n",
      "2017-11-03 17:57:13,347 INFO : train_misclass            0.18524\n",
      "2017-11-03 17:57:13,348 INFO : valid_misclass            0.19553\n",
      "2017-11-03 17:57:13,349 INFO : test_misclass             0.29375\n",
      "2017-11-03 17:57:13,350 INFO : runtime                   1.94563\n",
      "2017-11-03 17:57:13,350 INFO : \n",
      "2017-11-03 17:57:14,500 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:15,288 INFO : Epoch 18\n",
      "2017-11-03 17:57:15,289 INFO : train_loss                0.46534\n",
      "2017-11-03 17:57:15,290 INFO : valid_loss                0.52169\n",
      "2017-11-03 17:57:15,291 INFO : test_loss                 0.64741\n",
      "2017-11-03 17:57:15,292 INFO : train_sample_misclass     0.19295\n",
      "2017-11-03 17:57:15,292 INFO : valid_sample_misclass     0.23086\n",
      "2017-11-03 17:57:15,293 INFO : test_sample_misclass      0.33224\n",
      "2017-11-03 17:57:15,294 INFO : train_misclass            0.11978\n",
      "2017-11-03 17:57:15,295 INFO : valid_misclass            0.18436\n",
      "2017-11-03 17:57:15,295 INFO : test_misclass             0.24375\n",
      "2017-11-03 17:57:15,296 INFO : runtime                   1.95093\n",
      "2017-11-03 17:57:15,297 INFO : \n",
      "2017-11-03 17:57:16,445 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:17,230 INFO : Epoch 19\n",
      "2017-11-03 17:57:17,232 INFO : train_loss                0.47752\n",
      "2017-11-03 17:57:17,232 INFO : valid_loss                0.55626\n",
      "2017-11-03 17:57:17,233 INFO : test_loss                 0.68759\n",
      "2017-11-03 17:57:17,234 INFO : train_sample_misclass     0.20194\n",
      "2017-11-03 17:57:17,235 INFO : valid_sample_misclass     0.24273\n",
      "2017-11-03 17:57:17,236 INFO : test_sample_misclass      0.35215\n",
      "2017-11-03 17:57:17,236 INFO : train_misclass            0.13928\n",
      "2017-11-03 17:57:17,237 INFO : valid_misclass            0.16201\n",
      "2017-11-03 17:57:17,238 INFO : test_misclass             0.23750\n",
      "2017-11-03 17:57:17,239 INFO : runtime                   1.94504\n",
      "2017-11-03 17:57:17,240 INFO : \n",
      "2017-11-03 17:57:18,387 INFO : Time only for training updates: 0.92s\n",
      "2017-11-03 17:57:19,174 INFO : Epoch 20\n",
      "2017-11-03 17:57:19,175 INFO : train_loss                0.46220\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2017-11-03 17:57:19,176 INFO : valid_loss                0.54947\n",
      "2017-11-03 17:57:19,177 INFO : test_loss                 0.62923\n",
      "2017-11-03 17:57:19,178 INFO : train_sample_misclass     0.18707\n",
      "2017-11-03 17:57:19,179 INFO : valid_sample_misclass     0.24050\n",
      "2017-11-03 17:57:19,179 INFO : test_sample_misclass      0.29137\n",
      "2017-11-03 17:57:19,180 INFO : train_misclass            0.10864\n",
      "2017-11-03 17:57:19,181 INFO : valid_misclass            0.18436\n",
      "2017-11-03 17:57:19,182 INFO : test_misclass             0.23750\n",
      "2017-11-03 17:57:19,182 INFO : runtime                   1.94147\n",
      "2017-11-03 17:57:19,183 INFO : \n",
      "2017-11-03 17:57:19,184 INFO : Setup for second stop...\n",
      "2017-11-03 17:57:19,188 INFO : Train loss to reach 0.53394\n",
      "2017-11-03 17:57:19,189 INFO : Run until second stop...\n",
      "2017-11-03 17:57:20,151 INFO : Epoch 9\n",
      "2017-11-03 17:57:20,153 INFO : train_loss                0.53530\n",
      "2017-11-03 17:57:20,154 INFO : valid_loss                0.54075\n",
      "2017-11-03 17:57:20,154 INFO : test_loss                 0.69350\n",
      "2017-11-03 17:57:20,155 INFO : train_sample_misclass     0.22843\n",
      "2017-11-03 17:57:20,156 INFO : valid_sample_misclass     0.24399\n",
      "2017-11-03 17:57:20,157 INFO : test_sample_misclass      0.35034\n",
      "2017-11-03 17:57:20,157 INFO : train_misclass            0.15273\n",
      "2017-11-03 17:57:20,158 INFO : valid_misclass            0.14525\n",
      "2017-11-03 17:57:20,159 INFO : test_misclass             0.21250\n",
      "2017-11-03 17:57:20,160 INFO : runtime                   0.80763\n",
      "2017-11-03 17:57:20,161 INFO : \n",
      "2017-11-03 17:57:21,603 INFO : Time only for training updates: 1.16s\n",
      "2017-11-03 17:57:22,533 INFO : Epoch 10\n",
      "2017-11-03 17:57:22,534 INFO : train_loss                0.54183\n",
      "2017-11-03 17:57:22,535 INFO : valid_loss                0.50682\n",
      "2017-11-03 17:57:22,536 INFO : test_loss                 0.64068\n",
      "2017-11-03 17:57:22,537 INFO : train_sample_misclass     0.22782\n",
      "2017-11-03 17:57:22,537 INFO : valid_sample_misclass     0.23420\n",
      "2017-11-03 17:57:22,538 INFO : test_sample_misclass      0.31642\n",
      "2017-11-03 17:57:22,539 INFO : train_misclass            0.18283\n",
      "2017-11-03 17:57:22,540 INFO : valid_misclass            0.15642\n",
      "2017-11-03 17:57:22,540 INFO : test_misclass             0.24375\n",
      "2017-11-03 17:57:22,541 INFO : runtime                   2.40818\n",
      "2017-11-03 17:57:22,542 INFO : \n"
     ]
    }
   ],
   "source": [
    "exp.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We arrive at ca. 80% accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to do trialwise decoding instead of cropped decoding, perform the following changes:\n",
    "\n",
    "\n",
    "Change:\n",
    "```python\n",
    "# This will determine how many crops are processed in parallel\n",
    "input_time_length = 800\n",
    "in_chans = 3\n",
    "n_classes = 4\n",
    "# final_conv_length determines the size of the receptive field of the ConvNet\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,\n",
    "                        final_conv_length=30).create_network()\n",
    "```\n",
    "\n",
    "to:\n",
    "```python\n",
    "# This will determine how many crops are processed in parallel\n",
    "input_time_length = train_set.X.shape[2]\n",
    "in_chans = 3\n",
    "n_classes = 4\n",
    "# final_conv_length determines the size of the receptive field of the ConvNet\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,\n",
    "                        final_conv_length='auto').create_network()\n",
    "```\n",
    "\n",
    "Remove:\n",
    "\n",
    "```python\n",
    "to_dense_prediction_model(model)\n",
    "```\n",
    "\n",
    "Remove:\n",
    "\n",
    "\n",
    "```python\n",
    "from braindecode.torch_ext.util import np_to_var\n",
    "# determine output size\n",
    "test_input = np_to_var(np.ones((2, 3, input_time_length, 1), dtype=np.float32))\n",
    "if cuda:\n",
    "    test_input = test_input.cuda()\n",
    "out = model(test_input)\n",
    "n_preds_per_input = out.cpu().data.numpy().shape[2]\n",
    "print(\"{:d} predictions per input/trial\".format(n_preds_per_input))\n",
    "```\n",
    "\n",
    "Change:\n",
    "\n",
    "```python\n",
    "from braindecode.datautil.iterators import CropsFromTrialsIterator\n",
    "iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,\n",
    "                                  n_preds_per_input=n_preds_per_input)\n",
    "```\n",
    "\n",
    "to:\n",
    "\n",
    "```python\n",
    "from braindecode.datautil.iterators import BalancedBatchSizeIterator\n",
    "iterator = BalancedBatchSizeIterator(batch_size=32)\n",
    "```\n",
    "\n",
    "\n",
    "Change:\n",
    "\n",
    "```python\n",
    "loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2)[:,:,0], targets)\n",
    "```\n",
    "\n",
    "to:\n",
    "\n",
    "```python\n",
    "loss_function = F.nll_loss\n",
    "```\n",
    "\n",
    "Change:\n",
    "\n",
    "```python\n",
    "monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),\n",
    "            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]\n",
    "```\n",
    "\n",
    "to:\n",
    "\n",
    "```python\n",
    "monitors = [LossMonitor(), MisclassMonitor(col_suffix='misclass'), \n",
    "            RuntimeMonitor(),]\n",
    "```\n",
    "\n",
    "Resulting code can be seen at [BBCI Data Epoched](BBCI_Data_Epoched.html)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
