updated tests

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-08-06 23:26:43 +02:00
commit c6b1fbf350
@@ -82,7 +82,7 @@ def test_cropped_decoding():
import torch.nn.functional as F import torch.nn.functional as F
from numpy.random import RandomState from numpy.random import RandomState
import torch as th import torch as th
from braindecode.experiments.monitors import compute_preds_per_trial_for_set from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
rng = RandomState((2017, 6, 30)) rng = RandomState((2017, 6, 30))
losses = [] losses = []
accuracies = [] accuracies = []
@@ -137,9 +137,9 @@ def test_cropped_decoding():
print("{:6s} Loss: {:.5f}".format(setname, loss)) print("{:6s} Loss: {:.5f}".format(setname, loss))
losses.append(loss) losses.append(loss)
# Assign the predictions to the trials # Assign the predictions to the trials
preds_per_trial = compute_preds_per_trial_for_set(all_preds, preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
input_time_length, input_time_length,
dataset) dataset.X)
# preds per trial are now trials x classes x timesteps/predictions # preds per trial are now trials x classes x timesteps/predictions
# Now mean across timesteps for each trial to get per-trial predictions # Now mean across timesteps for each trial to get per-trial predictions
meaned_preds_per_trial = np.array( meaned_preds_per_trial = np.array(