new simple model class
Esse commit está contido em:
@@ -7,6 +7,7 @@ import pandas as pd
|
||||
import torch as th
|
||||
|
||||
from braindecode.datautil.splitters import concatenate_sets
|
||||
from braindecode.experiments.loggers import Printer
|
||||
from braindecode.experiments.stopcriteria import MaxEpochs, ColumnBelow, Or
|
||||
from braindecode.torch_ext.util import np_to_var, set_random_seeds
|
||||
|
||||
@@ -140,9 +141,11 @@ class Experiment(object):
|
||||
reset_after_second_run: bool
|
||||
If true, reset to best model when second run did not find a valid loss
|
||||
below or equal to the best train loss of first run.
|
||||
print_0_epoch: bool
|
||||
Whether to compute monitor values and print them before the
|
||||
log_0_epoch: bool
|
||||
Whether to compute monitor values and log them before the
|
||||
start of training.
|
||||
loggers: list of :class:`.Logger`
|
||||
How to show computed metrics.
|
||||
seed: int
|
||||
Random seed for python random module numpy.random and torch.
|
||||
|
||||
@@ -159,7 +162,8 @@ class Experiment(object):
|
||||
batch_modifier=None, cuda=True, pin_memory=False,
|
||||
do_early_stop=True,
|
||||
reset_after_second_run=False,
|
||||
print_0_epoch=True,
|
||||
log_0_epoch=True,
|
||||
loggers=('print',),
|
||||
seed=2382938):
|
||||
if run_after_early_stop or reset_after_second_run:
|
||||
assert do_early_stop == True, ("Can only run after early stop or "
|
||||
@@ -194,8 +198,9 @@ class Experiment(object):
|
||||
self.pin_memory = pin_memory
|
||||
self.do_early_stop = do_early_stop
|
||||
self.reset_after_second_run = reset_after_second_run
|
||||
self.print_0_epoch = print_0_epoch
|
||||
self.log_0_epoch = log_0_epoch
|
||||
self.seed = seed
|
||||
self.loggers = loggers
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
@@ -231,6 +236,8 @@ class Experiment(object):
|
||||
# reset remember best extension in case you rerun some experiment
|
||||
if self.do_early_stop:
|
||||
self.rememberer = RememberBest(self.remember_best_column)
|
||||
if self.loggers == ('print',):
|
||||
self.loggers = [Printer()]
|
||||
self.epochs_df = pd.DataFrame()
|
||||
set_random_seeds(seed=self.seed, cuda=self.cuda)
|
||||
if self.cuda:
|
||||
@@ -272,9 +279,9 @@ class Experiment(object):
|
||||
remember_best: bool
|
||||
Whether to remember parameters at best epoch.
|
||||
"""
|
||||
if self.print_0_epoch:
|
||||
if self.log_0_epoch:
|
||||
self.monitor_epoch(datasets)
|
||||
self.print_epoch()
|
||||
self.log_epoch()
|
||||
if remember_best:
|
||||
self.rememberer.remember_epoch(self.epochs_df, self.model,
|
||||
self.optimizer)
|
||||
@@ -310,7 +317,7 @@ class Experiment(object):
|
||||
end_train_epoch_time - start_train_epoch_time))
|
||||
|
||||
self.monitor_epoch(datasets)
|
||||
self.print_epoch()
|
||||
self.log_epoch()
|
||||
if remember_best:
|
||||
self.rememberer.remember_epoch(self.epochs_df, self.model,
|
||||
self.optimizer)
|
||||
@@ -356,21 +363,20 @@ class Experiment(object):
|
||||
|
||||
"""
|
||||
self.model.eval()
|
||||
input_vars = np_to_var(inputs, pin_memory=self.pin_memory,
|
||||
volatile=True)
|
||||
target_vars = np_to_var(targets, pin_memory=self.pin_memory,
|
||||
volatile=True)
|
||||
if self.cuda:
|
||||
input_vars = input_vars.cuda()
|
||||
target_vars = target_vars.cuda()
|
||||
outputs = self.model(input_vars)
|
||||
loss = self.loss_function(outputs, target_vars)
|
||||
if hasattr(outputs, 'cpu'):
|
||||
outputs = outputs.cpu().data.numpy()
|
||||
else:
|
||||
# assume it is iterable
|
||||
outputs = [o.cpu().data.numpy() for o in outputs]
|
||||
loss = loss.cpu().data.numpy()
|
||||
with th.no_grad():
|
||||
input_vars = np_to_var(inputs, pin_memory=self.pin_memory)
|
||||
target_vars = np_to_var(targets, pin_memory=self.pin_memory)
|
||||
if self.cuda:
|
||||
input_vars = input_vars.cuda()
|
||||
target_vars = target_vars.cuda()
|
||||
outputs = self.model(input_vars)
|
||||
loss = self.loss_function(outputs, target_vars)
|
||||
if hasattr(outputs, 'cpu'):
|
||||
outputs = outputs.cpu().data.numpy()
|
||||
else:
|
||||
# assume it is iterable
|
||||
outputs = [o.cpu().data.numpy() for o in outputs]
|
||||
loss = loss.cpu().data.numpy()
|
||||
return outputs, loss
|
||||
|
||||
def monitor_epoch(self, datasets):
|
||||
@@ -420,17 +426,12 @@ class Experiment(object):
|
||||
assert set(self.epochs_df.columns) == set(row_dict.keys())
|
||||
self.epochs_df = self.epochs_df[list(row_dict.keys())]
|
||||
|
||||
def print_epoch(self):
|
||||
def log_epoch(self):
|
||||
"""
|
||||
Print monitoring values for this epoch.
|
||||
"""
|
||||
# -1 due to doing one monitor at start of training
|
||||
i_epoch = len(self.epochs_df) - 1
|
||||
log.info("Epoch {:d}".format(i_epoch))
|
||||
last_row = self.epochs_df.iloc[-1]
|
||||
for key, val in last_row.iteritems():
|
||||
log.info("{:25s} {:.5f}".format(key, val))
|
||||
log.info("")
|
||||
for logger in self.loggers:
|
||||
logger.log_epoch(self.epochs_df)
|
||||
|
||||
def setup_after_stop_training(self):
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Logger(ABC):
|
||||
@abstractmethod
|
||||
def log_epoch(self, epochs_df):
|
||||
raise NotImplementedError("Need to implement the log_epoch function!")
|
||||
|
||||
|
||||
class Printer(Logger):
|
||||
"""
|
||||
Prints output to the terminal using Python's logging module.
|
||||
"""
|
||||
def log_epoch(self, epochs_df):
|
||||
# -1 due to doing one monitor at start of training
|
||||
i_epoch = len(epochs_df) - 1
|
||||
log.info("Epoch {:d}".format(i_epoch))
|
||||
last_row = epochs_df.iloc[-1]
|
||||
for key, val in last_row.iteritems():
|
||||
log.info("{:25s} {:.5f}".format(key, val))
|
||||
log.info("")
|
||||
|
||||
|
||||
class TensorboardWriter(Logger):
|
||||
|
||||
"""
|
||||
Logs all values for tensorboard visualiuzation using tensorboardX.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_dir: string
|
||||
Directory path to log the output to
|
||||
"""
|
||||
def __init__(self, log_dir):
|
||||
# import inside to prevent dependency of braindecode onto tensorboardX
|
||||
from tensorboardX import SummaryWriter
|
||||
self.writer = SummaryWriter(log_dir)
|
||||
|
||||
def log_epoch(self, epochs_df):
|
||||
# -1 due to doing one monitor at start of training
|
||||
i_epoch = len(epochs_df) - 1
|
||||
last_row = epochs_df.iloc[-1]
|
||||
for key, val in last_row.iteritems():
|
||||
val = last_row[key]
|
||||
self.writer.add_scalar(key, val, i_epoch)
|
||||
@@ -70,6 +70,34 @@ class MisclassMonitor(object):
|
||||
return {column_name: float(misclass)}
|
||||
|
||||
|
||||
def compute_pred_labels_from_trial_preds(
|
||||
all_preds, threshold_for_binary_case=None):
|
||||
all_pred_labels = []
|
||||
for i_batch in range(len(all_preds)):
|
||||
preds = all_preds[i_batch]
|
||||
# preds could be examples x classes x time
|
||||
# or just
|
||||
# examples x classes
|
||||
# make sure not to remove first dimension if it only has size one
|
||||
if preds.ndim > 1:
|
||||
only_one_row = preds.shape[0] == 1
|
||||
|
||||
pred_labels = np.argmax(preds, axis=1).squeeze()
|
||||
# add first dimension again if needed
|
||||
if only_one_row:
|
||||
pred_labels = pred_labels[None]
|
||||
else:
|
||||
assert threshold_for_binary_case is not None, (
|
||||
"In case of only one output, please supply the "
|
||||
"threshold_for_binary_case parameter")
|
||||
# binary classification case... assume logits
|
||||
pred_labels = np.int32(preds > threshold_for_binary_case)
|
||||
# now examples x time or examples
|
||||
all_pred_labels.extend(pred_labels)
|
||||
all_pred_labels = np.array(all_pred_labels)
|
||||
return all_pred_labels
|
||||
|
||||
|
||||
class AveragePerClassMisclassMonitor(object):
|
||||
"""
|
||||
Compute average of misclasses per class,
|
||||
@@ -148,6 +176,7 @@ class LossMonitor(object):
|
||||
return {column_name: mean_loss}
|
||||
|
||||
|
||||
|
||||
class CroppedTrialMisclassMonitor(object):
|
||||
"""
|
||||
Compute trialwise misclasses from predictions for crops.
|
||||
@@ -169,9 +198,10 @@ class CroppedTrialMisclassMonitor(object):
|
||||
assert self.input_time_length is not None, "Need to know input time length..."
|
||||
# First case that each trial only has a single label
|
||||
if not hasattr(dataset.y[0], '__len__'):
|
||||
all_pred_labels = self._compute_pred_labels(dataset, all_preds)
|
||||
all_trial_labels = dataset.y
|
||||
all_pred_labels = compute_trial_labels_from_crop_preds(
|
||||
all_preds, self.input_time_length, dataset.X)
|
||||
assert all_pred_labels.shape == dataset.y.shape
|
||||
all_trial_labels = dataset.y
|
||||
else:
|
||||
all_trial_labels, all_pred_labels = (
|
||||
self._compute_trial_pred_labels_from_cnt_y(dataset, all_preds))
|
||||
@@ -181,8 +211,8 @@ class CroppedTrialMisclassMonitor(object):
|
||||
return {column_name: float(misclass)}
|
||||
|
||||
def _compute_pred_labels(self, dataset, all_preds, ):
|
||||
preds_per_trial = compute_preds_per_trial_for_set(
|
||||
all_preds, self.input_time_length, dataset)
|
||||
preds_per_trial = compute_preds_per_trial_from_crops(
|
||||
all_preds, self.input_time_length, dataset.X)
|
||||
all_pred_labels = [np.argmax(np.mean(p, axis=1))
|
||||
for p in preds_per_trial]
|
||||
all_pred_labels = np.array(all_pred_labels)
|
||||
@@ -194,8 +224,8 @@ class CroppedTrialMisclassMonitor(object):
|
||||
# we only want the preds that are for the same labels as the last label in y
|
||||
# (there might be parts of other class-data at start, for trialwise misclass we assume
|
||||
# they are contained in other trials at the end...)
|
||||
preds_per_trial = compute_preds_per_trial_for_set(
|
||||
all_preds, self.input_time_length, dataset)
|
||||
preds_per_trial = compute_preds_per_trial_from_crops(
|
||||
all_preds, self.input_time_length, dataset.X)
|
||||
trial_labels = []
|
||||
trial_pred_labels = []
|
||||
for trial_pred, trial_y in zip(preds_per_trial, dataset.y):
|
||||
@@ -217,8 +247,36 @@ class CroppedTrialMisclassMonitor(object):
|
||||
return trial_labels, trial_pred_labels
|
||||
|
||||
|
||||
def compute_preds_per_trial_for_set(all_preds, input_time_length,
|
||||
dataset, ):
|
||||
def compute_trial_labels_from_crop_preds(all_preds, input_time_length, X):
|
||||
"""
|
||||
Compute predicted trial labels from arrays of crop predictions
|
||||
|
||||
Parameters
|
||||
----------
|
||||
all_preds: list of 2darrays (classes x time)
|
||||
All predictions for the crops.
|
||||
input_time_length: int
|
||||
Temporal length of one input to the model.
|
||||
X: ndarray
|
||||
Input tensor the crops were taken from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_labels_per_trial: 1darray
|
||||
Predicted label for each trial.
|
||||
|
||||
"""
|
||||
|
||||
preds_per_trial = compute_preds_per_trial_from_crops(
|
||||
all_preds, input_time_length, X)
|
||||
pred_labels_per_trial = [np.argmax(np.mean(p, axis=1))
|
||||
for p in preds_per_trial]
|
||||
pred_labels_per_trial = np.array(pred_labels_per_trial)
|
||||
return pred_labels_per_trial
|
||||
|
||||
|
||||
def compute_preds_per_trial_from_crops(all_preds, input_time_length,
|
||||
X, ):
|
||||
"""
|
||||
Compute predictions per trial from predictions for crops.
|
||||
|
||||
@@ -228,8 +286,8 @@ def compute_preds_per_trial_for_set(all_preds, input_time_length,
|
||||
All predictions for the crops.
|
||||
input_time_length: int
|
||||
Temporal length of one input to the model.
|
||||
dataset: :class:`.SignalAndTarget`
|
||||
Dataset the crops were taken from.
|
||||
X: ndarray
|
||||
Input tensor the crops were taken from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -240,7 +298,7 @@ def compute_preds_per_trial_for_set(all_preds, input_time_length,
|
||||
n_preds_per_input = all_preds[0].shape[2]
|
||||
n_receptive_field = input_time_length - n_preds_per_input + 1
|
||||
n_preds_per_trial = [trial.shape[1] - n_receptive_field + 1
|
||||
for trial in dataset.X]
|
||||
for trial in X]
|
||||
preds_per_trial = compute_preds_per_trial_from_n_preds_per_trial(
|
||||
all_preds, n_preds_per_trial)
|
||||
return preds_per_trial
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import RandomState
|
||||
import torch as th
|
||||
|
||||
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
|
||||
RuntimeMonitor, CroppedTrialMisclassMonitor
|
||||
RuntimeMonitor, CroppedTrialMisclassMonitor, \
|
||||
compute_trial_labels_from_crop_preds, compute_pred_labels_from_trial_preds
|
||||
from braindecode.experiments.stopcriteria import MaxEpochs
|
||||
from braindecode.datautil.iterators import BalancedBatchSizeIterator, \
|
||||
CropsFromTrialsIterator
|
||||
@@ -10,7 +14,7 @@ from braindecode.experiments.experiment import Experiment
|
||||
from braindecode.datautil.signal_target import SignalAndTarget
|
||||
from braindecode.models.util import to_dense_prediction_model
|
||||
from braindecode.torch_ext.schedulers import CosineAnnealing, ScheduledOptimizer
|
||||
from braindecode.torch_ext.util import np_to_var
|
||||
from braindecode.torch_ext.util import np_to_var, var_to_np
|
||||
|
||||
|
||||
def find_optimizer(optimizer_name):
|
||||
@@ -62,6 +66,25 @@ class BaseModel(object):
|
||||
def fit(self, train_X, train_y, epochs, batch_size, input_time_length=None,
|
||||
validation_data=None, model_constraint=None,
|
||||
remember_best_column=None, scheduler=None):
|
||||
"""
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
train_X
|
||||
train_y
|
||||
epochs
|
||||
batch_size
|
||||
input_time_length
|
||||
validation_data
|
||||
model_constraint
|
||||
remember_best_column
|
||||
scheduler
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
if not self.compiled:
|
||||
raise ValueError("Compile the model first by calling model.compile(loss, optimizer, metrics)")
|
||||
|
||||
@@ -78,19 +101,19 @@ class BaseModel(object):
|
||||
test_input = test_input.cuda()
|
||||
out = self.network(test_input)
|
||||
n_preds_per_input = out.cpu().data.numpy().shape[2]
|
||||
iterator = CropsFromTrialsIterator(
|
||||
self.iterator = CropsFromTrialsIterator(
|
||||
batch_size=batch_size, input_time_length=input_time_length,
|
||||
n_preds_per_input=n_preds_per_input,
|
||||
seed=self.seed_rng.randint(0, 4294967295))
|
||||
else:
|
||||
iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
|
||||
self.iterator = BalancedBatchSizeIterator(batch_size=batch_size, seed=self.seed_rng.randint(0, 4294967295))
|
||||
stop_criterion = MaxEpochs(epochs - 1)# -1 since we dont print 0 epoch, which matters for this stop criterion
|
||||
train_set = SignalAndTarget(train_X, train_y)
|
||||
optimizer = self.optimizer
|
||||
if scheduler is not None:
|
||||
assert scheduler == 'cosine'
|
||||
n_updates_per_epoch = sum(
|
||||
[1 for _ in iterator.get_batches(train_set, shuffle=True)])
|
||||
[1 for _ in self.iterator.get_batches(train_set, shuffle=True)])
|
||||
n_updates_per_period = n_updates_per_epoch * epochs
|
||||
if scheduler == 'cosine':
|
||||
scheduler = CosineAnnealing(n_updates_per_period)
|
||||
@@ -117,22 +140,20 @@ class BaseModel(object):
|
||||
extra_monitors = [monitor_dict[m]() for m in self.metrics]
|
||||
self.monitors += extra_monitors
|
||||
self.monitors += [RuntimeMonitor()]
|
||||
exp = Experiment(self.network, train_set, valid_set, test_set, iterator=iterator,
|
||||
exp = Experiment(self.network, train_set, valid_set, test_set,
|
||||
iterator=self.iterator,
|
||||
loss_function=loss_function, optimizer=optimizer,
|
||||
model_constraint=model_constraint,
|
||||
monitors=self.monitors,
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=remember_best_column,
|
||||
run_after_early_stop=False, cuda=self.cuda, print_0_epoch=False,
|
||||
run_after_early_stop=False, cuda=self.cuda, log_0_epoch=False,
|
||||
do_early_stop=(remember_best_column is not None))
|
||||
exp.run()
|
||||
self.epochs_df = exp.epochs_df
|
||||
return exp
|
||||
|
||||
def evaluate(self, X,y, batch_size=32):
|
||||
# Create a dummy experiment for the evaluation
|
||||
iterator = BalancedBatchSizeIterator(batch_size=batch_size,
|
||||
seed=0) # seed irrelevant
|
||||
def evaluate(self, X,y):
|
||||
stop_criterion = MaxEpochs(0)
|
||||
train_set = SignalAndTarget(X, y)
|
||||
model_constraint = None
|
||||
@@ -142,15 +163,20 @@ class BaseModel(object):
|
||||
if self.cropped:
|
||||
loss_function = lambda outputs, targets: \
|
||||
self.loss(th.mean(outputs, dim=2), targets)
|
||||
|
||||
# reset runtime monitor if exists...
|
||||
for monitor in self.monitors:
|
||||
if hasattr(monitor, 'last_call_time'):
|
||||
monitor.last_call_time = time.time()
|
||||
exp = Experiment(self.network, train_set, valid_set, test_set,
|
||||
iterator=iterator,
|
||||
iterator=self.iterator,
|
||||
loss_function=loss_function, optimizer=self.optimizer,
|
||||
model_constraint=model_constraint,
|
||||
monitors=self.monitors,
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=None,
|
||||
run_after_early_stop=False, cuda=self.cuda,
|
||||
print_0_epoch=False,
|
||||
log_0_epoch=False,
|
||||
do_early_stop=False)
|
||||
|
||||
exp.monitor_epoch({'train': train_set})
|
||||
@@ -159,3 +185,15 @@ class BaseModel(object):
|
||||
for key, val in
|
||||
dict(exp.epochs_df.iloc[0]).items()])
|
||||
return result_dict
|
||||
|
||||
def predict(self, X, threshold_for_binary_case=None):
|
||||
all_preds = []
|
||||
for b_X, _ in self.iterator.get_batches(SignalAndTarget(X, X), False):
|
||||
all_preds.append(var_to_np(self.network(np_to_var(b_X))))
|
||||
if self.cropped:
|
||||
pred_labels = compute_trial_labels_from_crop_preds(
|
||||
all_preds, self.iterator.input_time_length, X)
|
||||
else:
|
||||
pred_labels = compute_pred_labels_from_trial_preds(
|
||||
all_preds, threshold_for_binary_case)
|
||||
return pred_labels
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário