Merge pull request #22 from gemeinl/second-run-reset

Added resetting to best model if second run was not successful
Esse commit está contido em:
robintibor
2018-07-19 11:26:14 +02:00
commit de GitHub
+20 -8
Ver Arquivo
@@ -137,6 +137,9 @@ class Experiment(object):
do_early_stop: bool
Whether to do an early stop at all. If true, reset to best model
even in case experiment does not run after early stop.
reset_after_second_run: bool
If true, reset to best model when second run did not find a valid loss
below or equal to the best train loss of first run.
seed: int
Random seed for python random module numpy.random and torch.
@@ -151,10 +154,11 @@ class Experiment(object):
run_after_early_stop,
model_loss_function=None,
batch_modifier=None, cuda=True, pin_memory=False,
do_early_stop=True, seed=2382938):
if run_after_early_stop:
assert do_early_stop == True, ("Can only run after early stop if "
"doing an early stop")
do_early_stop=True, reset_after_second_run=False,
seed=2382938):
if run_after_early_stop or reset_after_second_run:
assert do_early_stop == True, ("Can only run after early stop or "
"reset after second run if doing an early stop")
if do_early_stop:
assert valid_set is not None
assert remember_best_column is not None
@@ -181,9 +185,9 @@ class Experiment(object):
self.rememberer = None
self.pin_memory = pin_memory
self.do_early_stop = do_early_stop
self.reset_after_second_run = reset_after_second_run
self.seed = seed
def run(self):
"""
Run complete training.
@@ -198,7 +202,17 @@ class Experiment(object):
self.setup_after_stop_training()
if self.run_after_early_stop:
log.info("Run until second stop...")
loss_to_reach = float(self.epochs_df['train_loss'].iloc[-1])
self.run_until_second_stop()
if self.reset_after_second_run:
# if no valid loss was found below the best train loss on 1st
# run, reset model to the epoch with lowest valid_misclass
if float(self.epochs_df['valid_loss'].iloc[-1]) > loss_to_reach:
log.info("Resetting to best epoch {:d}".format(
self.rememberer.best_epoch))
self.rememberer.reset_to_best_model(self.epochs_df,
self.model,
self.optimizer)
def setup_training(self):
"""
@@ -234,9 +248,7 @@ class Experiment(object):
datasets['train'] = concatenate_sets([datasets['train'],
datasets['valid']])
# Todo: actually keep remembering and in case of twice number of epochs
# reset to best model again (check if valid loss not below train loss)
self.run_until_stop(datasets, remember_best=False)
self.run_until_stop(datasets, remember_best=True)
def run_until_stop(self, datasets, remember_best):
"""