cosine annealing, cropped for basemodel, hybrid corrected
Esse commit está contido em:
@@ -1,10 +1,16 @@
|
||||
import torch as th
|
||||
import numpy as np
|
||||
from numpy.random import RandomState
|
||||
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, RuntimeMonitor
|
||||
import torch as th
|
||||
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
|
||||
RuntimeMonitor, CroppedTrialMisclassMonitor
|
||||
from braindecode.experiments.stopcriteria import MaxEpochs
|
||||
from braindecode.datautil.iterators import BalancedBatchSizeIterator
|
||||
from braindecode.datautil.iterators import BalancedBatchSizeIterator, \
|
||||
CropsFromTrialsIterator
|
||||
from braindecode.experiments.experiment import Experiment
|
||||
from braindecode.datautil.signal_target import SignalAndTarget
|
||||
from braindecode.models.util import to_dense_prediction_model
|
||||
from braindecode.torch_ext.schedulers import CosineAnnealing, ScheduledOptimizer
|
||||
from braindecode.torch_ext.util import np_to_var
|
||||
|
||||
|
||||
def find_optimizer(optimizer_name):
|
||||
@@ -21,45 +27,106 @@ def find_optimizer(optimizer_name):
|
||||
|
||||
|
||||
class BaseModel(object):
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
def cuda(self):
|
||||
self._ensure_network_exists()
|
||||
assert not self.compiled,\
|
||||
("Call cuda before compiling model, otherwise optimization will not work")
|
||||
self.network = self.network.cuda()
|
||||
self.cuda = True
|
||||
return self
|
||||
|
||||
def compile(self, loss, optimizer, metrics, seed=0):
|
||||
def parameters(self):
|
||||
self._ensure_network_exists()
|
||||
return self.network.parameters()
|
||||
|
||||
def _ensure_network_exists(self):
|
||||
if not hasattr(self, 'network'):
|
||||
self.network = self.create_network()
|
||||
self.cuda = False
|
||||
self.compiled = False
|
||||
|
||||
def compile(self, loss, optimizer, metrics, cropped=False, seed=0):
|
||||
self.loss = loss
|
||||
self.network = self.create_network()
|
||||
self._ensure_network_exists()
|
||||
if cropped:
|
||||
to_dense_prediction_model(self.network)
|
||||
if not hasattr(optimizer, 'step'):
|
||||
optimizer_class = find_optimizer(optimizer)
|
||||
optimizer = optimizer_class(self.network.parameters())
|
||||
self.optimizer = optimizer
|
||||
monitor_dict = {'acc': MisclassMonitor}
|
||||
self.monitors = [LossMonitor()]
|
||||
extra_monitors = [monitor_dict[m]() for m in metrics]
|
||||
self.monitors += extra_monitors
|
||||
self.monitors += [RuntimeMonitor()]
|
||||
self.metrics = metrics
|
||||
self.seed_rng = RandomState(seed)
|
||||
self.cropped = cropped
|
||||
self.compiled = True
|
||||
|
||||
def fit(self, train_X, train_y, epochs, batch_size, validation_data=None, model_constraint=None,
|
||||
remember_best_column=None):
|
||||
def fit(self, train_X, train_y, epochs, batch_size, input_time_length=None,
|
||||
validation_data=None, model_constraint=None,
|
||||
remember_best_column=None, scheduler=None):
|
||||
if not self.compiled:
|
||||
raise ValueError("Compile the model first by calling model.compile(loss, optimizer, metrics")
|
||||
iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
|
||||
raise ValueError("Compile the model first by calling model.compile(loss, optimizer, metrics)")
|
||||
|
||||
|
||||
if self.cropped and input_time_length is None:
|
||||
raise ValueError("In cropped mode, need to specify input_time_length,"
|
||||
"which is the number of timesteps that will be pushed through"
|
||||
"the network in a single pass.")
|
||||
if self.cropped:
|
||||
test_input = np_to_var(train_X[0:1], dtype=np.float32)
|
||||
while len(test_input.size()) < 4:
|
||||
test_input = test_input.unsqueeze(-1)
|
||||
if self.cuda:
|
||||
test_input = test_input.cuda()
|
||||
out = self.network(test_input)
|
||||
n_preds_per_input = out.cpu().data.numpy().shape[2]
|
||||
iterator = CropsFromTrialsIterator(
|
||||
batch_size=batch_size, input_time_length=input_time_length,
|
||||
n_preds_per_input=n_preds_per_input,
|
||||
seed=self.seed_rng.randint(0, 4294967295))
|
||||
else:
|
||||
iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
|
||||
stop_criterion = MaxEpochs(epochs - 1)# -1 since we dont print 0 epoch, which matters for this stop criterion
|
||||
train_set = SignalAndTarget(train_X, train_y)
|
||||
optimizer = self.optimizer
|
||||
if scheduler is not None:
|
||||
assert scheduler == 'cosine'
|
||||
n_updates_per_epoch = sum(
|
||||
[1 for _ in iterator.get_batches(train_set, shuffle=True)])
|
||||
n_updates_per_period = n_updates_per_epoch * epochs
|
||||
if scheduler == 'cosine':
|
||||
scheduler = CosineAnnealing(n_updates_per_period)
|
||||
schedule_weight_decay = False
|
||||
if optimizer.__class__.__name__ == 'AdamW':
|
||||
schedule_weight_decay = True
|
||||
optimizer = ScheduledOptimizer(scheduler, self.optimizer,
|
||||
schedule_weight_decay=schedule_weight_decay)
|
||||
loss_function = self.loss
|
||||
if self.cropped:
|
||||
loss_function = lambda outputs, targets:\
|
||||
self.loss(th.mean(outputs, dim=2), targets)
|
||||
if validation_data is not None:
|
||||
valid_set = SignalAndTarget(validation_data[0], validation_data[1])
|
||||
else:
|
||||
valid_set = None
|
||||
test_set = None
|
||||
if self.cropped:
|
||||
monitor_dict = {'acc': lambda :
|
||||
CroppedTrialMisclassMonitor(input_time_length)}
|
||||
else:
|
||||
monitor_dict = {'acc': MisclassMonitor}
|
||||
self.monitors = [LossMonitor()]
|
||||
extra_monitors = [monitor_dict[m]() for m in self.metrics]
|
||||
self.monitors += extra_monitors
|
||||
self.monitors += [RuntimeMonitor()]
|
||||
exp = Experiment(self.network, train_set, valid_set, test_set, iterator=iterator,
|
||||
loss_function=self.loss, optimizer=self.optimizer,
|
||||
loss_function=loss_function, optimizer=optimizer,
|
||||
model_constraint=model_constraint,
|
||||
monitors=self.monitors,
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=remember_best_column,
|
||||
run_after_early_stop=False, cuda=True, print_0_epoch=False,
|
||||
run_after_early_stop=False, cuda=self.cuda, print_0_epoch=False,
|
||||
do_early_stop=(remember_best_column is not None))
|
||||
exp.run()
|
||||
self.epochs_df = exp.epochs_df
|
||||
return exp
|
||||
|
||||
def evaluate(self, X,y, batch_size=32):
|
||||
@@ -71,15 +138,18 @@ class BaseModel(object):
|
||||
model_constraint = None
|
||||
valid_set = None
|
||||
test_set = None
|
||||
|
||||
loss_function = self.loss
|
||||
if self.cropped:
|
||||
loss_function = lambda outputs, targets: \
|
||||
self.loss(th.mean(outputs, dim=2), targets)
|
||||
exp = Experiment(self.network, train_set, valid_set, test_set,
|
||||
iterator=iterator,
|
||||
loss_function=self.loss, optimizer=self.optimizer,
|
||||
loss_function=loss_function, optimizer=self.optimizer,
|
||||
model_constraint=model_constraint,
|
||||
monitors=self.monitors,
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=None,
|
||||
run_after_early_stop=False, cuda=True,
|
||||
run_after_early_stop=False, cuda=self.cuda,
|
||||
print_0_epoch=False,
|
||||
do_early_stop=False)
|
||||
|
||||
@@ -88,4 +158,4 @@ class BaseModel(object):
|
||||
result_dict = dict([(key.replace('train_', ''), val)
|
||||
for key, val in
|
||||
dict(exp.epochs_df.iloc[0]).items()])
|
||||
return result_dict
|
||||
return result_dict
|
||||
|
||||
@@ -11,11 +11,17 @@ from braindecode.models.util import to_dense_prediction_model
|
||||
class HybridNet(nn.Module, BaseModel):
|
||||
def __init__(self, n_chans, n_classes, input_time_length):
|
||||
super(HybridNet, self).__init__()
|
||||
deep_model = Deep4Net(n_chans, n_classes,
|
||||
deep_model = Deep4Net(n_chans, n_classes, n_filters_time=20,
|
||||
n_filters_spat=30,
|
||||
n_filters_2=40,
|
||||
n_filters_3=50,
|
||||
n_filters_4=60,
|
||||
input_time_length=input_time_length,
|
||||
final_conv_length=2).create_network()
|
||||
shallow_model = ShallowFBCSPNet(n_chans, n_classes,
|
||||
input_time_length=input_time_length,
|
||||
n_filters_time=30,
|
||||
n_filters_spat=40,
|
||||
filter_time_length=28,
|
||||
final_conv_length=29,
|
||||
).create_network()
|
||||
@@ -43,8 +49,8 @@ class HybridNet(nn.Module, BaseModel):
|
||||
|
||||
to_dense_prediction_model(reduced_deep_model)
|
||||
to_dense_prediction_model(reduced_shallow_model)
|
||||
self.reduced_shallow_model = reduced_shallow_model
|
||||
self.reduced_deep_model = reduced_deep_model
|
||||
self.reduced_shallow_model = reduced_shallow_model
|
||||
self.final_conv = nn.Conv2d(100, n_classes, kernel_size=(1, 1),
|
||||
stride=1)
|
||||
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
"""Implements Adam algorithm with fixed as TODO
|
||||
|
||||
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = grad.new().resize_as_(grad).zero_()
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
if group['weight_decay'] != 0:
|
||||
p.data.add_(-group['weight_decay'], p.data)
|
||||
|
||||
return loss
|
||||
@@ -0,0 +1,64 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ScheduledOptimizer(object):
|
||||
def __init__(self, scheduler, optimizer,
|
||||
schedule_weight_decay):
|
||||
self.scheduler = scheduler
|
||||
self.optimizer = optimizer
|
||||
self.schedule_weight_decay = schedule_weight_decay
|
||||
self.initial_lrs = list(map(
|
||||
lambda group: group['lr'], optimizer.param_groups))
|
||||
self.initial_weight_decays = list(map(
|
||||
lambda group: group['weight_decay'], optimizer.param_groups))
|
||||
self.i_update = 0
|
||||
|
||||
def step(self):
|
||||
for group, initial_lr, initial_wd in zip(
|
||||
self.optimizer.param_groups,
|
||||
self.initial_lrs,
|
||||
self.initial_weight_decays):
|
||||
group['lr'] = self.scheduler.get_lr(initial_lr, self.i_update)
|
||||
if self.schedule_weight_decay:
|
||||
group['weight_decay'] = self.scheduler.get_weight_decay(
|
||||
initial_wd, self.i_update)
|
||||
self.optimizer.step()
|
||||
self.i_update += 1
|
||||
|
||||
def state_dict(self):
|
||||
return self.optimizer.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.optimizer.load_state_dict(state_dict)
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
|
||||
class CosineAnnealing(object):
|
||||
def __init__(self, n_updates_per_period,):
|
||||
if not hasattr(n_updates_per_period, '__len__'):
|
||||
n_updates_per_period = [n_updates_per_period]
|
||||
assert np.all(np.array(n_updates_per_period) > 0)
|
||||
self.update_period_boundaries = np.cumsum(n_updates_per_period)
|
||||
self.update_period_boundaries = np.concatenate((
|
||||
[0], self.update_period_boundaries))
|
||||
|
||||
def get_lr(self, initial_val, i_update):
|
||||
assert i_update < self.update_period_boundaries[-1], (
|
||||
"More updates ({:d}) than expected ({:d})".format(
|
||||
i_update, self.update_period_boundaries[-1] - 1))
|
||||
i_end_period = np.searchsorted(self.update_period_boundaries,
|
||||
i_update, side='right')
|
||||
assert i_end_period > 0
|
||||
i_start_update = self.update_period_boundaries[i_end_period - 1]
|
||||
i_end_update = self.update_period_boundaries[i_end_period]
|
||||
i_update = i_update - i_start_update
|
||||
assert i_update >= 0
|
||||
n_updates_this_period = i_end_update - i_start_update
|
||||
fraction_period = i_update / np.float64(n_updates_this_period)
|
||||
return initial_val * (0.5 * np.cos(np.pi * fraction_period) + 0.5)
|
||||
|
||||
def get_weight_decay(self, initial_val, i_update):
|
||||
return self.get_lr(initial_val, i_update)
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário