new simple model class

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-08-06 23:26:31 +02:00
commit b560b91dad
4 arquivos alterados com 199 adições e 54 exclusões
+31 -30
Ver Arquivo
@@ -7,6 +7,7 @@ import pandas as pd
import torch as th
from braindecode.datautil.splitters import concatenate_sets
from braindecode.experiments.loggers import Printer
from braindecode.experiments.stopcriteria import MaxEpochs, ColumnBelow, Or
from braindecode.torch_ext.util import np_to_var, set_random_seeds
@@ -140,9 +141,11 @@ class Experiment(object):
reset_after_second_run: bool
If true, reset to best model when second run did not find a valid loss
below or equal to the best train loss of first run.
print_0_epoch: bool
Whether to compute monitor values and print them before the
log_0_epoch: bool
Whether to compute monitor values and log them before the
start of training.
loggers: list of :class:`.Logger`
How to show computed metrics.
seed: int
Random seed for python random module numpy.random and torch.
@@ -159,7 +162,8 @@ class Experiment(object):
batch_modifier=None, cuda=True, pin_memory=False,
do_early_stop=True,
reset_after_second_run=False,
print_0_epoch=True,
log_0_epoch=True,
loggers=('print',),
seed=2382938):
if run_after_early_stop or reset_after_second_run:
assert do_early_stop == True, ("Can only run after early stop or "
@@ -194,8 +198,9 @@ class Experiment(object):
self.pin_memory = pin_memory
self.do_early_stop = do_early_stop
self.reset_after_second_run = reset_after_second_run
self.print_0_epoch = print_0_epoch
self.log_0_epoch = log_0_epoch
self.seed = seed
self.loggers = loggers
def run(self):
"""
@@ -231,6 +236,8 @@ class Experiment(object):
# reset remember best extension in case you rerun some experiment
if self.do_early_stop:
self.rememberer = RememberBest(self.remember_best_column)
if self.loggers == ('print',):
self.loggers = [Printer()]
self.epochs_df = pd.DataFrame()
set_random_seeds(seed=self.seed, cuda=self.cuda)
if self.cuda:
@@ -272,9 +279,9 @@ class Experiment(object):
remember_best: bool
Whether to remember parameters at best epoch.
"""
if self.print_0_epoch:
if self.log_0_epoch:
self.monitor_epoch(datasets)
self.print_epoch()
self.log_epoch()
if remember_best:
self.rememberer.remember_epoch(self.epochs_df, self.model,
self.optimizer)
@@ -310,7 +317,7 @@ class Experiment(object):
end_train_epoch_time - start_train_epoch_time))
self.monitor_epoch(datasets)
self.print_epoch()
self.log_epoch()
if remember_best:
self.rememberer.remember_epoch(self.epochs_df, self.model,
self.optimizer)
@@ -356,21 +363,20 @@ class Experiment(object):
"""
self.model.eval()
input_vars = np_to_var(inputs, pin_memory=self.pin_memory,
volatile=True)
target_vars = np_to_var(targets, pin_memory=self.pin_memory,
volatile=True)
if self.cuda:
input_vars = input_vars.cuda()
target_vars = target_vars.cuda()
outputs = self.model(input_vars)
loss = self.loss_function(outputs, target_vars)
if hasattr(outputs, 'cpu'):
outputs = outputs.cpu().data.numpy()
else:
# assume it is iterable
outputs = [o.cpu().data.numpy() for o in outputs]
loss = loss.cpu().data.numpy()
with th.no_grad():
input_vars = np_to_var(inputs, pin_memory=self.pin_memory)
target_vars = np_to_var(targets, pin_memory=self.pin_memory)
if self.cuda:
input_vars = input_vars.cuda()
target_vars = target_vars.cuda()
outputs = self.model(input_vars)
loss = self.loss_function(outputs, target_vars)
if hasattr(outputs, 'cpu'):
outputs = outputs.cpu().data.numpy()
else:
# assume it is iterable
outputs = [o.cpu().data.numpy() for o in outputs]
loss = loss.cpu().data.numpy()
return outputs, loss
def monitor_epoch(self, datasets):
@@ -420,17 +426,12 @@ class Experiment(object):
assert set(self.epochs_df.columns) == set(row_dict.keys())
self.epochs_df = self.epochs_df[list(row_dict.keys())]
def print_epoch(self):
def log_epoch(self):
"""
Print monitoring values for this epoch.
"""
# -1 due to doing one monitor at start of training
i_epoch = len(self.epochs_df) - 1
log.info("Epoch {:d}".format(i_epoch))
last_row = self.epochs_df.iloc[-1]
for key, val in last_row.iteritems():
log.info("{:25s} {:.5f}".format(key, val))
log.info("")
for logger in self.loggers:
logger.log_epoch(self.epochs_df)
def setup_after_stop_training(self):
"""
+48
Ver Arquivo
@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod
import logging
log = logging.getLogger(__name__)
class Logger(ABC):
@abstractmethod
def log_epoch(self, epochs_df):
raise NotImplementedError("Need to implement the log_epoch function!")
class Printer(Logger):
"""
Prints output to the terminal using Python's logging module.
"""
def log_epoch(self, epochs_df):
# -1 due to doing one monitor at start of training
i_epoch = len(epochs_df) - 1
log.info("Epoch {:d}".format(i_epoch))
last_row = epochs_df.iloc[-1]
for key, val in last_row.iteritems():
log.info("{:25s} {:.5f}".format(key, val))
log.info("")
class TensorboardWriter(Logger):
"""
Logs all values for tensorboard visualiuzation using tensorboardX.
Parameters
----------
log_dir: string
Directory path to log the output to
"""
def __init__(self, log_dir):
# import inside to prevent dependency of braindecode onto tensorboardX
from tensorboardX import SummaryWriter
self.writer = SummaryWriter(log_dir)
def log_epoch(self, epochs_df):
# -1 due to doing one monitor at start of training
i_epoch = len(epochs_df) - 1
last_row = epochs_df.iloc[-1]
for key, val in last_row.iteritems():
val = last_row[key]
self.writer.add_scalar(key, val, i_epoch)
+69 -11
Ver Arquivo
@@ -70,6 +70,34 @@ class MisclassMonitor(object):
return {column_name: float(misclass)}
def compute_pred_labels_from_trial_preds(
all_preds, threshold_for_binary_case=None):
all_pred_labels = []
for i_batch in range(len(all_preds)):
preds = all_preds[i_batch]
# preds could be examples x classes x time
# or just
# examples x classes
# make sure not to remove first dimension if it only has size one
if preds.ndim > 1:
only_one_row = preds.shape[0] == 1
pred_labels = np.argmax(preds, axis=1).squeeze()
# add first dimension again if needed
if only_one_row:
pred_labels = pred_labels[None]
else:
assert threshold_for_binary_case is not None, (
"In case of only one output, please supply the "
"threshold_for_binary_case parameter")
# binary classification case... assume logits
pred_labels = np.int32(preds > threshold_for_binary_case)
# now examples x time or examples
all_pred_labels.extend(pred_labels)
all_pred_labels = np.array(all_pred_labels)
return all_pred_labels
class AveragePerClassMisclassMonitor(object):
"""
Compute average of misclasses per class,
@@ -148,6 +176,7 @@ class LossMonitor(object):
return {column_name: mean_loss}
class CroppedTrialMisclassMonitor(object):
"""
Compute trialwise misclasses from predictions for crops.
@@ -169,9 +198,10 @@ class CroppedTrialMisclassMonitor(object):
assert self.input_time_length is not None, "Need to know input time length..."
# First case that each trial only has a single label
if not hasattr(dataset.y[0], '__len__'):
all_pred_labels = self._compute_pred_labels(dataset, all_preds)
all_trial_labels = dataset.y
all_pred_labels = compute_trial_labels_from_crop_preds(
all_preds, self.input_time_length, dataset.X)
assert all_pred_labels.shape == dataset.y.shape
all_trial_labels = dataset.y
else:
all_trial_labels, all_pred_labels = (
self._compute_trial_pred_labels_from_cnt_y(dataset, all_preds))
@@ -181,8 +211,8 @@ class CroppedTrialMisclassMonitor(object):
return {column_name: float(misclass)}
def _compute_pred_labels(self, dataset, all_preds, ):
preds_per_trial = compute_preds_per_trial_for_set(
all_preds, self.input_time_length, dataset)
preds_per_trial = compute_preds_per_trial_from_crops(
all_preds, self.input_time_length, dataset.X)
all_pred_labels = [np.argmax(np.mean(p, axis=1))
for p in preds_per_trial]
all_pred_labels = np.array(all_pred_labels)
@@ -194,8 +224,8 @@ class CroppedTrialMisclassMonitor(object):
# we only want the preds that are for the same labels as the last label in y
# (there might be parts of other class-data at start, for trialwise misclass we assume
# they are contained in other trials at the end...)
preds_per_trial = compute_preds_per_trial_for_set(
all_preds, self.input_time_length, dataset)
preds_per_trial = compute_preds_per_trial_from_crops(
all_preds, self.input_time_length, dataset.X)
trial_labels = []
trial_pred_labels = []
for trial_pred, trial_y in zip(preds_per_trial, dataset.y):
@@ -217,8 +247,36 @@ class CroppedTrialMisclassMonitor(object):
return trial_labels, trial_pred_labels
def compute_preds_per_trial_for_set(all_preds, input_time_length,
dataset, ):
def compute_trial_labels_from_crop_preds(all_preds, input_time_length, X):
"""
Compute predicted trial labels from arrays of crop predictions
Parameters
----------
all_preds: list of 2darrays (classes x time)
All predictions for the crops.
input_time_length: int
Temporal length of one input to the model.
X: ndarray
Input tensor the crops were taken from.
Returns
-------
pred_labels_per_trial: 1darray
Predicted label for each trial.
"""
preds_per_trial = compute_preds_per_trial_from_crops(
all_preds, input_time_length, X)
pred_labels_per_trial = [np.argmax(np.mean(p, axis=1))
for p in preds_per_trial]
pred_labels_per_trial = np.array(pred_labels_per_trial)
return pred_labels_per_trial
def compute_preds_per_trial_from_crops(all_preds, input_time_length,
X, ):
"""
Compute predictions per trial from predictions for crops.
@@ -228,8 +286,8 @@ def compute_preds_per_trial_for_set(all_preds, input_time_length,
All predictions for the crops.
input_time_length: int
Temporal length of one input to the model.
dataset: :class:`.SignalAndTarget`
Dataset the crops were taken from.
X: ndarray
Input tensor the crops were taken from.
Returns
-------
@@ -240,7 +298,7 @@ def compute_preds_per_trial_for_set(all_preds, input_time_length,
n_preds_per_input = all_preds[0].shape[2]
n_receptive_field = input_time_length - n_preds_per_input + 1
n_preds_per_trial = [trial.shape[1] - n_receptive_field + 1
for trial in dataset.X]
for trial in X]
preds_per_trial = compute_preds_per_trial_from_n_preds_per_trial(
all_preds, n_preds_per_trial)
return preds_per_trial
+51 -13
Ver Arquivo
@@ -1,8 +1,12 @@
import time
import numpy as np
from numpy.random import RandomState
import torch as th
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
RuntimeMonitor, CroppedTrialMisclassMonitor
RuntimeMonitor, CroppedTrialMisclassMonitor, \
compute_trial_labels_from_crop_preds, compute_pred_labels_from_trial_preds
from braindecode.experiments.stopcriteria import MaxEpochs
from braindecode.datautil.iterators import BalancedBatchSizeIterator, \
CropsFromTrialsIterator
@@ -10,7 +14,7 @@ from braindecode.experiments.experiment import Experiment
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.models.util import to_dense_prediction_model
from braindecode.torch_ext.schedulers import CosineAnnealing, ScheduledOptimizer
from braindecode.torch_ext.util import np_to_var
from braindecode.torch_ext.util import np_to_var, var_to_np
def find_optimizer(optimizer_name):
@@ -62,6 +66,25 @@ class BaseModel(object):
def fit(self, train_X, train_y, epochs, batch_size, input_time_length=None,
validation_data=None, model_constraint=None,
remember_best_column=None, scheduler=None):
"""
Parameters
----------
train_X
train_y
epochs
batch_size
input_time_length
validation_data
model_constraint
remember_best_column
scheduler
Returns
-------
"""
if not self.compiled:
raise ValueError("Compile the model first by calling model.compile(loss, optimizer, metrics)")
@@ -78,19 +101,19 @@ class BaseModel(object):
test_input = test_input.cuda()
out = self.network(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
iterator = CropsFromTrialsIterator(
self.iterator = CropsFromTrialsIterator(
batch_size=batch_size, input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input,
seed=self.seed_rng.randint(0, 4294967295))
else:
iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
self.iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
stop_criterion = MaxEpochs(epochs - 1)# -1 since we dont print 0 epoch, which matters for this stop criterion
train_set = SignalAndTarget(train_X, train_y)
optimizer = self.optimizer
if scheduler is not None:
assert scheduler == 'cosine'
n_updates_per_epoch = sum(
[1 for _ in iterator.get_batches(train_set, shuffle=True)])
[1 for _ in self.iterator.get_batches(train_set, shuffle=True)])
n_updates_per_period = n_updates_per_epoch * epochs
if scheduler == 'cosine':
scheduler = CosineAnnealing(n_updates_per_period)
@@ -117,22 +140,20 @@ class BaseModel(object):
extra_monitors = [monitor_dict[m]() for m in self.metrics]
self.monitors += extra_monitors
self.monitors += [RuntimeMonitor()]
exp = Experiment(self.network, train_set, valid_set, test_set, iterator=iterator,
exp = Experiment(self.network, train_set, valid_set, test_set,
iterator=self.iterator,
loss_function=loss_function, optimizer=optimizer,
model_constraint=model_constraint,
monitors=self.monitors,
stop_criterion=stop_criterion,
remember_best_column=remember_best_column,
run_after_early_stop=False, cuda=self.cuda, print_0_epoch=False,
run_after_early_stop=False, cuda=self.cuda, log_0_epoch=False,
do_early_stop=(remember_best_column is not None))
exp.run()
self.epochs_df = exp.epochs_df
return exp
def evaluate(self, X,y, batch_size=32):
# Create a dummy experiment for the evaluation
iterator = BalancedBatchSizeIterator(batch_size=batch_size,
seed=0) # seed irrelevant
def evaluate(self, X,y):
stop_criterion = MaxEpochs(0)
train_set = SignalAndTarget(X, y)
model_constraint = None
@@ -142,15 +163,20 @@ class BaseModel(object):
if self.cropped:
loss_function = lambda outputs, targets: \
self.loss(th.mean(outputs, dim=2), targets)
# reset runtime monitor if exists...
for monitor in self.monitors:
if hasattr(monitor, 'last_call_time'):
monitor.last_call_time = time.time()
exp = Experiment(self.network, train_set, valid_set, test_set,
iterator=iterator,
iterator=self.iterator,
loss_function=loss_function, optimizer=self.optimizer,
model_constraint=model_constraint,
monitors=self.monitors,
stop_criterion=stop_criterion,
remember_best_column=None,
run_after_early_stop=False, cuda=self.cuda,
print_0_epoch=False,
log_0_epoch=False,
do_early_stop=False)
exp.monitor_epoch({'train': train_set})
@@ -159,3 +185,15 @@ class BaseModel(object):
for key, val in
dict(exp.epochs_df.iloc[0]).items()])
return result_dict
def predict(self, X, threshold_for_binary_case=None):
all_preds = []
for b_X, _ in self.iterator.get_batches(SignalAndTarget(X, X), False):
all_preds.append(var_to_np(self.network(np_to_var(b_X))))
if self.cropped:
pred_labels = compute_trial_labels_from_crop_preds(
all_preds, self.iterator.input_time_length, X)
else:
pred_labels = compute_pred_labels_from_trial_preds(
all_preds, threshold_for_binary_case)
return pred_labels