logging 0 epoch now as well
Esse commit está contido em:
@@ -57,7 +57,7 @@ class BaseModel(object):
|
||||
self.cuda = False
|
||||
self.compiled = False
|
||||
|
||||
def compile(self, loss, optimizer, monitors=None, cropped=False, iterator_seed=0):
|
||||
def compile(self, loss, optimizer, extra_monitors=None, cropped=False, iterator_seed=0):
|
||||
"""
|
||||
Setup training for this model.
|
||||
|
||||
@@ -66,7 +66,7 @@ class BaseModel(object):
|
||||
loss: function (predictions, targets) -> torch scalar
|
||||
optimizer: `torch.optim.Optimizer` or string
|
||||
Either supply an optimizer or the name of the class (e.g. 'adam')
|
||||
monitors: List of Braindecode monitors, optional
|
||||
extra_monitors: List of Braindecode monitors, optional
|
||||
In case you want to monitor additional values except for loss, misclass and runtime.
|
||||
cropped: bool
|
||||
Whether to perform cropped decoding, see cropped decoding tutorial.
|
||||
@@ -84,7 +84,7 @@ class BaseModel(object):
|
||||
optimizer_class = find_optimizer(optimizer)
|
||||
optimizer = optimizer_class(self.network.parameters())
|
||||
self.optimizer = optimizer
|
||||
self.extra_monitors = monitors
|
||||
self.extra_monitors = extra_monitors
|
||||
# Already setting it here, so multiple calls to fit
|
||||
# will lead to different batches being drawn
|
||||
self.seed_rng = RandomState(iterator_seed)
|
||||
@@ -193,7 +193,7 @@ class BaseModel(object):
|
||||
monitors=self.monitors,
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=remember_best_column,
|
||||
run_after_early_stop=False, cuda=self.cuda, log_0_epoch=False,
|
||||
run_after_early_stop=False, cuda=self.cuda, log_0_epoch=True,
|
||||
do_early_stop=(remember_best_column is not None))
|
||||
exp.run()
|
||||
self.epochs_df = exp.epochs_df
|
||||
@@ -238,7 +238,7 @@ class BaseModel(object):
|
||||
stop_criterion=stop_criterion,
|
||||
remember_best_column=None,
|
||||
run_after_early_stop=False, cuda=self.cuda,
|
||||
log_0_epoch=False,
|
||||
log_0_epoch=True,
|
||||
do_early_stop=False)
|
||||
|
||||
exp.monitor_epoch({'train': train_set})
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.4.1"
|
||||
__version__ = "0.4.3"
|
||||
Referência em uma Nova Issue
Bloquear um usuário