hybrid model, eegnetv4

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-07-19 17:18:15 +02:00
commit a3bac03277
8 arquivos alterados com 340 adições e 24 exclusões
+12 -6
Ver Arquivo
@@ -64,10 +64,13 @@ class BalancedBatchSizeIterator(object):
Resulting batches will not necessarily have the given batch size
but rather the next largest batch size that allows to split the set into
balanced batches (maximum size difference 1).
seed: int
Random seed for the the random generator that shuffles the batches.
"""
def __init__(self, batch_size):
def __init__(self, batch_size, seed=328774):
self.batch_size = batch_size
self.rng = RandomState(328774)
self.seed = seed
self.rng = RandomState(seed)
def get_batches(self, dataset, shuffle):
n_trials = dataset.X.shape[0]
@@ -85,7 +88,7 @@ class BalancedBatchSizeIterator(object):
yield (batch_X, batch_y)
def reset_rng(self):
self.rng = RandomState(328774)
self.rng = RandomState(self.seed)
class ClassBalancedBatchSizeIterator(object):
@@ -100,11 +103,14 @@ class ClassBalancedBatchSizeIterator(object):
Resulting batches will not necessarily have the given batch size
but rather the next largest batch size that allows to split the set into
balanced batches (maximum size difference 1).
seed: int
Random seed for the the random generator that shuffles the batches.
"""
def __init__(self, batch_size):
def __init__(self, batch_size, seed=328774):
self.batch_size = batch_size
self.rng = RandomState(328774)
self.seed = seed
self.rng = RandomState(seed)
def get_batches(self, dataset, shuffle):
n_trials = dataset.X.shape[0]
@@ -138,7 +144,7 @@ class ClassBalancedBatchSizeIterator(object):
yield (batch_X, batch_y)
def reset_rng(self):
self.rng = RandomState((4, 7, 2017))
self.rng = RandomState(self.seed)
class CropsFromTrialsIterator(object):
+15 -7
Ver Arquivo
@@ -137,6 +137,9 @@ class Experiment(object):
do_early_stop: bool
Whether to do an early stop at all. If true, reset to best model
even in case experiment does not run after early stop.
print_0_epoch: bool
Whether to compute monitor values and print them before the
start of training.
Attributes
----------
@@ -149,7 +152,8 @@ class Experiment(object):
run_after_early_stop,
model_loss_function=None,
batch_modifier=None, cuda=True, pin_memory=False,
do_early_stop=True):
do_early_stop=True,
print_0_epoch=True):
if run_after_early_stop:
assert do_early_stop == True, ("Can only run after early stop if "
"doing an early stop")
@@ -163,6 +167,9 @@ class Experiment(object):
self.datasets.pop('valid')
assert run_after_early_stop == False
assert do_early_stop == False
if test_set is None:
self.datasets.pop('test')
self.iterator = iterator
self.loss_function = loss_function
self.optimizer = optimizer
@@ -179,7 +186,7 @@ class Experiment(object):
self.rememberer = None
self.pin_memory = pin_memory
self.do_early_stop = do_early_stop
self.print_0_epoch = print_0_epoch
def run(self):
"""
@@ -248,11 +255,12 @@ class Experiment(object):
remember_best: bool
Whether to remember parameters at best epoch.
"""
self.monitor_epoch(datasets)
self.print_epoch()
if remember_best:
self.rememberer.remember_epoch(self.epochs_df, self.model,
self.optimizer)
if self.print_0_epoch:
self.monitor_epoch(datasets)
self.print_epoch()
if remember_best:
self.rememberer.remember_epoch(self.epochs_df, self.model,
self.optimizer)
self.iterator.reset_rng()
while not self.stop_criterion.should_stop(self.epochs_df):
+91
Ver Arquivo
@@ -0,0 +1,91 @@
import torch as th
from numpy.random import RandomState
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, RuntimeMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.experiments.experiment import Experiment
from braindecode.datautil.signal_target import SignalAndTarget
def find_optimizer(optimizer_name):
optim_found = False
for name in th.optim.__dict__.keys():
if name.lower() == optimizer_name.lower():
optimizer = th.optim.__dict__[name]
optim_found = True
break
if not optim_found:
raise ValueError("Unknown optimizer {:s}".format(optimizer))
return \
optimizer
class BaseModel(object):
def __init__(self):
self.compiled = False
def compile(self, loss, optimizer, metrics, seed=0):
self.loss = loss
self.network = self.create_network()
if not hasattr(optimizer, 'step'):
optimizer_class = find_optimizer(optimizer)
optimizer = optimizer_class(self.network.parameters())
self.optimizer = optimizer
monitor_dict = {'acc': MisclassMonitor}
self.monitors = [LossMonitor()]
extra_monitors = [monitor_dict[m]() for m in metrics]
self.monitors += extra_monitors
self.monitors += [RuntimeMonitor()]
self.seed_rng = RandomState(seed)
self.compiled = True
def fit(self, train_X, train_y, epochs, batch_size, validation_data=None, model_constraint=None,
remember_best_column=None):
if not self.compiled:
raise ValueError("Compile the model first by calling model.compile(loss, optimizer, metrics")
iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
stop_criterion = MaxEpochs(epochs - 1)# -1 since we dont print 0 epoch, which matters for this stop criterion
train_set = SignalAndTarget(train_X, train_y)
if validation_data is not None:
valid_set = SignalAndTarget(validation_data[0], validation_data[1])
else:
valid_set = None
test_set = None
exp = Experiment(self.network, train_set, valid_set, test_set, iterator=iterator,
loss_function=self.loss, optimizer=self.optimizer,
model_constraint=model_constraint,
monitors=self.monitors,
stop_criterion=stop_criterion,
remember_best_column=remember_best_column,
run_after_early_stop=False, cuda=True, print_0_epoch=False,
do_early_stop=(remember_best_column is not None))
exp.run()
return exp
def evaluate(self, X,y, batch_size=32):
# Create a dummy experiment for the evaluation
iterator = BalancedBatchSizeIterator(batch_size=batch_size,
seed=0) # seed irrelevant
stop_criterion = MaxEpochs(0)
train_set = SignalAndTarget(X, y)
model_constraint = None
valid_set = None
test_set = None
exp = Experiment(self.network, train_set, valid_set, test_set,
iterator=iterator,
loss_function=self.loss, optimizer=self.optimizer,
model_constraint=model_constraint,
monitors=self.monitors,
stop_criterion=stop_criterion,
remember_best_column=None,
run_after_early_stop=False, cuda=True,
print_0_epoch=False,
do_early_stop=False)
exp.monitor_epoch({'train': train_set})
result_dict = dict([(key.replace('train_', ''), val)
for key, val in
dict(exp.epochs_df.iloc[0]).items()])
return result_dict
+3 -1
Ver Arquivo
@@ -2,12 +2,14 @@ import numpy as np
from torch import nn
from torch.nn import init
from torch.nn.functional import elu
from braindecode.models.base import BaseModel
from braindecode.torch_ext.modules import Expression, AvgPool2dWithConv
from braindecode.torch_ext.functions import identity
from braindecode.torch_ext.util import np_to_var
class Deep4Net(object):
class Deep4Net(BaseModel):
"""
Deep ConvNet model from [1]_.
+130 -9
Ver Arquivo
@@ -1,12 +1,141 @@
import numpy as np
import torch as th
from torch import nn
from torch.nn import init
from torch.nn.functional import elu
from braindecode.models.base import BaseModel
from braindecode.torch_ext.init import glorot_weight_zero_bias
from braindecode.torch_ext.modules import Expression
from braindecode.torch_ext.util import np_to_var
class Conv2dWithConstraint(nn.Conv2d):
def __init__(self, *args, max_norm=1, **kwargs):
self.max_norm = max_norm
super(Conv2dWithConstraint, self).__init__(*args, **kwargs)
def forward(self, x):
self.weight.data = th.renorm(self.weight.data, p=2, dim=0,
maxnorm=self.max_norm)
return super(Conv2dWithConstraint, self).forward(x)
class EEGNetv4(BaseModel):
"""
EEGNet v4 model from [EEGNet]_.
Notes
-----
This implementation is not guaranteed to be correct, has not been checked
by original authors, only reimplemented from the paper description.
References
----------
.. [EEGNet] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
S. M., Hung, C. P., & Lance, B. J. (2018).
EEGNet: A Compact Convolutional Network for EEG-based
Brain-Computer Interfaces.
arXiv preprint arXiv:1611.08024.
"""
def __init__(self, in_chans,
n_classes,
final_conv_length='auto',
input_time_length=None,
pool_mode='mean',
F1=8,
D=2,
F2=16,
kernel_length=64,
third_kernel_size=(8, 4),
drop_prob=0.25
):
if final_conv_length == 'auto':
assert input_time_length is not None
self.__dict__.update(locals())
del self.self
def create_network(self):
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
model = nn.Sequential()
# b c 0 1
# now to b 1 0 c
model.add_module('dimshuffle', Expression(_transpose_to_b_1_c_0))
model.add_module('conv_temporal', nn.Conv2d(
1, self.F1, (1, self.kernel_length), stride=1, bias=False,
padding=(0, self.kernel_length // 2,)))
model.add_module('bnorm_temporal', nn.BatchNorm2d(
self.F1, momentum=0.01, affine=True, eps=1e-3), )
model.add_module('conv_spatial', Conv2dWithConstraint(
self.F1, self.F1 * self.D, (self.in_chans, 1), max_norm=1, stride=1, bias=False,
groups=self.F1,
padding=(0, 0)))
model.add_module('bnorm_1', nn.BatchNorm2d(
self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3), )
model.add_module('elu_1', Expression(elu))
model.add_module('pool_1', pool_class(
kernel_size=(1, 4), stride=(1, 4)))
model.add_module('drop_1', nn.Dropout(p=self.drop_prob))
# https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
model.add_module('conv_separable_depth', nn.Conv2d(
self.F1 * self.D, self.F1 * self.D, (1, 16), stride=1, bias=False, groups=self.F1 * self.D,
padding=(0, 16 // 2)))
model.add_module('conv_separable_point', nn.Conv2d(
self.F1 * self.D, self.F2, (1, 1), stride=1, bias=False,
padding=(0, 0)))
model.add_module('bnorm_2', nn.BatchNorm2d(
self.F2, momentum=0.01, affine=True, eps=1e-3), )
model.add_module('elu_2', Expression(elu))
model.add_module('pool_2', pool_class(
kernel_size=(1, 8), stride=(1, 8)))
model.add_module('drop_2', nn.Dropout(p=self.drop_prob))
out = model(np_to_var(np.ones(
(1, self.in_chans, self.input_time_length, 1),
dtype=np.float32)))
n_out_virtual_chans = out.cpu().data.numpy().shape[2]
if self.final_conv_length == 'auto':
n_out_time = out.cpu().data.numpy().shape[3]
self.final_conv_length = n_out_time
model.add_module('conv_classifier', nn.Conv2d(
self.F2, self.n_classes,
(n_out_virtual_chans, self.final_conv_length,), bias=True))
model.add_module('softmax', nn.LogSoftmax())
# Transpose back to the the logic of braindecode,
# so time in third dimension (axis=2)
model.add_module('permute_back', Expression(_transpose_1_0))
model.add_module('squeeze', Expression(_squeeze_final_output))
glorot_weight_zero_bias(model)
return model
def _transpose_to_b_1_c_0(x):
return x.permute(0, 3, 1, 2)
def _transpose_1_0(x):
return x.permute(0, 1, 3, 2)
# remove empty dim at end and potentially remove empty time dim
# do not just use squeeze as we never want to remove first dim
def _squeeze_final_output(x):
assert x.size()[3] == 1
x = x[:, :, :, 0]
if x.size()[2] == 1:
x = x[:, :, 0]
return x
class EEGNet(object):
"""
@@ -102,14 +231,6 @@ class EEGNet(object):
# Transpose back to the the logic of braindecode,
# so time in third dimension (axis=2)
model.add_module('permute_2', Expression(lambda x: x.permute(0,1,3,2)))
# remove empty dim at end and potentially remove empty time dim
# do not just use squeeze as we never want to remove first dim
def squeeze_output(x):
assert x.size()[3] == 1
x = x[:,:,:,0]
if x.size()[2] == 1:
x = x[:,:,0]
return x
model.add_module('squeeze', Expression(squeeze_output))
model.add_module('squeeze', Expression(_squeeze_final_output))
glorot_weight_zero_bias(model)
return model
+71
Ver Arquivo
@@ -0,0 +1,71 @@
import torch as th
from torch import nn
from torch.nn import ConstantPad2d
from braindecode.models.base import BaseModel
from braindecode.models.deep4 import Deep4Net
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model
class HybridNet(nn.Module, BaseModel):
def __init__(self, n_chans, n_classes, input_time_length):
super(HybridNet, self).__init__()
deep_model = Deep4Net(n_chans, n_classes,
input_time_length=input_time_length,
final_conv_length=2).create_network()
shallow_model = ShallowFBCSPNet(n_chans, n_classes,
input_time_length=input_time_length,
filter_time_length=28,
final_conv_length=29,
).create_network()
reduced_deep_model = nn.Sequential()
for name, module in deep_model.named_children():
if name == 'conv_classifier':
new_conv_layer = nn.Conv2d(module.in_channels, 60,
kernel_size=module.kernel_size,
stride=module.stride)
reduced_deep_model.add_module('deep_final_conv', new_conv_layer)
break
reduced_deep_model.add_module(name, module)
reduced_shallow_model = nn.Sequential()
for name, module in shallow_model.named_children():
if name == 'conv_classifier':
new_conv_layer = nn.Conv2d(module.in_channels, 40,
kernel_size=module.kernel_size,
stride=module.stride)
reduced_shallow_model.add_module('shallow_final_conv',
new_conv_layer)
break
reduced_shallow_model.add_module(name, module)
to_dense_prediction_model(reduced_deep_model)
to_dense_prediction_model(reduced_shallow_model)
self.reduced_shallow_model = reduced_shallow_model
self.reduced_deep_model = reduced_deep_model
self.final_conv = nn.Conv2d(100, n_classes, kernel_size=(1, 1),
stride=1)
def create_network(self):
return self
def forward(self, x):
deep_out = self.reduced_deep_model(x)
shallow_out = self.reduced_shallow_model(x)
n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2]
if n_diff_deep_shallow < 0:
deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(
deep_out)
elif n_diff_deep_shallow > 0:
shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(
shallow_out)
merged_out = th.cat((deep_out, shallow_out), dim=1)
linear_out = self.final_conv(merged_out)
softmaxed = nn.LogSoftmax(dim=1)(linear_out)
squeezed = softmaxed.squeeze(3)
return squeezed
+2 -1
Ver Arquivo
@@ -2,12 +2,13 @@ import numpy as np
from torch import nn
from torch.nn import init
from braindecode.models.base import BaseModel
from braindecode.torch_ext.modules import Expression
from braindecode.torch_ext.functions import safe_log, square
from braindecode.torch_ext.util import np_to_var
class ShallowFBCSPNet(object):
class ShallowFBCSPNet(BaseModel):
"""
Shallow ConvNet model from [2]_.
+16
Ver Arquivo
@@ -9,6 +9,14 @@ braindecode\.visualization package
Submodules
----------
braindecode\.visualization\.input\_windows module
-------------------------------------------------
.. automodule:: braindecode.visualization.input_windows
:members:
:undoc-members:
:show-inheritance:
braindecode\.visualization\.perturbation module
-----------------------------------------------
@@ -25,4 +33,12 @@ braindecode\.visualization\.plot module
:undoc-members:
:show-inheritance:
braindecode\.visualization\.sinfit module
-----------------------------------------
.. automodule:: braindecode.visualization.sinfit
:members:
:undoc-members:
:show-inheritance: