Read and Decode BBCI Data¶
This tutorial shows how to read and decode BBCI data.
Setup logging to see outputs¶
In [2]:
import logging
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.DEBUG, stream=sys.stdout)
log = logging.getLogger()
Load and preprocess data¶
First set the filename and the sensors you want to load. If you set
load_sensor_names=None
or just remove the parameter from the function call, all sensors will be loaded.
In [3]:
from braindecode.datasets.bbci import BBCIDataset
train_filename = '/home/schirrmr/data/BBCI-without-last-runs/BhNoMoSc1S001R01_ds10_1-12.BBCI.mat'
cnt = BBCIDataset(train_filename, load_sensor_names=['C3', 'CPz', 'C4']).load()
Creating RawArray with float64 data, n_channels=3, n_times=3451320
Range : 0 ... 3451319 = 0.000 ... 6902.638 secs
Ready.
Preprocessing on continous data¶
First remove the stimulus channel, than apply any preprocessing you
like. There are some very few directions available from Braindecode,
such as resample_cnt. But you can apply any function on the chan x time
matrix of the mne raw object (cnt in the code) by calling
mne_apply with two arguments:
- Your function (2d-array-> 2darray), that transforms the channel x timesteps data array
- the Raw data object from mne itself
In [4]:
from braindecode.mne_ext.signalproc import resample_cnt, mne_apply
from braindecode.datautil.signalproc import exponential_running_standardize
# Remove stimulus channel
cnt = cnt.drop_channels(['STI 014'])
cnt = resample_cnt(cnt, 250)
# mne apply will apply the function to the data (a 2d-numpy-array)
# have to transpose data back and forth, since
# exponential_running_standardize expects time x chans order
# while mne object has chans x time order
cnt = mne_apply(lambda a: exponential_running_standardize(
a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,
cnt)
2017-11-03 17:56:32,890 WARNING : This is not causal, uses future data....
2017-11-03 17:56:32,891 INFO : Resampling from 500.000000 to 250.000000 Hz.
Creating RawArray with float64 data, n_channels=3, n_times=1725660
Range : 0 ... 1725659 = 0.000 ... 6902.636 secs
Ready.
Transform to epoched dataset¶
Braindecode supplies the create_signal_target_from_raw_mne function,
which will transform the mne raw object into a SignalAndTarget
object for use in Braindecode. name_to_code should be an
OrderedDict that maps class names to either one or a list of marker
codes for that class.
In [5]:
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne
from collections import OrderedDict
# can also give lists of marker codes in case a class has multiple marker codes...
name_to_code = OrderedDict([('Right', 1), ('Left', 2), ('Rest', 3), ('Feet', 4)])
segment_ival_ms = [-500,4000]
train_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)
2017-11-03 17:56:34,795 INFO : Trial per class:
Counter({'Feet': 225, 'Right': 224, 'Rest': 224, 'Left': 224})
Same for test set¶
In [6]:
test_filename = '/home/schirrmr/data/BBCI-only-last-runs/BhNoMoSc1S001R13_ds10_1-2BBCI.mat'
cnt = BBCIDataset(test_filename, load_sensor_names=['C3', 'CPz', 'C4']).load()
# Remove stimulus channel
cnt = cnt.drop_channels(['STI 014'])
cnt = resample_cnt(cnt, 250)
cnt = mne_apply(lambda a: exponential_running_standardize(
a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,
cnt)
test_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)
Creating RawArray with float64 data, n_channels=3, n_times=617090
Range : 0 ... 617089 = 0.000 ... 1234.178 secs
Ready.
2017-11-03 17:56:35,707 WARNING : This is not causal, uses future data....
2017-11-03 17:56:35,708 INFO : Resampling from 500.000000 to 250.000000 Hz.
Creating RawArray with float64 data, n_channels=3, n_times=308545
Range : 0 ... 308544 = 0.000 ... 1234.176 secs
Ready.
2017-11-03 17:56:36,026 INFO : Trial per class:
Counter({'Feet': 40, 'Left': 40, 'Rest': 40, 'Right': 40})
name_to_stop_codes
dictionary (same as for the start codes in this example) as a final
argument to create_signal_target_from_raw_mne. See Read and Decode
BBCI Data with Start-Stop-Markers
TutorialSplit off a validation set.
In [7]:
from braindecode.datautil.splitters import split_into_two_sets
train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)
Create the model¶
In [8]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = True
set_random_seeds(seed=20170629, cuda=cuda)
# This will determine how many crops are processed in parallel
input_time_length = 800
in_chans = 3
n_classes = 4
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,
final_conv_length=30).create_network()
to_dense_prediction_model(model)
if cuda:
model.cuda()
Setup optimizer and iterator¶
In [9]:
from torch import optim
import numpy as np
optimizer = optim.Adam(model.parameters())
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(np.ones((2, 3, input_time_length, 1), dtype=np.float32))
if cuda:
test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))
from braindecode.datautil.iterators import CropsFromTrialsIterator
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input)
267 predictions per input/trial
Setup Monitors, Loss function, Stop Criteria¶
In [10]:
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression
loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2).squeeze(), targets)
model_constraint = None
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]
stop_criterion = MaxEpochs(20)
exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,
monitors, stop_criterion, remember_best_column='valid_misclass',
run_after_early_stop=True, batch_modifier=None, cuda=cuda)
Run experiment¶
In [11]:
exp.run()
2017-11-03 17:56:39,458 INFO : Run until first stop...
2017-11-03 17:56:40,298 INFO : Epoch 0
2017-11-03 17:56:40,299 INFO : train_loss 7.89184
2017-11-03 17:56:40,300 INFO : valid_loss 7.72731
2017-11-03 17:56:40,301 INFO : test_loss 7.75617
2017-11-03 17:56:40,303 INFO : train_sample_misclass 0.75013
2017-11-03 17:56:40,304 INFO : valid_sample_misclass 0.74856
2017-11-03 17:56:40,305 INFO : test_sample_misclass 0.75115
2017-11-03 17:56:40,306 INFO : train_misclass 0.75070
2017-11-03 17:56:40,308 INFO : valid_misclass 0.74860
2017-11-03 17:56:40,309 INFO : test_misclass 0.75000
2017-11-03 17:56:40,310 INFO : runtime 0.00000
2017-11-03 17:56:40,311 INFO :
2017-11-03 17:56:40,313 INFO : New best valid_misclass: 0.748603
2017-11-03 17:56:40,314 INFO :
2017-11-03 17:56:41,475 INFO : Time only for training updates: 0.94s
2017-11-03 17:56:42,262 INFO : Epoch 1
2017-11-03 17:56:42,263 INFO : train_loss 0.79550
2017-11-03 17:56:42,264 INFO : valid_loss 0.79273
2017-11-03 17:56:42,265 INFO : test_loss 0.83673
2017-11-03 17:56:42,266 INFO : train_sample_misclass 0.36961
2017-11-03 17:56:42,267 INFO : valid_sample_misclass 0.37374
2017-11-03 17:56:42,268 INFO : test_sample_misclass 0.42969
2017-11-03 17:56:42,268 INFO : train_misclass 0.30223
2017-11-03 17:56:42,269 INFO : valid_misclass 0.26257
2017-11-03 17:56:42,270 INFO : test_misclass 0.33750
2017-11-03 17:56:42,271 INFO : runtime 2.01650
2017-11-03 17:56:42,271 INFO :
2017-11-03 17:56:42,274 INFO : New best valid_misclass: 0.262570
2017-11-03 17:56:42,274 INFO :
2017-11-03 17:56:43,423 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:44,209 INFO : Epoch 2
2017-11-03 17:56:44,211 INFO : train_loss 0.68180
2017-11-03 17:56:44,212 INFO : valid_loss 0.65888
2017-11-03 17:56:44,213 INFO : test_loss 0.74155
2017-11-03 17:56:44,214 INFO : train_sample_misclass 0.29325
2017-11-03 17:56:44,215 INFO : valid_sample_misclass 0.30236
2017-11-03 17:56:44,216 INFO : test_sample_misclass 0.37317
2017-11-03 17:56:44,217 INFO : train_misclass 0.22563
2017-11-03 17:56:44,218 INFO : valid_misclass 0.21229
2017-11-03 17:56:44,219 INFO : test_misclass 0.29375
2017-11-03 17:56:44,220 INFO : runtime 1.94802
2017-11-03 17:56:44,221 INFO :
2017-11-03 17:56:44,223 INFO : New best valid_misclass: 0.212291
2017-11-03 17:56:44,224 INFO :
2017-11-03 17:56:45,377 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:46,152 INFO : Epoch 3
2017-11-03 17:56:46,153 INFO : train_loss 0.65244
2017-11-03 17:56:46,154 INFO : valid_loss 0.68508
2017-11-03 17:56:46,155 INFO : test_loss 0.81607
2017-11-03 17:56:46,156 INFO : train_sample_misclass 0.27737
2017-11-03 17:56:46,157 INFO : valid_sample_misclass 0.32108
2017-11-03 17:56:46,157 INFO : test_sample_misclass 0.43695
2017-11-03 17:56:46,158 INFO : train_misclass 0.20195
2017-11-03 17:56:46,159 INFO : valid_misclass 0.25140
2017-11-03 17:56:46,160 INFO : test_misclass 0.36250
2017-11-03 17:56:46,160 INFO : runtime 1.95389
2017-11-03 17:56:46,161 INFO :
2017-11-03 17:56:47,319 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:48,106 INFO : Epoch 4
2017-11-03 17:56:48,107 INFO : train_loss 0.59123
2017-11-03 17:56:48,108 INFO : valid_loss 0.57887
2017-11-03 17:56:48,109 INFO : test_loss 0.71750
2017-11-03 17:56:48,110 INFO : train_sample_misclass 0.25470
2017-11-03 17:56:48,110 INFO : valid_sample_misclass 0.25364
2017-11-03 17:56:48,111 INFO : test_sample_misclass 0.35189
2017-11-03 17:56:48,112 INFO : train_misclass 0.20474
2017-11-03 17:56:48,113 INFO : valid_misclass 0.20670
2017-11-03 17:56:48,113 INFO : test_misclass 0.30625
2017-11-03 17:56:48,114 INFO : runtime 1.94214
2017-11-03 17:56:48,115 INFO :
2017-11-03 17:56:48,117 INFO : New best valid_misclass: 0.206704
2017-11-03 17:56:48,118 INFO :
2017-11-03 17:56:49,274 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:50,058 INFO : Epoch 5
2017-11-03 17:56:50,060 INFO : train_loss 0.76997
2017-11-03 17:56:50,060 INFO : valid_loss 0.72850
2017-11-03 17:56:50,061 INFO : test_loss 0.83681
2017-11-03 17:56:50,062 INFO : train_sample_misclass 0.33041
2017-11-03 17:56:50,063 INFO : valid_sample_misclass 0.32293
2017-11-03 17:56:50,063 INFO : test_sample_misclass 0.39756
2017-11-03 17:56:50,064 INFO : train_misclass 0.25487
2017-11-03 17:56:50,065 INFO : valid_misclass 0.29050
2017-11-03 17:56:50,066 INFO : test_misclass 0.34375
2017-11-03 17:56:50,066 INFO : runtime 1.95470
2017-11-03 17:56:50,067 INFO :
2017-11-03 17:56:51,213 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:51,992 INFO : Epoch 6
2017-11-03 17:56:51,994 INFO : train_loss 0.61221
2017-11-03 17:56:51,995 INFO : valid_loss 0.62389
2017-11-03 17:56:51,996 INFO : test_loss 0.72417
2017-11-03 17:56:51,996 INFO : train_sample_misclass 0.26388
2017-11-03 17:56:51,997 INFO : valid_sample_misclass 0.28059
2017-11-03 17:56:51,998 INFO : test_sample_misclass 0.34581
2017-11-03 17:56:51,999 INFO : train_misclass 0.21031
2017-11-03 17:56:51,999 INFO : valid_misclass 0.22346
2017-11-03 17:56:52,000 INFO : test_misclass 0.30625
2017-11-03 17:56:52,001 INFO : runtime 1.93918
2017-11-03 17:56:52,002 INFO :
2017-11-03 17:56:53,152 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:53,932 INFO : Epoch 7
2017-11-03 17:56:53,933 INFO : train_loss 0.57982
2017-11-03 17:56:53,934 INFO : valid_loss 0.60830
2017-11-03 17:56:53,936 INFO : test_loss 0.73615
2017-11-03 17:56:53,937 INFO : train_sample_misclass 0.25238
2017-11-03 17:56:53,938 INFO : valid_sample_misclass 0.30217
2017-11-03 17:56:53,939 INFO : test_sample_misclass 0.36385
2017-11-03 17:56:53,941 INFO : train_misclass 0.17827
2017-11-03 17:56:53,942 INFO : valid_misclass 0.21788
2017-11-03 17:56:53,943 INFO : test_misclass 0.31875
2017-11-03 17:56:53,944 INFO : runtime 1.93903
2017-11-03 17:56:53,946 INFO :
2017-11-03 17:56:55,088 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:55,866 INFO : Epoch 8
2017-11-03 17:56:55,867 INFO : train_loss 0.53394
2017-11-03 17:56:55,868 INFO : valid_loss 0.54075
2017-11-03 17:56:55,868 INFO : test_loss 0.69350
2017-11-03 17:56:55,869 INFO : train_sample_misclass 0.22455
2017-11-03 17:56:55,870 INFO : valid_sample_misclass 0.24399
2017-11-03 17:56:55,871 INFO : test_sample_misclass 0.35034
2017-11-03 17:56:55,872 INFO : train_misclass 0.15460
2017-11-03 17:56:55,872 INFO : valid_misclass 0.14525
2017-11-03 17:56:55,873 INFO : test_misclass 0.21250
2017-11-03 17:56:55,874 INFO : runtime 1.93581
2017-11-03 17:56:55,875 INFO :
2017-11-03 17:56:55,877 INFO : New best valid_misclass: 0.145251
2017-11-03 17:56:55,878 INFO :
2017-11-03 17:56:57,025 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:57,808 INFO : Epoch 9
2017-11-03 17:56:57,809 INFO : train_loss 0.52253
2017-11-03 17:56:57,810 INFO : valid_loss 0.55157
2017-11-03 17:56:57,811 INFO : test_loss 0.66259
2017-11-03 17:56:57,812 INFO : train_sample_misclass 0.20662
2017-11-03 17:56:57,813 INFO : valid_sample_misclass 0.23934
2017-11-03 17:56:57,813 INFO : test_sample_misclass 0.31630
2017-11-03 17:56:57,814 INFO : train_misclass 0.12535
2017-11-03 17:56:57,815 INFO : valid_misclass 0.15084
2017-11-03 17:56:57,816 INFO : test_misclass 0.19375
2017-11-03 17:56:57,816 INFO : runtime 1.93780
2017-11-03 17:56:57,817 INFO :
2017-11-03 17:56:58,966 INFO : Time only for training updates: 0.92s
2017-11-03 17:56:59,747 INFO : Epoch 10
2017-11-03 17:56:59,749 INFO : train_loss 0.63435
2017-11-03 17:56:59,749 INFO : valid_loss 0.54363
2017-11-03 17:56:59,750 INFO : test_loss 0.66703
2017-11-03 17:56:59,751 INFO : train_sample_misclass 0.23284
2017-11-03 17:56:59,752 INFO : valid_sample_misclass 0.25216
2017-11-03 17:56:59,752 INFO : test_sample_misclass 0.31337
2017-11-03 17:56:59,753 INFO : train_misclass 0.19081
2017-11-03 17:56:59,754 INFO : valid_misclass 0.16760
2017-11-03 17:56:59,755 INFO : test_misclass 0.21875
2017-11-03 17:56:59,755 INFO : runtime 1.94030
2017-11-03 17:56:59,756 INFO :
2017-11-03 17:57:00,903 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:01,683 INFO : Epoch 11
2017-11-03 17:57:01,684 INFO : train_loss 0.50823
2017-11-03 17:57:01,685 INFO : valid_loss 0.57940
2017-11-03 17:57:01,686 INFO : test_loss 0.65863
2017-11-03 17:57:01,686 INFO : train_sample_misclass 0.21697
2017-11-03 17:57:01,687 INFO : valid_sample_misclass 0.26513
2017-11-03 17:57:01,688 INFO : test_sample_misclass 0.33127
2017-11-03 17:57:01,689 INFO : train_misclass 0.13510
2017-11-03 17:57:01,689 INFO : valid_misclass 0.18994
2017-11-03 17:57:01,690 INFO : test_misclass 0.24375
2017-11-03 17:57:01,691 INFO : runtime 1.93741
2017-11-03 17:57:01,692 INFO :
2017-11-03 17:57:02,839 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:03,619 INFO : Epoch 12
2017-11-03 17:57:03,621 INFO : train_loss 0.51300
2017-11-03 17:57:03,622 INFO : valid_loss 0.54631
2017-11-03 17:57:03,623 INFO : test_loss 0.64827
2017-11-03 17:57:03,624 INFO : train_sample_misclass 0.20628
2017-11-03 17:57:03,624 INFO : valid_sample_misclass 0.23820
2017-11-03 17:57:03,625 INFO : test_sample_misclass 0.32021
2017-11-03 17:57:03,626 INFO : train_misclass 0.13649
2017-11-03 17:57:03,627 INFO : valid_misclass 0.15642
2017-11-03 17:57:03,628 INFO : test_misclass 0.24375
2017-11-03 17:57:03,628 INFO : runtime 1.93566
2017-11-03 17:57:03,629 INFO :
2017-11-03 17:57:04,777 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:05,568 INFO : Epoch 13
2017-11-03 17:57:05,569 INFO : train_loss 0.53439
2017-11-03 17:57:05,570 INFO : valid_loss 0.57364
2017-11-03 17:57:05,571 INFO : test_loss 0.75146
2017-11-03 17:57:05,572 INFO : train_sample_misclass 0.22174
2017-11-03 17:57:05,573 INFO : valid_sample_misclass 0.25399
2017-11-03 17:57:05,573 INFO : test_sample_misclass 0.36589
2017-11-03 17:57:05,574 INFO : train_misclass 0.14903
2017-11-03 17:57:05,575 INFO : valid_misclass 0.16201
2017-11-03 17:57:05,576 INFO : test_misclass 0.27500
2017-11-03 17:57:05,576 INFO : runtime 1.93826
2017-11-03 17:57:05,577 INFO :
2017-11-03 17:57:06,724 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:07,510 INFO : Epoch 14
2017-11-03 17:57:07,511 INFO : train_loss 0.49450
2017-11-03 17:57:07,512 INFO : valid_loss 0.50904
2017-11-03 17:57:07,513 INFO : test_loss 0.59711
2017-11-03 17:57:07,514 INFO : train_sample_misclass 0.19409
2017-11-03 17:57:07,514 INFO : valid_sample_misclass 0.23406
2017-11-03 17:57:07,515 INFO : test_sample_misclass 0.29099
2017-11-03 17:57:07,516 INFO : train_misclass 0.11421
2017-11-03 17:57:07,517 INFO : valid_misclass 0.15084
2017-11-03 17:57:07,517 INFO : test_misclass 0.21875
2017-11-03 17:57:07,518 INFO : runtime 1.94703
2017-11-03 17:57:07,519 INFO :
2017-11-03 17:57:08,667 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:09,447 INFO : Epoch 15
2017-11-03 17:57:09,448 INFO : train_loss 0.55814
2017-11-03 17:57:09,449 INFO : valid_loss 0.52381
2017-11-03 17:57:09,450 INFO : test_loss 0.63100
2017-11-03 17:57:09,450 INFO : train_sample_misclass 0.19992
2017-11-03 17:57:09,451 INFO : valid_sample_misclass 0.22669
2017-11-03 17:57:09,452 INFO : test_sample_misclass 0.30698
2017-11-03 17:57:09,453 INFO : train_misclass 0.14624
2017-11-03 17:57:09,453 INFO : valid_misclass 0.16201
2017-11-03 17:57:09,454 INFO : test_misclass 0.23750
2017-11-03 17:57:09,455 INFO : runtime 1.94305
2017-11-03 17:57:09,456 INFO :
2017-11-03 17:57:10,604 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:11,388 INFO : Epoch 16
2017-11-03 17:57:11,389 INFO : train_loss 0.54798
2017-11-03 17:57:11,390 INFO : valid_loss 0.60483
2017-11-03 17:57:11,391 INFO : test_loss 0.70932
2017-11-03 17:57:11,392 INFO : train_sample_misclass 0.23531
2017-11-03 17:57:11,393 INFO : valid_sample_misclass 0.26090
2017-11-03 17:57:11,394 INFO : test_sample_misclass 0.35440
2017-11-03 17:57:11,396 INFO : train_misclass 0.19359
2017-11-03 17:57:11,397 INFO : valid_misclass 0.22905
2017-11-03 17:57:11,398 INFO : test_misclass 0.32500
2017-11-03 17:57:11,399 INFO : runtime 1.93673
2017-11-03 17:57:11,399 INFO :
2017-11-03 17:57:12,550 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:13,342 INFO : Epoch 17
2017-11-03 17:57:13,343 INFO : train_loss 0.57735
2017-11-03 17:57:13,344 INFO : valid_loss 0.60624
2017-11-03 17:57:13,344 INFO : test_loss 0.71293
2017-11-03 17:57:13,345 INFO : train_sample_misclass 0.23263
2017-11-03 17:57:13,346 INFO : valid_sample_misclass 0.25866
2017-11-03 17:57:13,347 INFO : test_sample_misclass 0.35750
2017-11-03 17:57:13,347 INFO : train_misclass 0.18524
2017-11-03 17:57:13,348 INFO : valid_misclass 0.19553
2017-11-03 17:57:13,349 INFO : test_misclass 0.29375
2017-11-03 17:57:13,350 INFO : runtime 1.94563
2017-11-03 17:57:13,350 INFO :
2017-11-03 17:57:14,500 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:15,288 INFO : Epoch 18
2017-11-03 17:57:15,289 INFO : train_loss 0.46534
2017-11-03 17:57:15,290 INFO : valid_loss 0.52169
2017-11-03 17:57:15,291 INFO : test_loss 0.64741
2017-11-03 17:57:15,292 INFO : train_sample_misclass 0.19295
2017-11-03 17:57:15,292 INFO : valid_sample_misclass 0.23086
2017-11-03 17:57:15,293 INFO : test_sample_misclass 0.33224
2017-11-03 17:57:15,294 INFO : train_misclass 0.11978
2017-11-03 17:57:15,295 INFO : valid_misclass 0.18436
2017-11-03 17:57:15,295 INFO : test_misclass 0.24375
2017-11-03 17:57:15,296 INFO : runtime 1.95093
2017-11-03 17:57:15,297 INFO :
2017-11-03 17:57:16,445 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:17,230 INFO : Epoch 19
2017-11-03 17:57:17,232 INFO : train_loss 0.47752
2017-11-03 17:57:17,232 INFO : valid_loss 0.55626
2017-11-03 17:57:17,233 INFO : test_loss 0.68759
2017-11-03 17:57:17,234 INFO : train_sample_misclass 0.20194
2017-11-03 17:57:17,235 INFO : valid_sample_misclass 0.24273
2017-11-03 17:57:17,236 INFO : test_sample_misclass 0.35215
2017-11-03 17:57:17,236 INFO : train_misclass 0.13928
2017-11-03 17:57:17,237 INFO : valid_misclass 0.16201
2017-11-03 17:57:17,238 INFO : test_misclass 0.23750
2017-11-03 17:57:17,239 INFO : runtime 1.94504
2017-11-03 17:57:17,240 INFO :
2017-11-03 17:57:18,387 INFO : Time only for training updates: 0.92s
2017-11-03 17:57:19,174 INFO : Epoch 20
2017-11-03 17:57:19,175 INFO : train_loss 0.46220
2017-11-03 17:57:19,176 INFO : valid_loss 0.54947
2017-11-03 17:57:19,177 INFO : test_loss 0.62923
2017-11-03 17:57:19,178 INFO : train_sample_misclass 0.18707
2017-11-03 17:57:19,179 INFO : valid_sample_misclass 0.24050
2017-11-03 17:57:19,179 INFO : test_sample_misclass 0.29137
2017-11-03 17:57:19,180 INFO : train_misclass 0.10864
2017-11-03 17:57:19,181 INFO : valid_misclass 0.18436
2017-11-03 17:57:19,182 INFO : test_misclass 0.23750
2017-11-03 17:57:19,182 INFO : runtime 1.94147
2017-11-03 17:57:19,183 INFO :
2017-11-03 17:57:19,184 INFO : Setup for second stop...
2017-11-03 17:57:19,188 INFO : Train loss to reach 0.53394
2017-11-03 17:57:19,189 INFO : Run until second stop...
2017-11-03 17:57:20,151 INFO : Epoch 9
2017-11-03 17:57:20,153 INFO : train_loss 0.53530
2017-11-03 17:57:20,154 INFO : valid_loss 0.54075
2017-11-03 17:57:20,154 INFO : test_loss 0.69350
2017-11-03 17:57:20,155 INFO : train_sample_misclass 0.22843
2017-11-03 17:57:20,156 INFO : valid_sample_misclass 0.24399
2017-11-03 17:57:20,157 INFO : test_sample_misclass 0.35034
2017-11-03 17:57:20,157 INFO : train_misclass 0.15273
2017-11-03 17:57:20,158 INFO : valid_misclass 0.14525
2017-11-03 17:57:20,159 INFO : test_misclass 0.21250
2017-11-03 17:57:20,160 INFO : runtime 0.80763
2017-11-03 17:57:20,161 INFO :
2017-11-03 17:57:21,603 INFO : Time only for training updates: 1.16s
2017-11-03 17:57:22,533 INFO : Epoch 10
2017-11-03 17:57:22,534 INFO : train_loss 0.54183
2017-11-03 17:57:22,535 INFO : valid_loss 0.50682
2017-11-03 17:57:22,536 INFO : test_loss 0.64068
2017-11-03 17:57:22,537 INFO : train_sample_misclass 0.22782
2017-11-03 17:57:22,537 INFO : valid_sample_misclass 0.23420
2017-11-03 17:57:22,538 INFO : test_sample_misclass 0.31642
2017-11-03 17:57:22,539 INFO : train_misclass 0.18283
2017-11-03 17:57:22,540 INFO : valid_misclass 0.15642
2017-11-03 17:57:22,540 INFO : test_misclass 0.24375
2017-11-03 17:57:22,541 INFO : runtime 2.40818
2017-11-03 17:57:22,542 INFO :
We arrive at ca. 80% accuracy.
If you want to do trialwise decoding instead of cropped decoding, perform the following changes:
Change:
# This will determine how many crops are processed in parallel
input_time_length = 800
in_chans = 3
n_classes = 4
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,
final_conv_length=30).create_network()
to:
# This will determine how many crops are processed in parallel
input_time_length = train_set.X.shape[2]
in_chans = 3
n_classes = 4
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,
final_conv_length='auto').create_network()
Remove:
to_dense_prediction_model(model)
Remove:
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(np.ones((2, 3, input_time_length, 1), dtype=np.float32))
if cuda:
test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))
Change:
from braindecode.datautil.iterators import CropsFromTrialsIterator
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input)
to:
from braindecode.datautil.iterators import BalancedBatchSizeIterator
iterator = BalancedBatchSizeIterator(batch_size=32)
Change:
loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2)[:,:,0], targets)
to:
loss_function = F.nll_loss
Change:
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]
to:
monitors = [LossMonitor(), MisclassMonitor(col_suffix='misclass'),
RuntimeMonitor(),]
Resulting code can be seen at BBCI Data Epoched.