Lazy Loading (#38)

* fixing memory problems

memory problems were caused by successively appending predictions to list.
fixed by allocating memory beforehand.

* created lazy_iterators.py

added LoadCropsFromTrialsIterator

* created lazy_dataset.py

added abstract class as a parent for lazy datasets

* Update lazy_iterators.py

now saving random state before creating data loader and resetting to random state after every batch. data loader behavior was changed with pytorch upgrade 0.4.0 -> 1.0.0 and broke tests.

* Update experiment.py

* Update experiment.py

* Update lazy_iterators.py

* Update lazy_dataset.py

* Update iterators.py

* Update lazy_iterators.py

* Update iterators.py

* Update lazy_iterators.py

* Update lazy_iterators.py

added option to toggle resetting of rng state

* Update iterators.py

deactivated reset rng state after every batch by default

* Update iterators.py

accidentally changed traditional iterators. reverting changes

* Update lazy_iterators.py

deactivated resetting rng after every batch by default

* Update experiment.py

removed unused variables
Esse commit está contido em:
gemeinl
2019-03-14 14:41:23 +01:00
commit de robintibor
commit 6398803efe
3 arquivos alterados com 236 adições e 22 exclusões
+45
Ver Arquivo
@@ -0,0 +1,45 @@
from abc import ABC, abstractmethod
class LazyDataset(ABC):
""" Class implementing an abstract lazy data set. Custom lazy data sets
have to override file_paths, X and y as well as the load_lazy function to
load trials or crops. """
def __init__(self):
self.file_paths = "Not implemented: a list of all file paths"
self.X = ("Not implemented: a list of empty ndarrays with number of "
"samples as second dimension")
self.y = "Not implemented: a list of all targets"
@abstractmethod
def load_lazy(self, path, start_i, stop_i):
""" Loading procedure that gets a file path, start and stop indices.
Is supposed to return a trial / crop together with its target
Parameters
----------
path: str
file path
start_i: int
start index of signal crop
stop_i: int
stop index of signal crop
"""
raise NotImplementedError
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
""" Returns a two-tuple of example, label """
try:
idx, start_i, stop_i = idx
except (TypeError, ValueError):
start_i = 0
stop_i = None
file_path = self.file_paths[idx]
x = self.load_lazy(file_path, start_i, stop_i)
if x.ndim == 2:
x = x[:, :, None]
return x, self.y[idx]
+123
Ver Arquivo
@@ -0,0 +1,123 @@
from torch.utils.data import DataLoader
from numpy.random import RandomState
from functools import partial
import torch as th
import numpy as np
from braindecode.datautil.iterators import _compute_start_stop_block_inds, \
get_balanced_batches
def custom_collate(batch, rng_state=None):
""" Puts each data field into a ndarray with outer dimension batch size.
Taken and adapted from pytorch to return ndarrays instead of tensors:
https://pytorch.org/docs/0.4.1/_modules/torch/utils/data/dataloader.html
this function is needed, since tensors require more system RAM which we
want to decrease using lazy loading
"""
elem_type = type(batch[0])
if elem_type.__module__ == 'numpy':
if rng_state is not None:
th.random.set_rng_state(rng_state)
return np.stack([b for b in batch], 0)
elif isinstance(batch[0], tuple):
transposed = zip(*batch)
return [custom_collate(samples, rng_state) for samples in transposed]
class LazyCropsFromTrialsIterator(object):
""" This is basically the same code as CropsFromTrialsIterator adapted to
work with lazy datasets. It uses pytorch DataLoader to load recordings
from hdd with multiple threads when the data is actually needed. Reduces
overall RAM requirements.
Parameters
----------
input_time_length: int
Input time length of the ConvNet, determines size of batches in
3rd dimension.
n_preds_per_input: int
Number of predictions ConvNet makes per one input. Can be computed
by making a forward pass with the given input time length, the
output length in 3rd dimension is n_preds_per_input.
batch_size: int
seed: int
Random seed for initialization of `numpy.RandomState` random generator
that shuffles the batches.
num_workers: int
The number of workers to load crops in parallel
collate_fn: func
Merges a list of samples to form a mini-batch
check_preds_smaller_trial_len: bool
Checking validity of predictions and trial lengths. Disable to decrease
runtime.
"""
def __init__(self, input_time_length, n_preds_per_input, batch_size,
seed=328774, num_workers=0, collate_fn=custom_collate,
check_preds_smaller_trial_len=True,
reset_rng_after_each_batch=False):
self.batch_size = batch_size
self.seed = seed
self.rng = RandomState(self.seed)
self.input_time_length = input_time_length
self.n_preds_per_input = n_preds_per_input
self.num_workers = num_workers
self.collate_fn = collate_fn
self.check_preds_smaller_trial_len = check_preds_smaller_trial_len
self.reset_rng_after_each_batch = reset_rng_after_each_batch
def reset_rng(self):
self.rng = RandomState(self.seed)
def get_batches(self, dataset, shuffle):
# in pytorch 1.0.0, internal random state is changed when using a
# DataLoader, even if num_workers is 0. this did not happen in torch
# 0.4.0 and breaks our equality tests of traditional and lazy loading
# therefore, in the collate function of every batch, reset to the
# random state before iterating through batches.
if self.reset_rng_after_each_batch:
random_state = th.random.get_rng_state()
collate_fn = partial(self.collate_fn, rng_state=random_state)
else:
collate_fn = partial(self.collate_fn, rng_state=None)
batch_indeces = self._get_batch_indeces(dataset=dataset,
shuffle=shuffle)
data_loader = DataLoader(dataset=dataset, batch_sampler=batch_indeces,
num_workers=self.num_workers,
pin_memory=False, collate_fn=collate_fn)
return data_loader
def _get_batch_indeces(self, dataset, shuffle):
# start always at first predictable sample, so
# start at end of receptive field
n_receptive_field = self.input_time_length - self.n_preds_per_input + 1
i_trial_starts = [n_receptive_field - 1] * len(dataset.X)
i_trial_stops = [trial.shape[1] for trial in dataset.X]
# Check whether input lengths ok
input_lens = i_trial_stops
for i_trial, input_len in enumerate(input_lens):
assert input_len >= self.input_time_length, (
"Input length {:d} of trial {:d} is smaller than the "
"input time length {:d}".format(input_len, i_trial,
self.input_time_length))
start_stop_blocks_per_trial = _compute_start_stop_block_inds(
i_trial_starts, i_trial_stops, self.input_time_length,
self.n_preds_per_input,
check_preds_smaller_trial_len=self.check_preds_smaller_trial_len)
for i_trial, trial_blocks in enumerate(start_stop_blocks_per_trial):
assert trial_blocks[0][0] == 0
assert trial_blocks[-1][1] == i_trial_stops[i_trial]
i_trial_start_stop_block = np.array([
(i_trial, start, stop) for i_trial, block in
enumerate(start_stop_blocks_per_trial) for start, stop in block])
batches = get_balanced_batches(
n_trials=len(i_trial_start_stop_block), rng=self.rng,
shuffle=shuffle, batch_size=self.batch_size)
return [i_trial_start_stop_block[batch_ind] for batch_ind in batches]
+63 -17
Ver Arquivo
@@ -5,11 +5,12 @@ import time
import pandas as pd
import torch as th
import numpy as np
from braindecode.datautil.splitters import concatenate_sets
from braindecode.experiments.loggers import Printer
from braindecode.experiments.stopcriteria import MaxEpochs, ColumnBelow, Or
from braindecode.torch_ext.util import np_to_var, set_random_seeds
from braindecode.torch_ext.util import np_to_var
log = logging.getLogger(__name__)
@@ -214,10 +215,10 @@ class Experiment(object):
log.info("Run until second stop...")
loss_to_reach = float(self.epochs_df['train_loss'].iloc[-1])
self.run_until_second_stop()
if self.reset_after_second_run:
if (float(self.epochs_df['valid_loss'].iloc[-1]) > loss_to_reach
and self.reset_after_second_run):
# if no valid loss was found below the best train loss on 1st
# run, reset model to the epoch with lowest valid_misclass
if float(self.epochs_df['valid_loss'].iloc[-1]) > loss_to_reach:
log.info("Resetting to best epoch {:d}".format(
self.rememberer.best_epoch))
self.rememberer.reset_to_best_model(self.epochs_df,
@@ -367,11 +368,11 @@ class Experiment(object):
outputs = self.model(input_vars)
loss = self.loss_function(outputs, target_vars)
if hasattr(outputs, 'cpu'):
outputs = outputs.cpu().data.numpy()
outputs = outputs.cpu().detach().numpy()
else:
# assume it is iterable
outputs = [o.cpu().data.numpy() for o in outputs]
loss = loss.cpu().data.numpy()
outputs = [o.cpu().detach().numpy() for o in outputs]
loss = loss.cpu().detach().numpy()
return outputs, loss
def monitor_epoch(self, datasets):
@@ -390,23 +391,66 @@ class Experiment(object):
result_dicts_per_monitor = OrderedDict()
for m in self.monitors:
result_dicts_per_monitor[m] = OrderedDict()
for m in self.monitors:
result_dict = m.monitor_epoch()
if result_dict is not None:
result_dicts_per_monitor[m].update(result_dict)
for setname in datasets:
assert setname in ['train', 'valid', 'test']
dataset = datasets[setname]
all_preds = []
all_losses = []
all_batch_sizes = []
all_targets = []
for batch in self.iterator.get_batches(dataset, shuffle=False):
preds, loss = self.eval_on_batch(batch[0], batch[1])
all_preds.append(preds)
batch_generator = self.iterator.get_batches(dataset, shuffle=False)
if hasattr(batch_generator, "__len__"):
# prevent loading of data to estimate number of batches when
# using lazy iterators
n_batches = len(batch_generator)
else:
# iterating through traditional iterators is cheap, since
# nothing is loaded, recreate generator afterwards
n_batches = sum(1 for i in batch_generator)
batch_generator = self.iterator.get_batches(dataset,
shuffle=False)
all_preds, all_targets = None, None
all_losses, all_batch_sizes = [], []
for inputs, targets in batch_generator:
preds, loss = self.eval_on_batch(inputs, targets)
all_losses.append(loss)
all_batch_sizes.append(len(batch[0]))
all_targets.append(batch[1])
all_batch_sizes.append(len(targets))
if all_preds is None:
assert all_targets is None
# first batch size is largest
max_size, n_classes, n_preds_per_input = preds.shape
# pre-allocate memory for all predictions and targets
all_preds = np.nan * np.ones(
(n_batches * max_size, n_classes, n_preds_per_input),
dtype=np.float32)
all_preds[:len(preds)] = preds
all_targets = np.nan * np.ones((n_batches * max_size))
all_targets[:len(targets)] = targets
else:
start_i = sum(all_batch_sizes[:-1])
stop_i = sum(all_batch_sizes)
all_preds[start_i:stop_i] = preds
all_targets[start_i:stop_i] = targets
# check for unequal batches
unequal_batches = len(set(all_batch_sizes)) > 1
all_batch_sizes = sum(all_batch_sizes)
# remove nan rows in case of unequal batch sizes
if unequal_batches:
assert np.sum(np.isnan(all_preds[:all_batch_sizes - 1])) == 0
assert np.sum(np.isnan(all_preds[all_batch_sizes:])) > 0
range_to_delete = range(all_batch_sizes, len(all_preds))
all_preds = np.delete(all_preds, range_to_delete, axis=0)
all_targets = np.delete(all_targets, range_to_delete, axis=0)
assert np.sum(np.isnan(all_preds)) == 0, (
"There are still nans in predictions")
assert np.sum(np.isnan(all_targets)) == 0, (
"There are still nans in targets")
# add empty dimension
# monitors expect n_batches x ...
all_preds = all_preds[np.newaxis, :]
all_targets = all_targets[np.newaxis, :]
all_batch_sizes = [all_batch_sizes]
all_losses = [all_losses]
for m in self.monitors:
result_dict = m.monitor_set(setname, all_preds, all_losses,
@@ -418,7 +462,9 @@ class Experiment(object):
for m in self.monitors:
row_dict.update(result_dicts_per_monitor[m])
self.epochs_df = self.epochs_df.append(row_dict, ignore_index=True)
assert set(self.epochs_df.columns) == set(row_dict.keys())
assert set(self.epochs_df.columns) == set(row_dict.keys()), (
"Columns of dataframe: {:s}\n and keys of dict {:s} not same")\
.format(str(set(self.epochs_df.columns)), str(set(row_dict.keys())))
self.epochs_df = self.epochs_df[list(row_dict.keys())]
def log_epoch(self):