fix using cuda with model.predict

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-09-19 12:33:20 +02:00
commit 651ef3d7a4
2 arquivos alterados com 7 adições e 3 exclusões
+6 -2
Ver Arquivo
@@ -285,8 +285,12 @@ class BaseModel(object):
Predicted labels per trial.
"""
all_preds = []
for b_X, _ in self.iterator.get_batches(SignalAndTarget(X, X), False):
all_preds.append(var_to_np(self.network(np_to_var(b_X))))
with th.no_grad():
for b_X, _ in self.iterator.get_batches(SignalAndTarget(X, X), False):
b_X_var = np_to_var(b_X)
if self.cuda:
b_X_var = b_X_var.cuda()
all_preds.append(var_to_np(self.network(b_X_var)))
if self.cropped:
pred_labels = compute_trial_labels_from_crop_preds(
all_preds, self.iterator.input_time_length, X)
+1 -1
Ver Arquivo
@@ -1 +1 @@
__version__ = "0.4.5"
__version__ = "0.4.6"