Merge pull request #22 from gemeinl/second-run-reset
Added resetting to best model if second run was not successful
Esse commit está contido em:
@@ -137,6 +137,9 @@ class Experiment(object):
|
||||
do_early_stop: bool
|
||||
Whether to do an early stop at all. If true, reset to best model
|
||||
even in case experiment does not run after early stop.
|
||||
reset_after_second_run: bool
|
||||
If true, reset to best model when second run did not find a valid loss
|
||||
below or equal to the best train loss of first run.
|
||||
seed: int
|
||||
Random seed for python random module numpy.random and torch.
|
||||
|
||||
@@ -151,10 +154,11 @@ class Experiment(object):
|
||||
run_after_early_stop,
|
||||
model_loss_function=None,
|
||||
batch_modifier=None, cuda=True, pin_memory=False,
|
||||
do_early_stop=True, seed=2382938):
|
||||
if run_after_early_stop:
|
||||
assert do_early_stop == True, ("Can only run after early stop if "
|
||||
"doing an early stop")
|
||||
do_early_stop=True, reset_after_second_run=False,
|
||||
seed=2382938):
|
||||
if run_after_early_stop or reset_after_second_run:
|
||||
assert do_early_stop == True, ("Can only run after early stop or "
|
||||
"reset after second run if doing an early stop")
|
||||
if do_early_stop:
|
||||
assert valid_set is not None
|
||||
assert remember_best_column is not None
|
||||
@@ -181,9 +185,9 @@ class Experiment(object):
|
||||
self.rememberer = None
|
||||
self.pin_memory = pin_memory
|
||||
self.do_early_stop = do_early_stop
|
||||
self.reset_after_second_run = reset_after_second_run
|
||||
self.seed = seed
|
||||
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run complete training.
|
||||
@@ -198,7 +202,17 @@ class Experiment(object):
|
||||
self.setup_after_stop_training()
|
||||
if self.run_after_early_stop:
|
||||
log.info("Run until second stop...")
|
||||
loss_to_reach = float(self.epochs_df['train_loss'].iloc[-1])
|
||||
self.run_until_second_stop()
|
||||
if self.reset_after_second_run:
|
||||
# if no valid loss was found below the best train loss on 1st
|
||||
# run, reset model to the epoch with lowest valid_misclass
|
||||
if float(self.epochs_df['valid_loss'].iloc[-1]) > loss_to_reach:
|
||||
log.info("Resetting to best epoch {:d}".format(
|
||||
self.rememberer.best_epoch))
|
||||
self.rememberer.reset_to_best_model(self.epochs_df,
|
||||
self.model,
|
||||
self.optimizer)
|
||||
|
||||
def setup_training(self):
|
||||
"""
|
||||
@@ -234,9 +248,7 @@ class Experiment(object):
|
||||
datasets['train'] = concatenate_sets([datasets['train'],
|
||||
datasets['valid']])
|
||||
|
||||
# Todo: actually keep remembering and in case of twice number of epochs
|
||||
# reset to best model again (check if valid loss not below train loss)
|
||||
self.run_until_stop(datasets, remember_best=False)
|
||||
self.run_until_stop(datasets, remember_best=True)
|
||||
|
||||
def run_until_stop(self, datasets, remember_best):
|
||||
"""
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário