Trialwise Manual Training Loop

Here, we show the trialwise decoding when you want to write your own training loop. For more simple code with a predefined training loop, see the Trialwise Decoding Tutorial.

In this example, we will use a convolutional neural network on the Physiobank EEG Motor Movement/Imagery Dataset to decode two classes:

  1. Executed and imagined opening and closing of both hands
  2. Executed and imagined opening and closing of both feet
We use only one subject (with 90 trials) in this tutorial for demonstration purposes. A more interesting decoding task with many more trials would be to do cross-subject decoding on the same dataset.

Enable logging

In [2]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

Load data

You can load and preprocess your EEG dataset in any way, Braindecode only expects a 3darray (trials, channels, timesteps) of input signals X and a vector of labels y later (see below). In this tutorial, we will use the MNE library to load an EEG motor imagery/motor execution dataset. For a tutorial from MNE using Common Spatial Patterns to decode this data, see here. For another library useful for loading EEG data, take a look at Neo IO.

In [4]:
import mne
from mne.io import concatenate_raws

# 5,6,7,10,13,14 are codes for executed and imagined hands/feet
subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)
event_codes = [5,6,9,10,13,14]

# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

# Load each of the files
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')
         for path in physionet_paths]

# Concatenate them
raw = concatenate_raws(parts)

# Find the events in this dataset
events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

# Use only EEG channels
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

# Extract trials, only using EEG channels
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,
                baseline=None, preload=True)

Convert data to Braindecode format

Braindecode has a minimalistic SignalAndTarget class, with attributes X for the signal and y for the labels. X should have these dimensions: trials x channels x timesteps. y should have one label per trial.

In [5]:
import numpy as np
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1

We use the first 40 trials for training and the next 30 trials for validation. The validation accuracies can be used to tune hyperparameters such as learning rate etc. The final 20 trials are split apart so we have a final hold-out evaluation set that is not part of any hyperparameter optimization. As mentioned before, this dataset is dangerously small to get any meaningful results and only used here for quick demonstration purposes.

In [6]:
from braindecode.datautil.signal_target import SignalAndTarget

train_set = SignalAndTarget(X[:40], y=y[:40])
valid_set = SignalAndTarget(X[40:70], y=y[40:70])

Create the model

Braindecode comes with some predefined convolutional neural network architectures for raw time-domain EEG. Here, we use the shallow ConvNet model from Deep learning with convolutional neural networks for EEG decoding and visualization.

In [20]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length = auto ensures we only get a single output in the time dimension
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
                        input_time_length=train_set.X.shape[2],
                        final_conv_length='auto').create_network()
if cuda:
    model.cuda()

We use AdamW to optimize the parameters of our network together with Cosine Annealing of the learning rate. We supply some default parameters that we have found to work well for motor decoding, however we strongly encourage you to perform your own hyperparameter optimization using cross validation on your training data.

In [21]:
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.schedulers import ScheduledOptimizer, CosineAnnealing
from braindecode.datautil.iterators import get_balanced_batches
from numpy.random import RandomState
rng = RandomState((2018,8,7))
#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
# Need to determine number of batch passes per epoch for cosine annealing
n_epochs = 30
n_updates_per_epoch = len(list(get_balanced_batches(len(train_set.X), rng, shuffle=True,
                                            batch_size=30)))
scheduler = CosineAnnealing(n_epochs * n_updates_per_epoch)
# schedule_weight_decay must be True for AdamW
optimizer = ScheduledOptimizer(scheduler, optimizer, schedule_weight_decay=True)

Training loop

This is a conventional mini-batch stochastic gradient descent training loop:

  1. Get randomly shuffled batches of trials
  2. Compute outputs, loss and gradients on the batches of trials
  3. Update your model
  4. After iterating through all batches of your dataset, report some statistics like mean accuracy and mean loss.
In [22]:
from braindecode.torch_ext.util import np_to_var, var_to_np
import torch.nn.functional as F
for i_epoch in range(n_epochs):
    i_trials_in_batch = get_balanced_batches(len(train_set.X), rng, shuffle=True,
                                            batch_size=30)
    # Set model to training mode
    model.train()
    for i_trials in i_trials_in_batch:
        # Have to add empty fourth dimension to X
        batch_X = train_set.X[i_trials][:,:,:,None]
        batch_y = train_set.y[i_trials]
        net_in = np_to_var(batch_X)
        if cuda:
            net_in = net_in.cuda()
        net_target = np_to_var(batch_y)
        if cuda:
            net_target = net_target.cuda()
        # Remove gradients of last backward pass from all parameters
        optimizer.zero_grad()
        # Compute outputs of the network
        outputs = model(net_in)
        # Compute the loss
        loss = F.nll_loss(outputs, net_target)
        # Do the backpropagation
        loss.backward()
        # Update parameters with the optimizer
        optimizer.step()

    # Print some statistics each epoch
    model.eval()
    print("Epoch {:d}".format(i_epoch))
    for setname, dataset in (('Train', train_set), ('Valid', valid_set)):
        # Here, we will use the entire dataset at once, which is still possible
        # for such smaller datasets. Otherwise we would have to use batches.
        net_in = np_to_var(dataset.X[:,:,:,None])
        if cuda:
            net_in = net_in.cuda()
        net_target = np_to_var(dataset.y)
        if cuda:
            net_target = net_target.cuda()
        outputs = model(net_in)
        loss = F.nll_loss(outputs, net_target)
        print("{:6s} Loss: {:.5f}".format(
            setname, float(var_to_np(loss))))
        predicted_labels = np.argmax(var_to_np(outputs), axis=1)
        accuracy = np.mean(dataset.y  == predicted_labels)
        print("{:6s} Accuracy: {:.1f}%".format(
            setname, accuracy * 100))
Epoch 0
Train  Loss: 1.38360
Train  Accuracy: 47.5%
Valid  Loss: 1.41585
Valid  Accuracy: 50.0%
Epoch 1
Train  Loss: 0.88988
Train  Accuracy: 60.0%
Valid  Loss: 1.03939
Valid  Accuracy: 56.7%
Epoch 2
Train  Loss: 0.67633
Train  Accuracy: 67.5%
Valid  Loss: 0.94319
Valid  Accuracy: 60.0%
Epoch 3
Train  Loss: 0.41825
Train  Accuracy: 80.0%
Valid  Loss: 0.75822
Valid  Accuracy: 63.3%
Epoch 4
Train  Loss: 0.22624
Train  Accuracy: 90.0%
Valid  Loss: 0.67704
Valid  Accuracy: 66.7%
Epoch 5
Train  Loss: 0.13072
Train  Accuracy: 97.5%
Valid  Loss: 0.62466
Valid  Accuracy: 73.3%
Epoch 6
Train  Loss: 0.10054
Train  Accuracy: 97.5%
Valid  Loss: 0.62027
Valid  Accuracy: 73.3%
Epoch 7
Train  Loss: 0.08371
Train  Accuracy: 97.5%
Valid  Loss: 0.62787
Valid  Accuracy: 73.3%
Epoch 8
Train  Loss: 0.07234
Train  Accuracy: 97.5%
Valid  Loss: 0.62938
Valid  Accuracy: 70.0%
Epoch 9
Train  Loss: 0.06713
Train  Accuracy: 97.5%
Valid  Loss: 0.63169
Valid  Accuracy: 70.0%
Epoch 10
Train  Loss: 0.06175
Train  Accuracy: 97.5%
Valid  Loss: 0.63279
Valid  Accuracy: 70.0%
Epoch 11
Train  Loss: 0.05743
Train  Accuracy: 97.5%
Valid  Loss: 0.62944
Valid  Accuracy: 70.0%
Epoch 12
Train  Loss: 0.05303
Train  Accuracy: 97.5%
Valid  Loss: 0.62220
Valid  Accuracy: 70.0%
Epoch 13
Train  Loss: 0.04724
Train  Accuracy: 97.5%
Valid  Loss: 0.61127
Valid  Accuracy: 70.0%
Epoch 14
Train  Loss: 0.04223
Train  Accuracy: 100.0%
Valid  Loss: 0.59926
Valid  Accuracy: 70.0%
Epoch 15
Train  Loss: 0.03736
Train  Accuracy: 100.0%
Valid  Loss: 0.58510
Valid  Accuracy: 70.0%
Epoch 16
Train  Loss: 0.03282
Train  Accuracy: 100.0%
Valid  Loss: 0.56685
Valid  Accuracy: 70.0%
Epoch 17
Train  Loss: 0.02917
Train  Accuracy: 100.0%
Valid  Loss: 0.54962
Valid  Accuracy: 73.3%
Epoch 18
Train  Loss: 0.02635
Train  Accuracy: 100.0%
Valid  Loss: 0.53484
Valid  Accuracy: 73.3%
Epoch 19
Train  Loss: 0.02404
Train  Accuracy: 100.0%
Valid  Loss: 0.52472
Valid  Accuracy: 73.3%
Epoch 20
Train  Loss: 0.02188
Train  Accuracy: 100.0%
Valid  Loss: 0.51820
Valid  Accuracy: 73.3%
Epoch 21
Train  Loss: 0.02015
Train  Accuracy: 100.0%
Valid  Loss: 0.51563
Valid  Accuracy: 73.3%
Epoch 22
Train  Loss: 0.01874
Train  Accuracy: 100.0%
Valid  Loss: 0.51309
Valid  Accuracy: 73.3%
Epoch 23
Train  Loss: 0.01756
Train  Accuracy: 100.0%
Valid  Loss: 0.51196
Valid  Accuracy: 73.3%
Epoch 24
Train  Loss: 0.01658
Train  Accuracy: 100.0%
Valid  Loss: 0.51188
Valid  Accuracy: 73.3%
Epoch 25
Train  Loss: 0.01572
Train  Accuracy: 100.0%
Valid  Loss: 0.51296
Valid  Accuracy: 73.3%
Epoch 26
Train  Loss: 0.01505
Train  Accuracy: 100.0%
Valid  Loss: 0.51424
Valid  Accuracy: 73.3%
Epoch 27
Train  Loss: 0.01454
Train  Accuracy: 100.0%
Valid  Loss: 0.51548
Valid  Accuracy: 73.3%
Epoch 28
Train  Loss: 0.01416
Train  Accuracy: 100.0%
Valid  Loss: 0.51677
Valid  Accuracy: 73.3%
Epoch 29
Train  Loss: 0.01388
Train  Accuracy: 100.0%
Valid  Loss: 0.51794
Valid  Accuracy: 73.3%

Eventually, we arrive at 73.3% accuracy, so 22 from 30 trials are correctly predicted. In the Cropped Decoding Tutorial, we can learn do the same decoding using Cropped Decoding.

Evaluation

Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:

In [27]:
test_set = SignalAndTarget(X[70:], y=y[70:])

model.eval()
# Here, we will use the entire dataset at once, which is still possible
# for such smaller datasets. Otherwise we would have to use batches.
net_in = np_to_var(test_set.X[:,:,:,None])
if cuda:
    net_in = net_in.cuda()
net_target = np_to_var(test_set.y)
if cuda:
    net_target = net_target.cuda()
outputs = model(net_in)
loss = F.nll_loss(outputs, net_target)
print("Test Loss: {:.5f}".format(float(var_to_np(loss))))
predicted_labels = np.argmax(var_to_np(outputs), axis=1)
accuracy = np.mean(test_set.y  == predicted_labels)
print("Test Accuracy: {:.1f}%".format(accuracy * 100))
Test Loss: 0.22339
Test Accuracy: 95.0%
If you want to try cross-subject decoding, changing the loading code to the following will perform cross-subject decoding on imagined left vs right hand closing, with 50 training and 5 validation subjects (Warning, might be very slow if you are on CPU):
In [ ]:
import mne
import numpy as np
from mne.io import concatenate_raws
from braindecode.datautil.signal_target import SignalAndTarget

# First 50 subjects as train
physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]
physionet_paths = np.concatenate(physionet_paths)
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
         for path in physionet_paths]

raw = concatenate_raws(parts)

picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,
                baseline=None, preload=True)

# 51-55 as validation subjects
physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]
physionet_paths_valid = np.concatenate(physionet_paths_valid)
parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
         for path in physionet_paths_valid]
raw_valid = concatenate_raws(parts_valid)

picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,
                baseline=None, preload=True)

train_X = (epoched.get_data() * 1e6).astype(np.float32)
train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)
valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
train_set = SignalAndTarget(train_X, y=train_y)
valid_set = SignalAndTarget(valid_X, y=valid_y)

Dataset references

This dataset was created and contributed to PhysioNet by the developers of the BCI2000 instrumentation system, which they used in making these recordings. The system is described in:

Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.

PhysioBank is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:

Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220.