Lazy Loading (#38)
* fixing memory problems memory problems were caused by successively appending predictions to list. fixed by allocating memory beforehand. * created lazy_iterators.py added LoadCropsFromTrialsIterator * created lazy_dataset.py added abstract class as a parent for lazy datasets * Update lazy_iterators.py now saving random state before creating data loader and resetting to random state after every batch. data loader behavior was changed with pytorch upgrade 0.4.0 -> 1.0.0 and broke tests. * Update experiment.py * Update experiment.py * Update lazy_iterators.py * Update lazy_dataset.py * Update iterators.py * Update lazy_iterators.py * Update iterators.py * Update lazy_iterators.py * Update lazy_iterators.py added option to toggle resetting of rng state * Update iterators.py deactivated reset rng state after every batch by default * Update iterators.py accidentally changed traditional iterators. reverting changes * Update lazy_iterators.py deactivated resetting rng after every batch by default * Update experiment.py removed unused variables
Esse commit está contido em:
@@ -0,0 +1,45 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LazyDataset(ABC):
|
||||
""" Class implementing an abstract lazy data set. Custom lazy data sets
|
||||
have to override file_paths, X and y as well as the load_lazy function to
|
||||
load trials or crops. """
|
||||
def __init__(self):
|
||||
self.file_paths = "Not implemented: a list of all file paths"
|
||||
self.X = ("Not implemented: a list of empty ndarrays with number of "
|
||||
"samples as second dimension")
|
||||
self.y = "Not implemented: a list of all targets"
|
||||
|
||||
@abstractmethod
|
||||
def load_lazy(self, path, start_i, stop_i):
|
||||
""" Loading procedure that gets a file path, start and stop indices.
|
||||
Is supposed to return a trial / crop together with its target
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path: str
|
||||
file path
|
||||
start_i: int
|
||||
start index of signal crop
|
||||
stop_i: int
|
||||
stop index of signal crop
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
""" Returns a two-tuple of example, label """
|
||||
try:
|
||||
idx, start_i, stop_i = idx
|
||||
except (TypeError, ValueError):
|
||||
start_i = 0
|
||||
stop_i = None
|
||||
file_path = self.file_paths[idx]
|
||||
x = self.load_lazy(file_path, start_i, stop_i)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
return x, self.y[idx]
|
||||
@@ -0,0 +1,123 @@
|
||||
from torch.utils.data import DataLoader
|
||||
from numpy.random import RandomState
|
||||
from functools import partial
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
from braindecode.datautil.iterators import _compute_start_stop_block_inds, \
|
||||
get_balanced_batches
|
||||
|
||||
|
||||
def custom_collate(batch, rng_state=None):
|
||||
""" Puts each data field into a ndarray with outer dimension batch size.
|
||||
Taken and adapted from pytorch to return ndarrays instead of tensors:
|
||||
https://pytorch.org/docs/0.4.1/_modules/torch/utils/data/dataloader.html
|
||||
|
||||
this function is needed, since tensors require more system RAM which we
|
||||
want to decrease using lazy loading
|
||||
"""
|
||||
elem_type = type(batch[0])
|
||||
if elem_type.__module__ == 'numpy':
|
||||
if rng_state is not None:
|
||||
th.random.set_rng_state(rng_state)
|
||||
return np.stack([b for b in batch], 0)
|
||||
|
||||
elif isinstance(batch[0], tuple):
|
||||
transposed = zip(*batch)
|
||||
return [custom_collate(samples, rng_state) for samples in transposed]
|
||||
|
||||
|
||||
class LazyCropsFromTrialsIterator(object):
|
||||
""" This is basically the same code as CropsFromTrialsIterator adapted to
|
||||
work with lazy datasets. It uses pytorch DataLoader to load recordings
|
||||
from hdd with multiple threads when the data is actually needed. Reduces
|
||||
overall RAM requirements.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_time_length: int
|
||||
Input time length of the ConvNet, determines size of batches in
|
||||
3rd dimension.
|
||||
n_preds_per_input: int
|
||||
Number of predictions ConvNet makes per one input. Can be computed
|
||||
by making a forward pass with the given input time length, the
|
||||
output length in 3rd dimension is n_preds_per_input.
|
||||
batch_size: int
|
||||
seed: int
|
||||
Random seed for initialization of `numpy.RandomState` random generator
|
||||
that shuffles the batches.
|
||||
num_workers: int
|
||||
The number of workers to load crops in parallel
|
||||
collate_fn: func
|
||||
Merges a list of samples to form a mini-batch
|
||||
check_preds_smaller_trial_len: bool
|
||||
Checking validity of predictions and trial lengths. Disable to decrease
|
||||
runtime.
|
||||
"""
|
||||
def __init__(self, input_time_length, n_preds_per_input, batch_size,
|
||||
seed=328774, num_workers=0, collate_fn=custom_collate,
|
||||
check_preds_smaller_trial_len=True,
|
||||
reset_rng_after_each_batch=False):
|
||||
self.batch_size = batch_size
|
||||
self.seed = seed
|
||||
self.rng = RandomState(self.seed)
|
||||
self.input_time_length = input_time_length
|
||||
self.n_preds_per_input = n_preds_per_input
|
||||
self.num_workers = num_workers
|
||||
self.collate_fn = collate_fn
|
||||
self.check_preds_smaller_trial_len = check_preds_smaller_trial_len
|
||||
self.reset_rng_after_each_batch = reset_rng_after_each_batch
|
||||
|
||||
def reset_rng(self):
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
def get_batches(self, dataset, shuffle):
|
||||
# in pytorch 1.0.0, internal random state is changed when using a
|
||||
# DataLoader, even if num_workers is 0. this did not happen in torch
|
||||
# 0.4.0 and breaks our equality tests of traditional and lazy loading
|
||||
# therefore, in the collate function of every batch, reset to the
|
||||
# random state before iterating through batches.
|
||||
if self.reset_rng_after_each_batch:
|
||||
random_state = th.random.get_rng_state()
|
||||
collate_fn = partial(self.collate_fn, rng_state=random_state)
|
||||
else:
|
||||
collate_fn = partial(self.collate_fn, rng_state=None)
|
||||
batch_indeces = self._get_batch_indeces(dataset=dataset,
|
||||
shuffle=shuffle)
|
||||
data_loader = DataLoader(dataset=dataset, batch_sampler=batch_indeces,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=False, collate_fn=collate_fn)
|
||||
return data_loader
|
||||
|
||||
def _get_batch_indeces(self, dataset, shuffle):
|
||||
# start always at first predictable sample, so
|
||||
# start at end of receptive field
|
||||
n_receptive_field = self.input_time_length - self.n_preds_per_input + 1
|
||||
i_trial_starts = [n_receptive_field - 1] * len(dataset.X)
|
||||
i_trial_stops = [trial.shape[1] for trial in dataset.X]
|
||||
|
||||
# Check whether input lengths ok
|
||||
input_lens = i_trial_stops
|
||||
for i_trial, input_len in enumerate(input_lens):
|
||||
assert input_len >= self.input_time_length, (
|
||||
"Input length {:d} of trial {:d} is smaller than the "
|
||||
"input time length {:d}".format(input_len, i_trial,
|
||||
self.input_time_length))
|
||||
|
||||
start_stop_blocks_per_trial = _compute_start_stop_block_inds(
|
||||
i_trial_starts, i_trial_stops, self.input_time_length,
|
||||
self.n_preds_per_input,
|
||||
check_preds_smaller_trial_len=self.check_preds_smaller_trial_len)
|
||||
for i_trial, trial_blocks in enumerate(start_stop_blocks_per_trial):
|
||||
assert trial_blocks[0][0] == 0
|
||||
assert trial_blocks[-1][1] == i_trial_stops[i_trial]
|
||||
|
||||
i_trial_start_stop_block = np.array([
|
||||
(i_trial, start, stop) for i_trial, block in
|
||||
enumerate(start_stop_blocks_per_trial) for start, stop in block])
|
||||
|
||||
batches = get_balanced_batches(
|
||||
n_trials=len(i_trial_start_stop_block), rng=self.rng,
|
||||
shuffle=shuffle, batch_size=self.batch_size)
|
||||
|
||||
return [i_trial_start_stop_block[batch_ind] for batch_ind in batches]
|
||||
@@ -5,11 +5,12 @@ import time
|
||||
|
||||
import pandas as pd
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
from braindecode.datautil.splitters import concatenate_sets
|
||||
from braindecode.experiments.loggers import Printer
|
||||
from braindecode.experiments.stopcriteria import MaxEpochs, ColumnBelow, Or
|
||||
from braindecode.torch_ext.util import np_to_var, set_random_seeds
|
||||
from braindecode.torch_ext.util import np_to_var
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -214,15 +215,15 @@ class Experiment(object):
|
||||
log.info("Run until second stop...")
|
||||
loss_to_reach = float(self.epochs_df['train_loss'].iloc[-1])
|
||||
self.run_until_second_stop()
|
||||
if self.reset_after_second_run:
|
||||
if (float(self.epochs_df['valid_loss'].iloc[-1]) > loss_to_reach
|
||||
and self.reset_after_second_run):
|
||||
# if no valid loss was found below the best train loss on 1st
|
||||
# run, reset model to the epoch with lowest valid_misclass
|
||||
if float(self.epochs_df['valid_loss'].iloc[-1]) > loss_to_reach:
|
||||
log.info("Resetting to best epoch {:d}".format(
|
||||
self.rememberer.best_epoch))
|
||||
self.rememberer.reset_to_best_model(self.epochs_df,
|
||||
self.model,
|
||||
self.optimizer)
|
||||
log.info("Resetting to best epoch {:d}".format(
|
||||
self.rememberer.best_epoch))
|
||||
self.rememberer.reset_to_best_model(self.epochs_df,
|
||||
self.model,
|
||||
self.optimizer)
|
||||
|
||||
def setup_training(self):
|
||||
"""
|
||||
@@ -367,11 +368,11 @@ class Experiment(object):
|
||||
outputs = self.model(input_vars)
|
||||
loss = self.loss_function(outputs, target_vars)
|
||||
if hasattr(outputs, 'cpu'):
|
||||
outputs = outputs.cpu().data.numpy()
|
||||
outputs = outputs.cpu().detach().numpy()
|
||||
else:
|
||||
# assume it is iterable
|
||||
outputs = [o.cpu().data.numpy() for o in outputs]
|
||||
loss = loss.cpu().data.numpy()
|
||||
outputs = [o.cpu().detach().numpy() for o in outputs]
|
||||
loss = loss.cpu().detach().numpy()
|
||||
return outputs, loss
|
||||
|
||||
def monitor_epoch(self, datasets):
|
||||
@@ -390,23 +391,66 @@ class Experiment(object):
|
||||
result_dicts_per_monitor = OrderedDict()
|
||||
for m in self.monitors:
|
||||
result_dicts_per_monitor[m] = OrderedDict()
|
||||
for m in self.monitors:
|
||||
result_dict = m.monitor_epoch()
|
||||
if result_dict is not None:
|
||||
result_dicts_per_monitor[m].update(result_dict)
|
||||
for setname in datasets:
|
||||
assert setname in ['train', 'valid', 'test']
|
||||
dataset = datasets[setname]
|
||||
all_preds = []
|
||||
all_losses = []
|
||||
all_batch_sizes = []
|
||||
all_targets = []
|
||||
for batch in self.iterator.get_batches(dataset, shuffle=False):
|
||||
preds, loss = self.eval_on_batch(batch[0], batch[1])
|
||||
all_preds.append(preds)
|
||||
batch_generator = self.iterator.get_batches(dataset, shuffle=False)
|
||||
if hasattr(batch_generator, "__len__"):
|
||||
# prevent loading of data to estimate number of batches when
|
||||
# using lazy iterators
|
||||
n_batches = len(batch_generator)
|
||||
else:
|
||||
# iterating through traditional iterators is cheap, since
|
||||
# nothing is loaded, recreate generator afterwards
|
||||
n_batches = sum(1 for i in batch_generator)
|
||||
batch_generator = self.iterator.get_batches(dataset,
|
||||
shuffle=False)
|
||||
all_preds, all_targets = None, None
|
||||
all_losses, all_batch_sizes = [], []
|
||||
for inputs, targets in batch_generator:
|
||||
preds, loss = self.eval_on_batch(inputs, targets)
|
||||
all_losses.append(loss)
|
||||
all_batch_sizes.append(len(batch[0]))
|
||||
all_targets.append(batch[1])
|
||||
all_batch_sizes.append(len(targets))
|
||||
if all_preds is None:
|
||||
assert all_targets is None
|
||||
# first batch size is largest
|
||||
max_size, n_classes, n_preds_per_input = preds.shape
|
||||
# pre-allocate memory for all predictions and targets
|
||||
all_preds = np.nan * np.ones(
|
||||
(n_batches * max_size, n_classes, n_preds_per_input),
|
||||
dtype=np.float32)
|
||||
all_preds[:len(preds)] = preds
|
||||
all_targets = np.nan * np.ones((n_batches * max_size))
|
||||
all_targets[:len(targets)] = targets
|
||||
else:
|
||||
start_i = sum(all_batch_sizes[:-1])
|
||||
stop_i = sum(all_batch_sizes)
|
||||
all_preds[start_i:stop_i] = preds
|
||||
all_targets[start_i:stop_i] = targets
|
||||
|
||||
# check for unequal batches
|
||||
unequal_batches = len(set(all_batch_sizes)) > 1
|
||||
all_batch_sizes = sum(all_batch_sizes)
|
||||
# remove nan rows in case of unequal batch sizes
|
||||
if unequal_batches:
|
||||
assert np.sum(np.isnan(all_preds[:all_batch_sizes - 1])) == 0
|
||||
assert np.sum(np.isnan(all_preds[all_batch_sizes:])) > 0
|
||||
range_to_delete = range(all_batch_sizes, len(all_preds))
|
||||
all_preds = np.delete(all_preds, range_to_delete, axis=0)
|
||||
all_targets = np.delete(all_targets, range_to_delete, axis=0)
|
||||
assert np.sum(np.isnan(all_preds)) == 0, (
|
||||
"There are still nans in predictions")
|
||||
assert np.sum(np.isnan(all_targets)) == 0, (
|
||||
"There are still nans in targets")
|
||||
# add empty dimension
|
||||
# monitors expect n_batches x ...
|
||||
all_preds = all_preds[np.newaxis, :]
|
||||
all_targets = all_targets[np.newaxis, :]
|
||||
all_batch_sizes = [all_batch_sizes]
|
||||
all_losses = [all_losses]
|
||||
|
||||
for m in self.monitors:
|
||||
result_dict = m.monitor_set(setname, all_preds, all_losses,
|
||||
@@ -418,7 +462,9 @@ class Experiment(object):
|
||||
for m in self.monitors:
|
||||
row_dict.update(result_dicts_per_monitor[m])
|
||||
self.epochs_df = self.epochs_df.append(row_dict, ignore_index=True)
|
||||
assert set(self.epochs_df.columns) == set(row_dict.keys())
|
||||
assert set(self.epochs_df.columns) == set(row_dict.keys()), (
|
||||
"Columns of dataframe: {:s}\n and keys of dict {:s} not same")\
|
||||
.format(str(set(self.epochs_df.columns)), str(set(row_dict.keys())))
|
||||
self.epochs_df = self.epochs_df[list(row_dict.keys())]
|
||||
|
||||
def log_epoch(self):
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário