iterator and concatenate_two_sets that can deal with per-time-point y
Esse commit está contido em:
@@ -0,0 +1 @@
|
||||
from braindecode.version import __version__
|
||||
@@ -218,7 +218,7 @@ class CropsFromTrialsIterator(object):
|
||||
for i_blocks in blocks_per_batch:
|
||||
start_stop_blocks = i_trial_start_stop_block[i_blocks]
|
||||
batch = _create_batch_from_i_trial_start_stop_blocks(
|
||||
X, y, start_stop_blocks)
|
||||
X, y, start_stop_blocks, self.n_preds_per_input)
|
||||
yield batch
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ def _compute_start_stop_block_inds(i_trial_starts, i_trial_stops,
|
||||
----------
|
||||
i_trial_starts: 1darray/list of int
|
||||
Indices of first samples to predict(!).
|
||||
i_trial_stops: 1daray/list of int
|
||||
i_trial_stops: 1darray/list of int
|
||||
Indices one past last sample to predict.
|
||||
input_time_length: int
|
||||
n_preds_per_input: int
|
||||
@@ -299,12 +299,17 @@ def _get_start_stop_blocks_for_trial(i_trial_start, i_trial_stop,
|
||||
return start_stop_blocks
|
||||
|
||||
|
||||
def _create_batch_from_i_trial_start_stop_blocks(X, y, i_trial_start_stop_block):
|
||||
def _create_batch_from_i_trial_start_stop_blocks(X, y, i_trial_start_stop_block,
|
||||
n_preds_per_input=None):
|
||||
Xs = []
|
||||
ys = []
|
||||
for i_trial, start, stop in i_trial_start_stop_block:
|
||||
Xs.append(X[i_trial][:,start:stop])
|
||||
ys.append(y[i_trial])
|
||||
if not hasattr(y[i_trial], '__len__'):
|
||||
ys.append(y[i_trial])
|
||||
else:
|
||||
assert n_preds_per_input is not None
|
||||
ys.append(y[i_trial][stop-n_preds_per_input:stop])
|
||||
batch_X = np.array(Xs)
|
||||
batch_y = np.array(ys)
|
||||
# add empty fourth dimension if necessary
|
||||
|
||||
@@ -33,17 +33,22 @@ def concatenate_two_sets(set_a, set_b):
|
||||
-------
|
||||
concatenated_set: :class:`.SignalAndTarget`
|
||||
"""
|
||||
if hasattr(set_a.X, 'ndim') and hasattr(set_b.X, 'ndim'):
|
||||
new_X = np.concatenate((set_a.X, set_b.X), axis=0)
|
||||
else:
|
||||
if hasattr(set_a.X, 'ndim'):
|
||||
set_a.X = set_a.X.tolist()
|
||||
if hasattr(set_b.X, 'ndim'):
|
||||
set_b.X = set_b.X.tolist()
|
||||
new_X = set_a.X + set_b.X
|
||||
new_y = np.concatenate((set_a.y, set_b.y), axis=0)
|
||||
new_X = concatenate_np_array_or_add_lists(set_a.X, set_b.X)
|
||||
new_y = concatenate_np_array_or_add_lists(set_a.y, set_b.y)
|
||||
return SignalAndTarget(new_X, new_y)
|
||||
|
||||
def concatenate_np_array_or_add_lists(a, b):
|
||||
if hasattr(a, 'ndim') and hasattr(b, 'ndim'):
|
||||
new = np.concatenate((a, b), axis=0)
|
||||
else:
|
||||
if hasattr(a, 'ndim'):
|
||||
a = a.tolist()
|
||||
if hasattr(b, 'ndim'):
|
||||
b = b.tolist()
|
||||
new = a + b
|
||||
return new
|
||||
|
||||
|
||||
|
||||
def split_into_two_sets(dataset, first_set_fraction=None, n_first_set=None):
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import torch as th
|
||||
|
||||
|
||||
def log_categorical_crossentropy(logpreds, targets, dims=None):
|
||||
def log_categorical_crossentropy_1_hot(logpreds, targets, dims=None):
|
||||
"""
|
||||
Returns log categorical crossentropy for given log-predictions and targets.
|
||||
Returns log categorical crossentropy for given log-predictions and targets,
|
||||
targets should be one-hot-encoded.
|
||||
|
||||
Computes :math:`-\mathrm{logpreds} \cdot \mathrm{targets}`
|
||||
|
||||
@@ -12,6 +13,7 @@ def log_categorical_crossentropy(logpreds, targets, dims=None):
|
||||
logpreds: `torch.autograd.Variable`
|
||||
Logarithm of softmax output.
|
||||
targets: `torch.autograd.Variable`
|
||||
One-hot encoded targets
|
||||
dims: int or iterable of int, optional.
|
||||
Compute sum across these dims
|
||||
|
||||
@@ -31,6 +33,39 @@ def log_categorical_crossentropy(logpreds, targets, dims=None):
|
||||
return result
|
||||
|
||||
|
||||
def log_categorical_crossentropy(log_preds, targets):
|
||||
"""
|
||||
Returns log categorical crossentropy for given log-predictions and targets.
|
||||
|
||||
Computes :math:`-\mathrm{logpreds} \cdot \mathrm{targets}` if you assume
|
||||
targets to be one-hot-encoded. Also works for targets that are not
|
||||
one-hot-encoded, in this case only uses targets that are in the range
|
||||
of the expected class labels, i.e., [0,log_preds.size()[1]-1].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_preds: torch.autograd.Variable`
|
||||
Logarithm of softmax output.
|
||||
targets: `torch.autograd.Variable`
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
loss: `torch.autograd.Variable`
|
||||
"""
|
||||
if log_preds.size() == targets.size():
|
||||
return log_categorical_crossentropy_1_hot(log_preds, targets)
|
||||
n_classes = log_preds.size()[1]
|
||||
n_elements = 0
|
||||
losses = []
|
||||
for i_class in range(n_classes):
|
||||
mask = targets == i_class
|
||||
mask = mask.type_as(log_preds)
|
||||
n_elements -= th.sum(mask)
|
||||
losses.append(th.sum(mask * log_preds[:,i_class]))
|
||||
return th.sum(th.stack(losses)) / n_elements
|
||||
|
||||
|
||||
def l2_loss(model):
|
||||
losses = [th.sum(p * p) for p in model.parameters()]
|
||||
return sum(losses)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = "0.2.0"
|
||||
@@ -21,3 +21,13 @@ apidoc:
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
|
||||
removeipynbcheckpoints: Makefile
|
||||
rm -rf notebooks/.ipynb_checkpoints/ notebooks/visualization/.ipynb_checkpoints/
|
||||
|
||||
removesource: Makefile
|
||||
rm -rf source/
|
||||
|
||||
rmanddoc: removesource removeipynbcheckpoints apidoc html
|
||||
echo "Done"
|
||||
|
||||
+5
-20
@@ -82,7 +82,8 @@ autodoc_member_order = 'bysource'
|
||||
## Default flags used by autodoc directives
|
||||
autodoc_default_flags = ['members', 'show-inheritance']
|
||||
|
||||
exclude_patterns = ['_build', '_templates',]
|
||||
exclude_patterns = ['_build', '_templates']
|
||||
|
||||
|
||||
napoleon_google_docstring = False
|
||||
napoleon_use_param = False
|
||||
@@ -110,9 +111,10 @@ author = 'Robin Tibor Schirrmeister'
|
||||
# built documents.
|
||||
#
|
||||
# The short X.Y version.
|
||||
version = '0.1.9'
|
||||
import braindecode
|
||||
version = braindecode.__version__
|
||||
# The full version, including alpha/beta/rc tags.
|
||||
release = '0.1.9'
|
||||
release = version
|
||||
|
||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||
# for a list of supported languages.
|
||||
@@ -211,20 +213,3 @@ texinfo_documents = [
|
||||
author, 'Braindecode', 'One line description of project.',
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
||||
## mock stuff
|
||||
"""
|
||||
import sys
|
||||
from mock import Mock as MagicMock
|
||||
#from unittest.mock import MagicMock
|
||||
|
||||
class Mock(MagicMock):
|
||||
@classmethod
|
||||
def __getattr__(cls, name):
|
||||
return MagicMock()
|
||||
|
||||
MOCK_MODULES = ['torch', 'h5py',]
|
||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
||||
|
||||
"""
|
||||
+1
-1
@@ -3,7 +3,7 @@ Welcome to Braindecode
|
||||
|
||||
A deep learning toolbox to decode raw time-domain EEG.
|
||||
|
||||
For EEG researchers that want to want to work with deep learning and
|
||||
For EEG researchers that want to work with deep learning and
|
||||
deep learning researchers that want to work with EEG data.
|
||||
For now focussed on convolutional networks.
|
||||
|
||||
|
||||
@@ -449,6 +449,9 @@
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"git": {
|
||||
"keep_outputs": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
|
||||
@@ -30,4 +30,12 @@ braindecode\.util module
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
braindecode\.version module
|
||||
---------------------------
|
||||
|
||||
.. automodule:: braindecode.version
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
|
||||
+7
-4
@@ -9,13 +9,16 @@ with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
|
||||
long_description = f.read()
|
||||
|
||||
|
||||
# This will add __version__ to version dict
|
||||
version = {}
|
||||
with open(path.join(here, 'braindecode/version.py'), encoding='utf-8') as (
|
||||
version_file):
|
||||
exec(version_file.read(), version)
|
||||
|
||||
setup(
|
||||
name='Braindecode',
|
||||
|
||||
# Versions should comply with PEP440. For a discussion on single-sourcing
|
||||
# the version across setup.py and the project code, see
|
||||
# http://packaging.python.org/en/latest/tutorial.html#version
|
||||
version='0.1.9', # TODO: read from __init__.py?
|
||||
version=version['__version__'],
|
||||
|
||||
description='A deep learning toolbox to decode raw time-domain EEG.',
|
||||
long_description=long_description,
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário