adapt base model for already dense model

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-09-17 17:23:13 +02:00
commit 5ea159b190
+14 -2
Ver Arquivo
@@ -15,6 +15,9 @@ from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.models.util import to_dense_prediction_model
from braindecode.torch_ext.schedulers import CosineAnnealing, ScheduledOptimizer
from braindecode.torch_ext.util import np_to_var, var_to_np
import logging
log = logging.getLogger(__name__)
def find_optimizer(optimizer_name):
@@ -79,7 +82,14 @@ class BaseModel(object):
self.loss = loss
self._ensure_network_exists()
if cropped:
to_dense_prediction_model(self.network)
model_already_dense = np.any([
hasattr(m, 'dilation') and (m.dilation != 1) and
(m.dilation) != (1, 1) for m in self.network.modules()
])
if not model_already_dense:
to_dense_prediction_model(self.network)
else:
log.info("Seems model was already converted to dense model...")
if not hasattr(optimizer, 'step'):
optimizer_class = find_optimizer(optimizer)
optimizer = optimizer_class(self.network.parameters())
@@ -141,7 +151,9 @@ class BaseModel(object):
"the network in a single pass.")
if self.cropped:
self.network.eval()
test_input = np_to_var(train_X[0:1], dtype=np.float32)
test_input = np_to_var(np.ones(
(1, train_X.shape[1], input_time_length,) + train_X.shape[3:],
dtype=np.float32))
while len(test_input.size()) < 4:
test_input = test_input.unsqueeze(-1)
if self.cuda: