updated tests
Esse commit está contido em:
@@ -82,7 +82,7 @@ def test_cropped_decoding():
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from numpy.random import RandomState
|
from numpy.random import RandomState
|
||||||
import torch as th
|
import torch as th
|
||||||
from braindecode.experiments.monitors import compute_preds_per_trial_for_set
|
from braindecode.experiments.monitors import compute_preds_per_trial_from_crops
|
||||||
rng = RandomState((2017, 6, 30))
|
rng = RandomState((2017, 6, 30))
|
||||||
losses = []
|
losses = []
|
||||||
accuracies = []
|
accuracies = []
|
||||||
@@ -137,9 +137,9 @@ def test_cropped_decoding():
|
|||||||
print("{:6s} Loss: {:.5f}".format(setname, loss))
|
print("{:6s} Loss: {:.5f}".format(setname, loss))
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
# Assign the predictions to the trials
|
# Assign the predictions to the trials
|
||||||
preds_per_trial = compute_preds_per_trial_for_set(all_preds,
|
preds_per_trial = compute_preds_per_trial_from_crops(all_preds,
|
||||||
input_time_length,
|
input_time_length,
|
||||||
dataset)
|
dataset.X)
|
||||||
# preds per trial are now trials x classes x timesteps/predictions
|
# preds per trial are now trials x classes x timesteps/predictions
|
||||||
# Now mean across timesteps for each trial to get per-trial predictions
|
# Now mean across timesteps for each trial to get per-trial predictions
|
||||||
meaned_preds_per_trial = np.array(
|
meaned_preds_per_trial = np.array(
|
||||||
|
|||||||
Referência em uma Nova Issue
Bloquear um usuário