Merge pull request #20 from gemeinl/seeds
Adding seeding parameters to iterators and experiment class, fixing bug in resetting random state of ClassBalancedBatchSizeIterator
Esse commit está contido em:
@@ -64,10 +64,13 @@ class BalancedBatchSizeIterator(object):
|
||||
Resulting batches will not necessarily have the given batch size
|
||||
but rather the next largest batch size that allows to split the set into
|
||||
balanced batches (maximum size difference 1).
|
||||
seed: int
|
||||
Random seed for initialization of `numpy.RandomState`.
|
||||
"""
|
||||
def __init__(self, batch_size):
|
||||
def __init__(self, batch_size, seed=328774):
|
||||
self.batch_size = batch_size
|
||||
self.rng = RandomState(328774)
|
||||
self.seed = seed
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
def get_batches(self, dataset, shuffle):
|
||||
n_trials = dataset.X.shape[0]
|
||||
@@ -85,7 +88,7 @@ class BalancedBatchSizeIterator(object):
|
||||
yield (batch_X, batch_y)
|
||||
|
||||
def reset_rng(self):
|
||||
self.rng = RandomState(328774)
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
|
||||
class ClassBalancedBatchSizeIterator(object):
|
||||
@@ -100,11 +103,14 @@ class ClassBalancedBatchSizeIterator(object):
|
||||
Resulting batches will not necessarily have the given batch size
|
||||
but rather the next largest batch size that allows to split the set into
|
||||
balanced batches (maximum size difference 1).
|
||||
seed: int
|
||||
Random seed for initialization of `numpy.RandomState`.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size):
|
||||
def __init__(self, batch_size, seed=328774):
|
||||
self.batch_size = batch_size
|
||||
self.rng = RandomState(328774)
|
||||
self.seed = seed
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
def get_batches(self, dataset, shuffle):
|
||||
n_trials = dataset.X.shape[0]
|
||||
@@ -138,7 +144,7 @@ class ClassBalancedBatchSizeIterator(object):
|
||||
yield (batch_X, batch_y)
|
||||
|
||||
def reset_rng(self):
|
||||
self.rng = RandomState((4, 7, 2017))
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
|
||||
class CropsFromTrialsIterator(object):
|
||||
|
||||
@@ -137,6 +137,8 @@ class Experiment(object):
|
||||
do_early_stop: bool
|
||||
Whether to do an early stop at all. If true, reset to best model
|
||||
even in case experiment does not run after early stop.
|
||||
seed: int
|
||||
Random seed for python random module numpy.random and torch.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -149,7 +151,7 @@ class Experiment(object):
|
||||
run_after_early_stop,
|
||||
model_loss_function=None,
|
||||
batch_modifier=None, cuda=True, pin_memory=False,
|
||||
do_early_stop=True):
|
||||
do_early_stop=True, seed=2382938):
|
||||
if run_after_early_stop:
|
||||
assert do_early_stop == True, ("Can only run after early stop if "
|
||||
"doing an early stop")
|
||||
@@ -179,6 +181,7 @@ class Experiment(object):
|
||||
self.rememberer = None
|
||||
self.pin_memory = pin_memory
|
||||
self.do_early_stop = do_early_stop
|
||||
self.seed = seed
|
||||
|
||||
|
||||
def run(self):
|
||||
@@ -206,7 +209,7 @@ class Experiment(object):
|
||||
if self.do_early_stop:
|
||||
self.rememberer = RememberBest(self.remember_best_column)
|
||||
self.epochs_df = pd.DataFrame()
|
||||
set_random_seeds(seed=2382938, cuda=self.cuda)
|
||||
set_random_seeds(seed=self.seed, cuda=self.cuda)
|
||||
if self.cuda:
|
||||
assert th.cuda.is_available(), "Cuda not available"
|
||||
self.model.cuda()
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário