From 5ea159b190fd8e3b16f705bc4274064476dee4ec Mon Sep 17 00:00:00 2001 From: Robin Tibor Schirrmeister Date: Mon, 17 Sep 2018 17:23:13 +0200 Subject: [PATCH] adapt base model for already dense model --- braindecode/models/base.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/braindecode/models/base.py b/braindecode/models/base.py index 828b38e..2a680da 100644 --- a/braindecode/models/base.py +++ b/braindecode/models/base.py @@ -15,6 +15,9 @@ from braindecode.datautil.signal_target import SignalAndTarget from braindecode.models.util import to_dense_prediction_model from braindecode.torch_ext.schedulers import CosineAnnealing, ScheduledOptimizer from braindecode.torch_ext.util import np_to_var, var_to_np +import logging + +log = logging.getLogger(__name__) def find_optimizer(optimizer_name): @@ -79,7 +82,14 @@ class BaseModel(object): self.loss = loss self._ensure_network_exists() if cropped: - to_dense_prediction_model(self.network) + model_already_dense = np.any([ + hasattr(m, 'dilation') and (m.dilation != 1) and + (m.dilation) != (1, 1) for m in self.network.modules() + ]) + if not model_already_dense: + to_dense_prediction_model(self.network) + else: + log.info("Seems model was already converted to dense model...") if not hasattr(optimizer, 'step'): optimizer_class = find_optimizer(optimizer) optimizer = optimizer_class(self.network.parameters()) @@ -141,7 +151,9 @@ class BaseModel(object): "the network in a single pass.") if self.cropped: self.network.eval() - test_input = np_to_var(train_X[0:1], dtype=np.float32) + test_input = np_to_var(np.ones( + (1, train_X.shape[1], input_time_length,) + train_X.shape[3:], + dtype=np.float32)) while len(test_input.size()) < 4: test_input = test_input.unsqueeze(-1) if self.cuda: