Cropped Manual Training Loop

Here, we show the cropped decoding when you want to write your own training loop. For more simple code with a predefined training loop and an explanation of cropped decoding in general, see the Cropped Decoding Tutorial.

Most of the code for cropped decoding is identical to the Trialwise Manual Training Loop Tutorial, differences are explained in the text.

Load data

In [2]:
import mne
from mne.io import concatenate_raws

# 5,6,7,10,13,14 are codes for executed and imagined hands/feet
subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)
event_codes = [5,6,9,10,13,14]

# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

# Load each of the files
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')
         for path in physionet_paths]

# Concatenate them
raw = concatenate_raws(parts)

# Find the events in this dataset
events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

# Use only EEG channels
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

# Extract trials, only using EEG channels
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,
                baseline=None, preload=True)

Convert data to Braindecode format

In [3]:
import numpy as np
from braindecode.datautil.signal_target import SignalAndTarget
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1

train_set = SignalAndTarget(X[:40], y=y[:40])
valid_set = SignalAndTarget(X[40:70], y=y[40:70])

Create the model

For cropped decoding, we now transform the model into a model that outputs a dense time series of predictions. For this, we manually set the length of the final convolution layer to some length that makes the receptive field of the ConvNet smaller than the number of samples in a trial. Also, we use to_dense_prediction_model, which removes the strides in the ConvNet and instead uses dilated convolutions to get a dense output (see Multi-Scale Context Aggregation by Dilated Convolutions and our paper Deep learning with convolutional neural networks for EEG decoding and visualization Section 2.5.4 for some background on this).

In [4]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 450
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,
                        final_conv_length=12).create_network()
to_dense_prediction_model(model)

if cuda:
    model.cuda()


Create cropped iterator

For extracting crops from the trials, Braindecode provides the CropsFromTrialsIterator? class. This class needs to know the input time length of the inputs you put into the network and the number of predictions that the ConvNet will output per input. You can determine the number of predictions by passing dummy data through the ConvNet:

In [5]:
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))
187 predictions per input/trial
In [6]:
from braindecode.datautil.iterators import CropsFromTrialsIterator
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
                                  n_preds_per_input=n_preds_per_input)

The iterator has the method get_batches, which can be used to get randomly shuffled training batches with shuffle=True or ordered batches (i.e. first from trial 1, then from trial 2, etc.) with shuffle=False. Additionally, Braindecode provides the compute_preds_per_trial_for_set method, which accepts predictions from the ordered batches and returns predictions per trial. It removes any overlapping predictions, which occur if the number of predictions per input is not a divisor of the number of samples in a trial.

These methods can also work with trials of different lengths! For different-length trials, set X to be a list of 2d-arrays instead of a 3d-array.

We now can set the optimizer, since we can compute the number of batches per epoch using the iterator.

In [7]:
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.schedulers import ScheduledOptimizer, CosineAnnealing
from braindecode.datautil.iterators import get_balanced_batches
from numpy.random import RandomState
rng = RandomState((2018,8,7))
#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)
# Need to determine number of batch passes per epoch for cosine annealing
n_epochs = 30
n_updates_per_epoch = len([None for b in iterator.get_batches(train_set, True)])
scheduler = CosineAnnealing(n_epochs * n_updates_per_epoch)
# schedule_weight_decay must be True for AdamW
optimizer = ScheduledOptimizer(scheduler, optimizer, schedule_weight_decay=True)

Training loop

The code below uses both the cropped iterator and the compute_preds_per_trial_from_crops function to train and evaluate the network.

In [8]:
from braindecode.torch_ext.util import np_to_var, var_to_np
import torch.nn.functional as F
from numpy.random import RandomState
import torch as th
from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
rng = RandomState((2017,6,30))
for i_epoch in range(20):
    # Set model to training mode
    model.train()
    for batch_X, batch_y in iterator.get_batches(train_set, shuffle=True):
        net_in = np_to_var(batch_X)
        if cuda:
            net_in = net_in.cuda()
        net_target = np_to_var(batch_y)
        if cuda:
            net_target = net_target.cuda()
        # Remove gradients of last backward pass from all parameters
        optimizer.zero_grad()
        outputs = model(net_in)
        # Mean predictions across trial
        # Note that this will give identical gradients to computing
        # a per-prediction loss (at least for the combination of log softmax activation
        # and negative log likelihood loss which we are using here)
        outputs = th.mean(outputs, dim=2, keepdim=False)
        loss = F.nll_loss(outputs, net_target)
        loss.backward()
        optimizer.step()

    # Print some statistics each epoch
    model.eval()
    print("Epoch {:d}".format(i_epoch))
    for setname, dataset in (('Train', train_set),('Valid', valid_set)):
        # Collect all predictions and losses
        all_preds = []
        all_losses = []
        batch_sizes = []
        for batch_X, batch_y in iterator.get_batches(dataset, shuffle=False):
            net_in = np_to_var(batch_X)
            if cuda:
                net_in = net_in.cuda()
            net_target = np_to_var(batch_y)
            if cuda:
                net_target = net_target.cuda()
            outputs = model(net_in)
            all_preds.append(var_to_np(outputs))
            outputs = th.mean(outputs, dim=2, keepdim=False)
            loss = F.nll_loss(outputs, net_target)
            loss = float(var_to_np(loss))
            all_losses.append(loss)
            batch_sizes.append(len(batch_X))
        # Compute mean per-input loss
        loss = np.mean(np.array(all_losses) * np.array(batch_sizes) /
                       np.mean(batch_sizes))
        print("{:6s} Loss: {:.5f}".format(setname, loss))
        # Assign the predictions to the trials
        preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
                                                          input_time_length,
                                                          dataset.X)
        # preds per trial are now trials x classes x timesteps/predictions
        # Now mean across timesteps for each trial to get per-trial predictions
        meaned_preds_per_trial = np.array([np.mean(p, axis=1) for p in preds_per_trial])
        predicted_labels = np.argmax(meaned_preds_per_trial, axis=1)
        accuracy = np.mean(predicted_labels == dataset.y)
        print("{:6s} Accuracy: {:.1f}%".format(
            setname, accuracy * 100))
