adapt base model for already dense model
Esse commit está contido em:
@@ -15,6 +15,9 @@ from braindecode.datautil.signal_target import SignalAndTarget
|
||||
from braindecode.models.util import to_dense_prediction_model
|
||||
from braindecode.torch_ext.schedulers import CosineAnnealing, ScheduledOptimizer
|
||||
from braindecode.torch_ext.util import np_to_var, var_to_np
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_optimizer(optimizer_name):
|
||||
@@ -79,7 +82,14 @@ class BaseModel(object):
|
||||
self.loss = loss
|
||||
self._ensure_network_exists()
|
||||
if cropped:
|
||||
to_dense_prediction_model(self.network)
|
||||
model_already_dense = np.any([
|
||||
hasattr(m, 'dilation') and (m.dilation != 1) and
|
||||
(m.dilation) != (1, 1) for m in self.network.modules()
|
||||
])
|
||||
if not model_already_dense:
|
||||
to_dense_prediction_model(self.network)
|
||||
else:
|
||||
log.info("Seems model was already converted to dense model...")
|
||||
if not hasattr(optimizer, 'step'):
|
||||
optimizer_class = find_optimizer(optimizer)
|
||||
optimizer = optimizer_class(self.network.parameters())
|
||||
@@ -141,7 +151,9 @@ class BaseModel(object):
|
||||
"the network in a single pass.")
|
||||
if self.cropped:
|
||||
self.network.eval()
|
||||
test_input = np_to_var(train_X[0:1], dtype=np.float32)
|
||||
test_input = np_to_var(np.ones(
|
||||
(1, train_X.shape[1], input_time_length,) + train_X.shape[3:],
|
||||
dtype=np.float32))
|
||||
while len(test_input.size()) < 4:
|
||||
test_input = test_input.unsqueeze(-1)
|
||||
if self.cuda:
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário