iterator and concatenate_two_sets that can deal with per-time-point y
Esse commit está contido em:
@@ -0,0 +1 @@
|
|||||||
|
from braindecode.version import __version__
|
||||||
@@ -218,7 +218,7 @@ class CropsFromTrialsIterator(object):
|
|||||||
for i_blocks in blocks_per_batch:
|
for i_blocks in blocks_per_batch:
|
||||||
start_stop_blocks = i_trial_start_stop_block[i_blocks]
|
start_stop_blocks = i_trial_start_stop_block[i_blocks]
|
||||||
batch = _create_batch_from_i_trial_start_stop_blocks(
|
batch = _create_batch_from_i_trial_start_stop_blocks(
|
||||||
X, y, start_stop_blocks)
|
X, y, start_stop_blocks, self.n_preds_per_input)
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
@@ -231,7 +231,7 @@ def _compute_start_stop_block_inds(i_trial_starts, i_trial_stops,
|
|||||||
----------
|
----------
|
||||||
i_trial_starts: 1darray/list of int
|
i_trial_starts: 1darray/list of int
|
||||||
Indices of first samples to predict(!).
|
Indices of first samples to predict(!).
|
||||||
i_trial_stops: 1daray/list of int
|
i_trial_stops: 1darray/list of int
|
||||||
Indices one past last sample to predict.
|
Indices one past last sample to predict.
|
||||||
input_time_length: int
|
input_time_length: int
|
||||||
n_preds_per_input: int
|
n_preds_per_input: int
|
||||||
@@ -299,12 +299,17 @@ def _get_start_stop_blocks_for_trial(i_trial_start, i_trial_stop,
|
|||||||
return start_stop_blocks
|
return start_stop_blocks
|
||||||
|
|
||||||
|
|
||||||
def _create_batch_from_i_trial_start_stop_blocks(X, y, i_trial_start_stop_block):
|
def _create_batch_from_i_trial_start_stop_blocks(X, y, i_trial_start_stop_block,
|
||||||
|
n_preds_per_input=None):
|
||||||
Xs = []
|
Xs = []
|
||||||
ys = []
|
ys = []
|
||||||
for i_trial, start, stop in i_trial_start_stop_block:
|
for i_trial, start, stop in i_trial_start_stop_block:
|
||||||
Xs.append(X[i_trial][:,start:stop])
|
Xs.append(X[i_trial][:,start:stop])
|
||||||
ys.append(y[i_trial])
|
if not hasattr(y[i_trial], '__len__'):
|
||||||
|
ys.append(y[i_trial])
|
||||||
|
else:
|
||||||
|
assert n_preds_per_input is not None
|
||||||
|
ys.append(y[i_trial][stop-n_preds_per_input:stop])
|
||||||
batch_X = np.array(Xs)
|
batch_X = np.array(Xs)
|
||||||
batch_y = np.array(ys)
|
batch_y = np.array(ys)
|
||||||
# add empty fourth dimension if necessary
|
# add empty fourth dimension if necessary
|
||||||
|
|||||||
@@ -33,17 +33,22 @@ def concatenate_two_sets(set_a, set_b):
|
|||||||
-------
|
-------
|
||||||
concatenated_set: :class:`.SignalAndTarget`
|
concatenated_set: :class:`.SignalAndTarget`
|
||||||
"""
|
"""
|
||||||
if hasattr(set_a.X, 'ndim') and hasattr(set_b.X, 'ndim'):
|
new_X = concatenate_np_array_or_add_lists(set_a.X, set_b.X)
|
||||||
new_X = np.concatenate((set_a.X, set_b.X), axis=0)
|
new_y = concatenate_np_array_or_add_lists(set_a.y, set_b.y)
|
||||||
else:
|
|
||||||
if hasattr(set_a.X, 'ndim'):
|
|
||||||
set_a.X = set_a.X.tolist()
|
|
||||||
if hasattr(set_b.X, 'ndim'):
|
|
||||||
set_b.X = set_b.X.tolist()
|
|
||||||
new_X = set_a.X + set_b.X
|
|
||||||
new_y = np.concatenate((set_a.y, set_b.y), axis=0)
|
|
||||||
return SignalAndTarget(new_X, new_y)
|
return SignalAndTarget(new_X, new_y)
|
||||||
|
|
||||||
|
def concatenate_np_array_or_add_lists(a, b):
|
||||||
|
if hasattr(a, 'ndim') and hasattr(b, 'ndim'):
|
||||||
|
new = np.concatenate((a, b), axis=0)
|
||||||
|
else:
|
||||||
|
if hasattr(a, 'ndim'):
|
||||||
|
a = a.tolist()
|
||||||
|
if hasattr(b, 'ndim'):
|
||||||
|
b = b.tolist()
|
||||||
|
new = a + b
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split_into_two_sets(dataset, first_set_fraction=None, n_first_set=None):
|
def split_into_two_sets(dataset, first_set_fraction=None, n_first_set=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import torch as th
|
import torch as th
|
||||||
|
|
||||||
|
|
||||||
def log_categorical_crossentropy(logpreds, targets, dims=None):
|
def log_categorical_crossentropy_1_hot(logpreds, targets, dims=None):
|
||||||
"""
|
"""
|
||||||
Returns log categorical crossentropy for given log-predictions and targets.
|
Returns log categorical crossentropy for given log-predictions and targets,
|
||||||
|
targets should be one-hot-encoded.
|
||||||
|
|
||||||
Computes :math:`-\mathrm{logpreds} \cdot \mathrm{targets}`
|
Computes :math:`-\mathrm{logpreds} \cdot \mathrm{targets}`
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ def log_categorical_crossentropy(logpreds, targets, dims=None):
|
|||||||
logpreds: `torch.autograd.Variable`
|
logpreds: `torch.autograd.Variable`
|
||||||
Logarithm of softmax output.
|
Logarithm of softmax output.
|
||||||
targets: `torch.autograd.Variable`
|
targets: `torch.autograd.Variable`
|
||||||
|
One-hot encoded targets
|
||||||
dims: int or iterable of int, optional.
|
dims: int or iterable of int, optional.
|
||||||
Compute sum across these dims
|
Compute sum across these dims
|
||||||
|
|
||||||
@@ -31,6 +33,39 @@ def log_categorical_crossentropy(logpreds, targets, dims=None):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def log_categorical_crossentropy(log_preds, targets):
|
||||||
|
"""
|
||||||
|
Returns log categorical crossentropy for given log-predictions and targets.
|
||||||
|
|
||||||
|
Computes :math:`-\mathrm{logpreds} \cdot \mathrm{targets}` if you assume
|
||||||
|
targets to be one-hot-encoded. Also works for targets that are not
|
||||||
|
one-hot-encoded, in this case only uses targets that are in the range
|
||||||
|
of the expected class labels, i.e., [0,log_preds.size()[1]-1].
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
log_preds: torch.autograd.Variable`
|
||||||
|
Logarithm of softmax output.
|
||||||
|
targets: `torch.autograd.Variable`
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
|
||||||
|
loss: `torch.autograd.Variable`
|
||||||
|
"""
|
||||||
|
if log_preds.size() == targets.size():
|
||||||
|
return log_categorical_crossentropy_1_hot(log_preds, targets)
|
||||||
|
n_classes = log_preds.size()[1]
|
||||||
|
n_elements = 0
|
||||||
|
losses = []
|
||||||
|
for i_class in range(n_classes):
|
||||||
|
mask = targets == i_class
|
||||||
|
mask = mask.type_as(log_preds)
|
||||||
|
n_elements -= th.sum(mask)
|
||||||
|
losses.append(th.sum(mask * log_preds[:,i_class]))
|
||||||
|
return th.sum(th.stack(losses)) / n_elements
|
||||||
|
|
||||||
|
|
||||||
def l2_loss(model):
|
def l2_loss(model):
|
||||||
losses = [th.sum(p * p) for p in model.parameters()]
|
losses = [th.sum(p * p) for p in model.parameters()]
|
||||||
return sum(losses)
|
return sum(losses)
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = "0.2.0"
|
||||||
@@ -21,3 +21,13 @@ apidoc:
|
|||||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
%: Makefile
|
%: Makefile
|
||||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
|
||||||
|
removeipynbcheckpoints: Makefile
|
||||||
|
rm -rf notebooks/.ipynb_checkpoints/ notebooks/visualization/.ipynb_checkpoints/
|
||||||
|
|
||||||
|
removesource: Makefile
|
||||||
|
rm -rf source/
|
||||||
|
|
||||||
|
rmanddoc: removesource removeipynbcheckpoints apidoc html
|
||||||
|
echo "Done"
|
||||||
|
|||||||
+6
-21
@@ -82,8 +82,9 @@ autodoc_member_order = 'bysource'
|
|||||||
## Default flags used by autodoc directives
|
## Default flags used by autodoc directives
|
||||||
autodoc_default_flags = ['members', 'show-inheritance']
|
autodoc_default_flags = ['members', 'show-inheritance']
|
||||||
|
|
||||||
exclude_patterns = ['_build', '_templates',]
|
exclude_patterns = ['_build', '_templates']
|
||||||
|
|
||||||
|
|
||||||
napoleon_google_docstring = False
|
napoleon_google_docstring = False
|
||||||
napoleon_use_param = False
|
napoleon_use_param = False
|
||||||
napoleon_use_ivar = True
|
napoleon_use_ivar = True
|
||||||
@@ -110,9 +111,10 @@ author = 'Robin Tibor Schirrmeister'
|
|||||||
# built documents.
|
# built documents.
|
||||||
#
|
#
|
||||||
# The short X.Y version.
|
# The short X.Y version.
|
||||||
version = '0.1.9'
|
import braindecode
|
||||||
|
version = braindecode.__version__
|
||||||
# The full version, including alpha/beta/rc tags.
|
# The full version, including alpha/beta/rc tags.
|
||||||
release = '0.1.9'
|
release = version
|
||||||
|
|
||||||
# The language for content autogenerated by Sphinx. Refer to documentation
|
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||||
# for a list of supported languages.
|
# for a list of supported languages.
|
||||||
@@ -211,20 +213,3 @@ texinfo_documents = [
|
|||||||
author, 'Braindecode', 'One line description of project.',
|
author, 'Braindecode', 'One line description of project.',
|
||||||
'Miscellaneous'),
|
'Miscellaneous'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
## mock stuff
|
|
||||||
"""
|
|
||||||
import sys
|
|
||||||
from mock import Mock as MagicMock
|
|
||||||
#from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
class Mock(MagicMock):
|
|
||||||
@classmethod
|
|
||||||
def __getattr__(cls, name):
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
MOCK_MODULES = ['torch', 'h5py',]
|
|
||||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
|
||||||
|
|
||||||
"""
|
|
||||||
+1
-1
@@ -3,7 +3,7 @@ Welcome to Braindecode
|
|||||||
|
|
||||||
A deep learning toolbox to decode raw time-domain EEG.
|
A deep learning toolbox to decode raw time-domain EEG.
|
||||||
|
|
||||||
For EEG researchers that want to want to work with deep learning and
|
For EEG researchers that want to work with deep learning and
|
||||||
deep learning researchers that want to work with EEG data.
|
deep learning researchers that want to work with EEG data.
|
||||||
For now focussed on convolutional networks.
|
For now focussed on convolutional networks.
|
||||||
|
|
||||||
|
|||||||
@@ -449,6 +449,9 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"git": {
|
||||||
|
"keep_outputs": true
|
||||||
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
|
|||||||
@@ -30,4 +30,12 @@ braindecode\.util module
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
braindecode\.version module
|
||||||
|
---------------------------
|
||||||
|
|
||||||
|
.. automodule:: braindecode.version
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+7
-4
@@ -9,13 +9,16 @@ with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
|
|||||||
long_description = f.read()
|
long_description = f.read()
|
||||||
|
|
||||||
|
|
||||||
|
# This will add __version__ to version dict
|
||||||
|
version = {}
|
||||||
|
with open(path.join(here, 'braindecode/version.py'), encoding='utf-8') as (
|
||||||
|
version_file):
|
||||||
|
exec(version_file.read(), version)
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='Braindecode',
|
name='Braindecode',
|
||||||
|
|
||||||
# Versions should comply with PEP440. For a discussion on single-sourcing
|
version=version['__version__'],
|
||||||
# the version across setup.py and the project code, see
|
|
||||||
# http://packaging.python.org/en/latest/tutorial.html#version
|
|
||||||
version='0.1.9', # TODO: read from __init__.py?
|
|
||||||
|
|
||||||
description='A deep learning toolbox to decode raw time-domain EEG.',
|
description='A deep learning toolbox to decode raw time-domain EEG.',
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
|||||||
Referência em uma Nova Issue
Bloquear um usuário