updated tests

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-08-06 23:26:43 +02:00
commit c6b1fbf350
@@ -82,7 +82,7 @@ def test_cropped_decoding():
import torch.nn.functional as F
from numpy.random import RandomState
import torch as th
from braindecode.experiments.monitors import compute_preds_per_trial_for_set
from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
rng = RandomState((2017, 6, 30))
losses = []
accuracies = []
@@ -137,9 +137,9 @@ def test_cropped_decoding():
print("{:6s} Loss: {:.5f}".format(setname, loss))
losses.append(loss)
# Assign the predictions to the trials
preds_per_trial = compute_preds_per_trial_for_set(all_preds,
preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
input_time_length,
dataset)
dataset.X)
# preds per trial are now trials x classes x timesteps/predictions
# Now mean across timesteps for each trial to get per-trial predictions
meaned_preds_per_trial = np.array(