fix using cuda with model.predict
Esse commit está contido em:
@@ -285,8 +285,12 @@ class BaseModel(object):
|
||||
Predicted labels per trial.
|
||||
"""
|
||||
all_preds = []
|
||||
for b_X, _ in self.iterator.get_batches(SignalAndTarget(X, X), False):
|
||||
all_preds.append(var_to_np(self.network(np_to_var(b_X))))
|
||||
with th.no_grad():
|
||||
for b_X, _ in self.iterator.get_batches(SignalAndTarget(X, X), False):
|
||||
b_X_var = np_to_var(b_X)
|
||||
if self.cuda:
|
||||
b_X_var = b_X_var.cuda()
|
||||
all_preds.append(var_to_np(self.network(b_X_var)))
|
||||
if self.cropped:
|
||||
pred_labels = compute_trial_labels_from_crop_preds(
|
||||
all_preds, self.iterator.input_time_length, X)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.4.5"
|
||||
__version__ = "0.4.6"
|
||||
Referência em uma Nova Issue
Bloquear um usuário