Epoch 0
Train  Loss: 3.79010
Train  Accuracy: 50.0%
Valid  Loss: 3.12765
Valid  Accuracy: 46.7%
Epoch 1
Train  Loss: 1.91778
Train  Accuracy: 50.0%
Valid  Loss: 1.52058
Valid  Accuracy: 46.7%
Epoch 2
Train  Loss: 1.09943
Train  Accuracy: 60.0%
Valid  Loss: 0.99602
Valid  Accuracy: 56.7%
Epoch 3
Train  Loss: 0.73015
Train  Accuracy: 67.5%
Valid  Loss: 0.82321
Valid  Accuracy: 56.7%
Epoch 4
Train  Loss: 0.55965
Train  Accuracy: 75.0%
Valid  Loss: 0.77549
Valid  Accuracy: 63.3%
Epoch 5
Train  Loss: 0.38075
Train  Accuracy: 82.5%
Valid  Loss: 0.65490
Valid  Accuracy: 73.3%
Epoch 6
Train  Loss: 0.28886
Train  Accuracy: 90.0%
Valid  Loss: 0.59481
Valid  Accuracy: 73.3%
Epoch 7
Train  Loss: 0.21871
Train  Accuracy: 97.5%
Valid  Loss: 0.52357
Valid  Accuracy: 83.3%
Epoch 8
Train  Loss: 0.16713
Train  Accuracy: 97.5%
Valid  Loss: 0.44890
Valid  Accuracy: 86.7%
Epoch 9
Train  Loss: 0.13149
Train  Accuracy: 97.5%
Valid  Loss: 0.40492
Valid  Accuracy: 83.3%
Epoch 10
Train  Loss: 0.09581
Train  Accuracy: 100.0%
Valid  Loss: 0.36021
Valid  Accuracy: 90.0%
Epoch 11
Train  Loss: 0.07818
Train  Accuracy: 100.0%
Valid  Loss: 0.34625
Valid  Accuracy: 90.0%
Epoch 12
Train  Loss: 0.07454
Train  Accuracy: 100.0%
Valid  Loss: 0.34489
Valid  Accuracy: 90.0%
Epoch 13
Train  Loss: 0.06694
Train  Accuracy: 100.0%
Valid  Loss: 0.33878
Valid  Accuracy: 90.0%
Epoch 14
Train  Loss: 0.05971
Train  Accuracy: 100.0%
Valid  Loss: 0.33128
Valid  Accuracy: 90.0%
Epoch 15
Train  Loss: 0.05269
Train  Accuracy: 100.0%
Valid  Loss: 0.32202
Valid  Accuracy: 90.0%
Epoch 16
Train  Loss: 0.04354
Train  Accuracy: 100.0%
Valid  Loss: 0.31063
Valid  Accuracy: 90.0%
Epoch 17
Train  Loss: 0.03759
Train  Accuracy: 100.0%
Valid  Loss: 0.30314
Valid  Accuracy: 90.0%
Epoch 18
Train  Loss: 0.03401
Train  Accuracy: 100.0%
Valid  Loss: 0.29997
Valid  Accuracy: 90.0%
Epoch 19
Train  Loss: 0.03145
Train  Accuracy: 100.0%
Valid  Loss: 0.29764
Valid  Accuracy: 90.0%

Eventually, we arrive at 90.0% accuracy, so 27 from 30 trials are correctly predicted, 5 more than for the trialwise decoding method.

Evaluation

Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:

In [9]:
test_set = SignalAndTarget(X[70:], y=y[70:])

model.eval()
# Collect all predictions and losses
all_preds = []
all_losses = []
batch_sizes = []
for batch_X, batch_y in iterator.get_batches(test_set, shuffle=False):
    net_in = np_to_var(batch_X)
    if cuda:
        net_in = net_in.cuda()
    net_target = np_to_var(batch_y)
    if cuda:
        net_target = net_target.cuda()
    outputs = model(net_in)
    all_preds.append(var_to_np(outputs))
    outputs = th.mean(outputs, dim=2, keepdim=False)
    loss = F.nll_loss(outputs, net_target)
    loss = float(var_to_np(loss))
    all_losses.append(loss)
    batch_sizes.append(len(batch_X))
# Compute mean per-input loss
loss = np.mean(np.array(all_losses) * np.array(batch_sizes) /
               np.mean(batch_sizes))
print("Test Loss: {:.5f}".format(loss))
# Assign the predictions to the trials
preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
                                                  input_time_length,
                                                  test_set.X)
# preds per trial are now trials x classes x timesteps/predictions
# Now mean across timesteps for each trial to get per-trial predictions
meaned_preds_per_trial = np.array([np.mean(p, axis=1) for p in preds_per_trial])
predicted_labels = np.argmax(meaned_preds_per_trial, axis=1)
accuracy = np.mean(predicted_labels == test_set.y)
print("Test Accuracy: {:.1f}%".format(accuracy * 100))
Test Loss: 0.42105
Test Accuracy: 85.0%

Dataset references

This dataset was created and contributed to PhysioNet by the developers of the BCI2000 instrumentation system, which they used in making these recordings. The system is described in:

Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.

PhysioBank is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:

Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220.