fixed a bug causing the rng of ClassBalancedBatchSizeIterator to be reset to a state different from the initial random state. added seeding parameter to BalancedBatchSizeIterator and ClassBalancedBatchSizeIterator. added seeding parameter to experiment class to allow control of torch and cuda random state.

Esse commit está contido em:
Lukas gemein
2018-07-13 11:03:03 +02:00
commit f84b444547
2 arquivos alterados com 17 adições e 8 exclusões
+12 -6
Ver Arquivo
@@ -64,10 +64,13 @@ class BalancedBatchSizeIterator(object):
Resulting batches will not necessarily have the given batch size
but rather the next largest batch size that allows to split the set into
balanced batches (maximum size difference 1).
seed: int
Random seed for initialization of `numpy.RandomState`.
"""
def __init__(self, batch_size):
def __init__(self, batch_size, seed=328774):
self.batch_size = batch_size
self.rng = RandomState(328774)
self.seed = seed
self.rng = RandomState(self.seed)
def get_batches(self, dataset, shuffle):
n_trials = dataset.X.shape[0]
@@ -85,7 +88,7 @@ class BalancedBatchSizeIterator(object):
yield (batch_X, batch_y)
def reset_rng(self):
self.rng = RandomState(328774)
self.rng = RandomState(self.seed)
class ClassBalancedBatchSizeIterator(object):
@@ -100,11 +103,14 @@ class ClassBalancedBatchSizeIterator(object):
Resulting batches will not necessarily have the given batch size
but rather the next largest batch size that allows to split the set into
balanced batches (maximum size difference 1).
seed: int
Random seed for initialization of `numpy.RandomState`.
"""
def __init__(self, batch_size):
def __init__(self, batch_size, seed=328774):
self.batch_size = batch_size
self.rng = RandomState(328774)
self.seed = seed
self.rng = RandomState(self.seed)
def get_batches(self, dataset, shuffle):
n_trials = dataset.X.shape[0]
@@ -138,7 +144,7 @@ class ClassBalancedBatchSizeIterator(object):
yield (batch_X, batch_y)
def reset_rng(self):
self.rng = RandomState((4, 7, 2017))
self.rng = RandomState(self.seed)
class CropsFromTrialsIterator(object):
+5 -2
Ver Arquivo
@@ -137,6 +137,8 @@ class Experiment(object):
do_early_stop: bool
Whether to do an early stop at all. If true, reset to best model
even in case experiment does not run after early stop.
seed: int
Random seed for python random module numpy.random and torch.
Attributes
----------
@@ -149,7 +151,7 @@ class Experiment(object):
run_after_early_stop,
model_loss_function=None,
batch_modifier=None, cuda=True, pin_memory=False,
do_early_stop=True):
do_early_stop=True, seed=2382938):
if run_after_early_stop:
assert do_early_stop == True, ("Can only run after early stop if "
"doing an early stop")
@@ -179,6 +181,7 @@ class Experiment(object):
self.rememberer = None
self.pin_memory = pin_memory
self.do_early_stop = do_early_stop
self.seed = seed
def run(self):
@@ -206,7 +209,7 @@ class Experiment(object):
if self.do_early_stop:
self.rememberer = RememberBest(self.remember_best_column)
self.epochs_df = pd.DataFrame()
set_random_seeds(seed=2382938, cuda=self.cuda)
set_random_seeds(seed=self.seed, cuda=self.cuda)
if self.cuda:
assert th.cuda.is_available(), "Cuda not available"
self.model.cuda()