hybrid model, eegnetv4
Esse commit está contido em:
@@ -64,10 +64,13 @@ class BalancedBatchSizeIterator(object):
|
||||
Resulting batches will not necessarily have the given batch size
|
||||
but rather the next largest batch size that allows to split the set into
|
||||
balanced batches (maximum size difference 1).
|
||||
seed: int
|
||||
Random seed for the the random generator that shuffles the batches.
|
||||
"""
|
||||
def __init__(self, batch_size):
|
||||
def __init__(self, batch_size, seed=328774):
|
||||
self.batch_size = batch_size
|
||||
self.rng = RandomState(328774)
|
||||
self.seed = seed
|
||||
self.rng = RandomState(seed)
|
||||
|
||||
def get_batches(self, dataset, shuffle):
|
||||
n_trials = dataset.X.shape[0]
|
||||
@@ -85,7 +88,7 @@ class BalancedBatchSizeIterator(object):
|
||||
yield (batch_X, batch_y)
|
||||
|
||||
def reset_rng(self):
|
||||
self.rng = RandomState(328774)
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
|
||||
class ClassBalancedBatchSizeIterator(object):
|
||||
@@ -100,11 +103,14 @@ class ClassBalancedBatchSizeIterator(object):
|
||||
Resulting batches will not necessarily have the given batch size
|
||||
but rather the next largest batch size that allows to split the set into
|
||||
balanced batches (maximum size difference 1).
|
||||
seed: int
|
||||
Random seed for the the random generator that shuffles the batches.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size):
|
||||
def __init__(self, batch_size, seed=328774):
|
||||
self.batch_size = batch_size
|
||||
self.rng = RandomState(328774)
|
||||
self.seed = seed
|
||||
self.rng = RandomState(seed)
|
||||
|
||||
def get_batches(self, dataset, shuffle):
|
||||
n_trials = dataset.X.shape[0]
|
||||
@@ -138,7 +144,7 @@ class ClassBalancedBatchSizeIterator(object):
|
||||
yield (batch_X, batch_y)
|
||||
|
||||
def reset_rng(self):
|
||||
self.rng = RandomState((4, 7, 2017))
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
|
||||
class CropsFromTrialsIterator(object):
|
||||
|
||||
@@ -137,6 +137,9 @@ class Experiment(object):
|
||||
do_early_stop: bool
|
||||
Whether to do an early stop at all. If true, reset to best model
|
||||
even in case experiment does not run after early stop.
|
||||
print_0_epoch: bool
|
||||
Whether to compute monitor values and print them before the
|
||||
start of training.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -149,7 +152,8 @@ class Experiment(object):
|
||||
run_after_early_stop,
|
||||
model_loss_function=None,
|
||||
batch_modifier=None, cuda=True, pin_memory=False,
|
||||
do_early_stop=True):
|
||||
do_early_stop=True,
|
||||
print_0_epoch=True):
|
||||
if run_after_early_stop:
|
||||
assert do_early_stop == True, ("Can only run after early stop if "
|
||||
"doing an early stop")
|
||||
@@ -163,6 +167,9 @@ class Experiment(object):
|
||||
self.datasets.pop('valid')
|
||||
assert run_after_early_stop == False
|
||||
assert do_early_stop == False
|
||||
if test_set is None:
|
||||
self.datasets.pop('test')
|
||||
|
||||
self.iterator = iterator
|
||||
self.loss_function = loss_function
|
||||
self.optimizer = optimizer
|
||||
@@ -179,7 +186,7 @@ class Experiment(object):
|
||||
self.rememberer = None
|
||||
self.pin_memory = pin_memory
|
||||
self.do_early_stop = do_early_stop
|
||||
|
||||
self.print_0_epoch = print_0_epoch
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
@@ -248,11 +255,12 @@ class Experiment(object):
|
||||
remember_best: bool
|
||||
Whether to remember parameters at best epoch.
|
||||
"""
|
||||
self.monitor_epoch(datasets)
|
||||
self.print_epoch()
|
||||
if remember_best:
|
||||
self.rememberer.remember_epoch(self.epochs_df, self.model,
|
||||
self.optimizer)
|
||||
if self.print_0_epoch:
|
||||
self.monitor_epoch(datasets)
|
||||
self.print_epoch()
|
||||
if remember_best:
|
||||
self.rememberer.remember_epoch(self.epochs_df, self.model,
|
||||
self.optimizer)
|
||||
|
||||
self.iterator.reset_rng()
|
||||
while not self.stop_criterion.should_stop(self.epochs_df):
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
import torch as th
|
||||
from numpy.random import RandomState
|
||||
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, RuntimeMonitor
|
||||
from braindecode.experiments.stopcriteria import MaxEpochs
|
||||
from braindecode.datautil.iterators import BalancedBatchSizeIterator
|
||||
from braindecode.experiments.experiment import Experiment
|
||||
from braindecode.datautil.signal_target import SignalAndTarget
|
||||
|
||||
|
||||
def find_optimizer(optimizer_name):
|
||||
optim_found = False
|
||||
for name in th.optim.__dict__.keys():
|
||||
if name.lower() == optimizer_name.lower():
|
||||
optimizer = th.optim.__dict__[name]
|
||||
optim_found = True
|
||||
break
|
||||
if not optim_found:
|
||||
raise ValueError("Unknown optimizer {:s}".format(optimizer))
|
||||
return \
|
||||
optimizer
|
||||
|
||||
|
||||
class BaseModel(object):
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
|
||||
def compile(self, loss, optimizer, metrics, seed=0):
|
||||
self.loss = loss
|
||||
self.network = self.create_network()
|
||||
if not hasattr(optimizer, 'step'):
|
||||
optimizer_class = find_optimizer(optimizer)
|
||||
optimizer = optimizer_class(self.network.parameters())
|
||||
self.optimizer = optimizer
|
||||
monitor_dict = {'acc': MisclassMonitor}
|
||||
self.monitors = [LossMonitor()]
|
||||
extra_monitors = [monitor_dict[m]() for m in metrics]
|
||||
self.monitors += extra_monitors
|
||||
self.monitors += [RuntimeMonitor()]
|
||||
self.seed_rng = RandomState(seed)
|
||||
self.compiled = True
|
||||
|
||||
def fit(self, train_X, train_y, epochs, batch_size, validation_data=None, model_constraint=None,
|
||||
remember_best_column=None):
|
||||
if not self.compiled:
|
||||
raise ValueError("Compile the model first by calling model.compile(loss, optimizer, metrics")
|
||||
iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
|
||||
stop_criterion = MaxEpochs(epochs - 1)# -1 since we dont print 0 epoch, which matters for this stop criterion
|
||||
train_set = SignalAndTarget(train_X, train_y)
|
||||
if validation_data is not None:
|
||||
valid_set = SignalAndTarget(validation_data[0], validation_data[1])
|
||||
else:
|
||||
valid_set = None
|
||||
test_set = None
|
||||
exp = Experiment(self.network, train_set, valid_set, test_set, iterator=iterator,
|
||||
loss_function=self.loss, optimizer=self.optimizer,
|
||||
model_constraint=model_constraint,
|
||||
monitors=self.monitors,
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=remember_best_column,
|
||||
run_after_early_stop=False, cuda=True, print_0_epoch=False,
|
||||
do_early_stop=(remember_best_column is not None))
|
||||
exp.run()
|
||||
return exp
|
||||
|
||||
def evaluate(self, X,y, batch_size=32):
|
||||
# Create a dummy experiment for the evaluation
|
||||
iterator = BalancedBatchSizeIterator(batch_size=batch_size,
|
||||
seed=0) # seed irrelevant
|
||||
stop_criterion = MaxEpochs(0)
|
||||
train_set = SignalAndTarget(X, y)
|
||||
model_constraint = None
|
||||
valid_set = None
|
||||
test_set = None
|
||||
|
||||
exp = Experiment(self.network, train_set, valid_set, test_set,
|
||||
iterator=iterator,
|
||||
loss_function=self.loss, optimizer=self.optimizer,
|
||||
model_constraint=model_constraint,
|
||||
monitors=self.monitors,
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=None,
|
||||
run_after_early_stop=False, cuda=True,
|
||||
print_0_epoch=False,
|
||||
do_early_stop=False)
|
||||
|
||||
exp.monitor_epoch({'train': train_set})
|
||||
|
||||
result_dict = dict([(key.replace('train_', ''), val)
|
||||
for key, val in
|
||||
dict(exp.epochs_df.iloc[0]).items()])
|
||||
return result_dict
|
||||
@@ -2,12 +2,14 @@ import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn.functional import elu
|
||||
|
||||
from braindecode.models.base import BaseModel
|
||||
from braindecode.torch_ext.modules import Expression, AvgPool2dWithConv
|
||||
from braindecode.torch_ext.functions import identity
|
||||
from braindecode.torch_ext.util import np_to_var
|
||||
|
||||
|
||||
class Deep4Net(object):
|
||||
class Deep4Net(BaseModel):
|
||||
"""
|
||||
Deep ConvNet model from [1]_.
|
||||
|
||||
|
||||
@@ -1,12 +1,141 @@
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn.functional import elu
|
||||
|
||||
from braindecode.models.base import BaseModel
|
||||
from braindecode.torch_ext.init import glorot_weight_zero_bias
|
||||
from braindecode.torch_ext.modules import Expression
|
||||
from braindecode.torch_ext.util import np_to_var
|
||||
|
||||
class Conv2dWithConstraint(nn.Conv2d):
|
||||
def __init__(self, *args, max_norm=1, **kwargs):
|
||||
self.max_norm = max_norm
|
||||
super(Conv2dWithConstraint, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
self.weight.data = th.renorm(self.weight.data, p=2, dim=0,
|
||||
maxnorm=self.max_norm)
|
||||
return super(Conv2dWithConstraint, self).forward(x)
|
||||
|
||||
class EEGNetv4(BaseModel):
|
||||
"""
|
||||
EEGNet v4 model from [EEGNet]_.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This implementation is not guaranteed to be correct, has not been checked
|
||||
by original authors, only reimplemented from the paper description.
|
||||
|
||||
References
|
||||
----------
|
||||
|
||||
.. [EEGNet] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
|
||||
S. M., Hung, C. P., & Lance, B. J. (2018).
|
||||
EEGNet: A Compact Convolutional Network for EEG-based
|
||||
Brain-Computer Interfaces.
|
||||
arXiv preprint arXiv:1611.08024.
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans,
|
||||
n_classes,
|
||||
final_conv_length='auto',
|
||||
input_time_length=None,
|
||||
pool_mode='mean',
|
||||
F1=8,
|
||||
D=2,
|
||||
F2=16,
|
||||
kernel_length=64,
|
||||
third_kernel_size=(8, 4),
|
||||
drop_prob=0.25
|
||||
):
|
||||
|
||||
if final_conv_length == 'auto':
|
||||
assert input_time_length is not None
|
||||
self.__dict__.update(locals())
|
||||
del self.self
|
||||
|
||||
def create_network(self):
|
||||
pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode]
|
||||
model = nn.Sequential()
|
||||
# b c 0 1
|
||||
# now to b 1 0 c
|
||||
model.add_module('dimshuffle', Expression(_transpose_to_b_1_c_0))
|
||||
|
||||
model.add_module('conv_temporal', nn.Conv2d(
|
||||
1, self.F1, (1, self.kernel_length), stride=1, bias=False,
|
||||
padding=(0, self.kernel_length // 2,)))
|
||||
model.add_module('bnorm_temporal', nn.BatchNorm2d(
|
||||
self.F1, momentum=0.01, affine=True, eps=1e-3), )
|
||||
model.add_module('conv_spatial', Conv2dWithConstraint(
|
||||
self.F1, self.F1 * self.D, (self.in_chans, 1), max_norm=1, stride=1, bias=False,
|
||||
groups=self.F1,
|
||||
padding=(0, 0)))
|
||||
|
||||
model.add_module('bnorm_1', nn.BatchNorm2d(
|
||||
self.F1 * self.D, momentum=0.01, affine=True, eps=1e-3), )
|
||||
model.add_module('elu_1', Expression(elu))
|
||||
|
||||
model.add_module('pool_1', pool_class(
|
||||
kernel_size=(1, 4), stride=(1, 4)))
|
||||
model.add_module('drop_1', nn.Dropout(p=self.drop_prob))
|
||||
|
||||
# https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/7
|
||||
model.add_module('conv_separable_depth', nn.Conv2d(
|
||||
self.F1 * self.D, self.F1 * self.D, (1, 16), stride=1, bias=False, groups=self.F1 * self.D,
|
||||
padding=(0, 16 // 2)))
|
||||
model.add_module('conv_separable_point', nn.Conv2d(
|
||||
self.F1 * self.D, self.F2, (1, 1), stride=1, bias=False,
|
||||
padding=(0, 0)))
|
||||
|
||||
model.add_module('bnorm_2', nn.BatchNorm2d(
|
||||
self.F2, momentum=0.01, affine=True, eps=1e-3), )
|
||||
model.add_module('elu_2', Expression(elu))
|
||||
model.add_module('pool_2', pool_class(
|
||||
kernel_size=(1, 8), stride=(1, 8)))
|
||||
model.add_module('drop_2', nn.Dropout(p=self.drop_prob))
|
||||
|
||||
out = model(np_to_var(np.ones(
|
||||
(1, self.in_chans, self.input_time_length, 1),
|
||||
dtype=np.float32)))
|
||||
n_out_virtual_chans = out.cpu().data.numpy().shape[2]
|
||||
|
||||
if self.final_conv_length == 'auto':
|
||||
n_out_time = out.cpu().data.numpy().shape[3]
|
||||
self.final_conv_length = n_out_time
|
||||
|
||||
model.add_module('conv_classifier', nn.Conv2d(
|
||||
self.F2, self.n_classes,
|
||||
(n_out_virtual_chans, self.final_conv_length,), bias=True))
|
||||
model.add_module('softmax', nn.LogSoftmax())
|
||||
# Transpose back to the the logic of braindecode,
|
||||
# so time in third dimension (axis=2)
|
||||
model.add_module('permute_back', Expression(_transpose_1_0))
|
||||
model.add_module('squeeze', Expression(_squeeze_final_output))
|
||||
|
||||
glorot_weight_zero_bias(model)
|
||||
return model
|
||||
|
||||
|
||||
def _transpose_to_b_1_c_0(x):
|
||||
return x.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def _transpose_1_0(x):
|
||||
return x.permute(0, 1, 3, 2)
|
||||
|
||||
|
||||
# remove empty dim at end and potentially remove empty time dim
|
||||
# do not just use squeeze as we never want to remove first dim
|
||||
def _squeeze_final_output(x):
|
||||
assert x.size()[3] == 1
|
||||
x = x[:, :, :, 0]
|
||||
if x.size()[2] == 1:
|
||||
x = x[:, :, 0]
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class EEGNet(object):
|
||||
"""
|
||||
@@ -102,14 +231,6 @@ class EEGNet(object):
|
||||
# Transpose back to the the logic of braindecode,
|
||||
# so time in third dimension (axis=2)
|
||||
model.add_module('permute_2', Expression(lambda x: x.permute(0,1,3,2)))
|
||||
# remove empty dim at end and potentially remove empty time dim
|
||||
# do not just use squeeze as we never want to remove first dim
|
||||
def squeeze_output(x):
|
||||
assert x.size()[3] == 1
|
||||
x = x[:,:,:,0]
|
||||
if x.size()[2] == 1:
|
||||
x = x[:,:,0]
|
||||
return x
|
||||
model.add_module('squeeze', Expression(squeeze_output))
|
||||
model.add_module('squeeze', Expression(_squeeze_final_output))
|
||||
glorot_weight_zero_bias(model)
|
||||
return model
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
import torch as th
|
||||
from torch import nn
|
||||
from torch.nn import ConstantPad2d
|
||||
|
||||
from braindecode.models.base import BaseModel
|
||||
from braindecode.models.deep4 import Deep4Net
|
||||
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
|
||||
from braindecode.models.util import to_dense_prediction_model
|
||||
|
||||
|
||||
class HybridNet(nn.Module, BaseModel):
|
||||
def __init__(self, n_chans, n_classes, input_time_length):
|
||||
super(HybridNet, self).__init__()
|
||||
deep_model = Deep4Net(n_chans, n_classes,
|
||||
input_time_length=input_time_length,
|
||||
final_conv_length=2).create_network()
|
||||
shallow_model = ShallowFBCSPNet(n_chans, n_classes,
|
||||
input_time_length=input_time_length,
|
||||
filter_time_length=28,
|
||||
final_conv_length=29,
|
||||
).create_network()
|
||||
|
||||
reduced_deep_model = nn.Sequential()
|
||||
for name, module in deep_model.named_children():
|
||||
if name == 'conv_classifier':
|
||||
new_conv_layer = nn.Conv2d(module.in_channels, 60,
|
||||
kernel_size=module.kernel_size,
|
||||
stride=module.stride)
|
||||
reduced_deep_model.add_module('deep_final_conv', new_conv_layer)
|
||||
break
|
||||
reduced_deep_model.add_module(name, module)
|
||||
|
||||
reduced_shallow_model = nn.Sequential()
|
||||
for name, module in shallow_model.named_children():
|
||||
if name == 'conv_classifier':
|
||||
new_conv_layer = nn.Conv2d(module.in_channels, 40,
|
||||
kernel_size=module.kernel_size,
|
||||
stride=module.stride)
|
||||
reduced_shallow_model.add_module('shallow_final_conv',
|
||||
new_conv_layer)
|
||||
break
|
||||
reduced_shallow_model.add_module(name, module)
|
||||
|
||||
to_dense_prediction_model(reduced_deep_model)
|
||||
to_dense_prediction_model(reduced_shallow_model)
|
||||
self.reduced_shallow_model = reduced_shallow_model
|
||||
self.reduced_deep_model = reduced_deep_model
|
||||
self.final_conv = nn.Conv2d(100, n_classes, kernel_size=(1, 1),
|
||||
stride=1)
|
||||
|
||||
def create_network(self):
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
deep_out = self.reduced_deep_model(x)
|
||||
shallow_out = self.reduced_shallow_model(x)
|
||||
|
||||
n_diff_deep_shallow = deep_out.size()[2] - shallow_out.size()[2]
|
||||
|
||||
if n_diff_deep_shallow < 0:
|
||||
deep_out = ConstantPad2d((0, 0, -n_diff_deep_shallow, 0), 0)(
|
||||
deep_out)
|
||||
elif n_diff_deep_shallow > 0:
|
||||
shallow_out = ConstantPad2d((0, 0, n_diff_deep_shallow, 0), 0)(
|
||||
shallow_out)
|
||||
|
||||
merged_out = th.cat((deep_out, shallow_out), dim=1)
|
||||
linear_out = self.final_conv(merged_out)
|
||||
softmaxed = nn.LogSoftmax(dim=1)(linear_out)
|
||||
squeezed = softmaxed.squeeze(3)
|
||||
return squeezed
|
||||
@@ -2,12 +2,13 @@ import numpy as np
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
|
||||
from braindecode.models.base import BaseModel
|
||||
from braindecode.torch_ext.modules import Expression
|
||||
from braindecode.torch_ext.functions import safe_log, square
|
||||
from braindecode.torch_ext.util import np_to_var
|
||||
|
||||
|
||||
class ShallowFBCSPNet(object):
|
||||
class ShallowFBCSPNet(BaseModel):
|
||||
"""
|
||||
Shallow ConvNet model from [2]_.
|
||||
|
||||
|
||||
@@ -9,6 +9,14 @@ braindecode\.visualization package
|
||||
Submodules
|
||||
----------
|
||||
|
||||
braindecode\.visualization\.input\_windows module
|
||||
-------------------------------------------------
|
||||
|
||||
.. automodule:: braindecode.visualization.input_windows
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
braindecode\.visualization\.perturbation module
|
||||
-----------------------------------------------
|
||||
|
||||
@@ -25,4 +33,12 @@ braindecode\.visualization\.plot module
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
braindecode\.visualization\.sinfit module
|
||||
-----------------------------------------
|
||||
|
||||
.. automodule:: braindecode.visualization.sinfit
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário