fixed a bug causing the rng of ClassBalancedBatchSizeIterator to be reset to a state different from the initial random state. added seeding parameter to BalancedBatchSizeIterator and ClassBalancedBatchSizeIterator. added seeding parameter to experiment class to allow control of torch and cuda random state.
Esse commit está contido em:
@@ -64,10 +64,13 @@ class BalancedBatchSizeIterator(object):
|
||||
Resulting batches will not necessarily have the given batch size
|
||||
but rather the next largest batch size that allows to split the set into
|
||||
balanced batches (maximum size difference 1).
|
||||
seed: int
|
||||
Random seed for initialization of `numpy.RandomState`.
|
||||
"""
|
||||
def __init__(self, batch_size):
|
||||
def __init__(self, batch_size, seed=328774):
|
||||
self.batch_size = batch_size
|
||||
self.rng = RandomState(328774)
|
||||
self.seed = seed
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
def get_batches(self, dataset, shuffle):
|
||||
n_trials = dataset.X.shape[0]
|
||||
@@ -85,7 +88,7 @@ class BalancedBatchSizeIterator(object):
|
||||
yield (batch_X, batch_y)
|
||||
|
||||
def reset_rng(self):
|
||||
self.rng = RandomState(328774)
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
|
||||
class ClassBalancedBatchSizeIterator(object):
|
||||
@@ -100,11 +103,14 @@ class ClassBalancedBatchSizeIterator(object):
|
||||
Resulting batches will not necessarily have the given batch size
|
||||
but rather the next largest batch size that allows to split the set into
|
||||
balanced batches (maximum size difference 1).
|
||||
seed: int
|
||||
Random seed for initialization of `numpy.RandomState`.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size):
|
||||
def __init__(self, batch_size, seed=328774):
|
||||
self.batch_size = batch_size
|
||||
self.rng = RandomState(328774)
|
||||
self.seed = seed
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
def get_batches(self, dataset, shuffle):
|
||||
n_trials = dataset.X.shape[0]
|
||||
@@ -138,7 +144,7 @@ class ClassBalancedBatchSizeIterator(object):
|
||||
yield (batch_X, batch_y)
|
||||
|
||||
def reset_rng(self):
|
||||
self.rng = RandomState((4, 7, 2017))
|
||||
self.rng = RandomState(self.seed)
|
||||
|
||||
|
||||
class CropsFromTrialsIterator(object):
|
||||
|
||||
@@ -137,6 +137,8 @@ class Experiment(object):
|
||||
do_early_stop: bool
|
||||
Whether to do an early stop at all. If true, reset to best model
|
||||
even in case experiment does not run after early stop.
|
||||
seed: int
|
||||
Random seed for python random module numpy.random and torch.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -149,7 +151,7 @@ class Experiment(object):
|
||||
run_after_early_stop,
|
||||
model_loss_function=None,
|
||||
batch_modifier=None, cuda=True, pin_memory=False,
|
||||
do_early_stop=True):
|
||||
do_early_stop=True, seed=2382938):
|
||||
if run_after_early_stop:
|
||||
assert do_early_stop == True, ("Can only run after early stop if "
|
||||
"doing an early stop")
|
||||
@@ -179,6 +181,7 @@ class Experiment(object):
|
||||
self.rememberer = None
|
||||
self.pin_memory = pin_memory
|
||||
self.do_early_stop = do_early_stop
|
||||
self.seed = seed
|
||||
|
||||
|
||||
def run(self):
|
||||
@@ -206,7 +209,7 @@ class Experiment(object):
|
||||
if self.do_early_stop:
|
||||
self.rememberer = RememberBest(self.remember_best_column)
|
||||
self.epochs_df = pd.DataFrame()
|
||||
set_random_seeds(seed=2382938, cuda=self.cuda)
|
||||
set_random_seeds(seed=self.seed, cuda=self.cuda)
|
||||
if self.cuda:
|
||||
assert th.cuda.is_available(), "Cuda not available"
|
||||
self.model.cuda()
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário