now all trial segmentation happens by creating continuous labels first

Esse commit está contido em:
Robin Tibor Schirrmeister
2017-10-18 15:33:12 +02:00
commit 750e060f3d
7 arquivos alterados com 848 adições e 687 exclusões
+346 -369
Ver Arquivo
@@ -13,9 +13,12 @@ marker_def = OrderedDict(
(['Right', 1], ['Left', [2]], ['Rest', 3], ['Feet', 4]))
def create_signal_target_from_raw_mne(raw, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes=None,
pad_to_n_samples=None):
def create_signal_target_from_raw_mne(
raw, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes=None,
prepad_trials_to_n_samples=None,
one_hot_labels=False,
one_label_per_trial=True):
"""
Create SignalTarget set from given `mne.io.RawArray`.
@@ -36,9 +39,14 @@ def create_signal_target_from_raw_mne(raw, name_to_start_codes, epoch_ival_ms,
Dictionary mapping class names to stop marker code or stop marker codes.
Order does not matter, dictionary should contain each class in
`name_to_codes` dictionary.
pad_to_n_samples: int
prepad_trials_to_n_samples: int
Pad trials that would be too short with the signal before it (only
valid name_to_stop_codes is not None.
valid if name_to_stop_codes is not None).
one_hot_labels: bool, optional
Whether to have the labels in a one-hot format, e.g. [0,0,1] or to
have them just as an int, e.g. 2
one_label_per_trial: bool, optional
Whether to have a timeseries of labels or just a single label per trial.
Returns
-------
@@ -49,14 +57,18 @@ def create_signal_target_from_raw_mne(raw, name_to_start_codes, epoch_ival_ms,
events = np.array([raw.info['events'][:,0],
raw.info['events'][:,2]]).T
fs = raw.info['sfreq']
return create_signal_target(data, events, fs, name_to_start_codes,
epoch_ival_ms,
name_to_stop_codes=name_to_stop_codes,
pad_to_n_samples=pad_to_n_samples)
return create_signal_target(
data, events, fs, name_to_start_codes,
epoch_ival_ms,
name_to_stop_codes=name_to_stop_codes,
prepad_trials_to_n_samples=prepad_trials_to_n_samples,
one_hot_labels=one_hot_labels,
one_label_per_trial=one_label_per_trial)
def create_signal_target(data, events, fs, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes=None, pad_to_n_samples=None):
name_to_stop_codes=None, prepad_trials_to_n_samples=None,
one_hot_labels=False, one_label_per_trial=True):
"""
Create SignalTarget set given continuous data.
@@ -83,10 +95,14 @@ def create_signal_target(data, events, fs, name_to_start_codes, epoch_ival_ms,
Dictionary mapping class names to stop marker code or stop marker codes.
Order does not matter, dictionary should contain each class in
`name_to_codes` dictionary.
pad_to_n_samples: int
prepad_trials_to_n_samples: int, optional
Pad trials that would be too short with the signal before it (only
valid name_to_stop_codes is not None.
valid if name_to_stop_codes is not None).
one_hot_labels: bool, optional
Whether to have the labels in a one-hot format, e.g. [0,0,1] or to
have them just as an int, e.g. 2
one_label_per_trial: bool, optional
Whether to have a timeseries of labels or just a single label per trial.
Returns
-------
@@ -96,11 +112,15 @@ def create_signal_target(data, events, fs, name_to_start_codes, epoch_ival_ms,
"""
if name_to_stop_codes is None:
return _create_signal_target_from_start_and_ival(
data, events, fs, name_to_start_codes, epoch_ival_ms)
data, events, fs, name_to_start_codes, epoch_ival_ms,
one_hot_labels=one_hot_labels,
one_label_per_trial=one_label_per_trial)
else:
return _create_signal_target_from_start_and_stop(
data, events, fs, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes, pad_to_n_samples)
name_to_stop_codes, prepad_trials_to_n_samples,
one_hot_labels=one_hot_labels,
one_label_per_trial=one_label_per_trial)
def _to_mrk_code_to_name_and_y(name_to_codes):
@@ -119,255 +139,72 @@ def _to_mrk_code_to_name_and_y(name_to_codes):
def _create_signal_target_from_start_and_ival(
data, events, fs, name_to_codes, epoch_ival_ms):
data, events, fs, name_to_codes, epoch_ival_ms,
one_hot_labels, one_label_per_trial):
cnt_y, i_start_stops = _create_cnt_y_and_trial_bounds_from_start_and_ival(
data.shape[1], events, fs, name_to_codes, epoch_ival_ms
)
signal_target = _create_signal_target_from_cnt_y_start_stops(
data, cnt_y, i_start_stops, prepad_trials_to_n_samples=None,
one_hot_labels=one_hot_labels,
one_label_per_trial=one_label_per_trial)
# make into arrray as all should have same dimensions
signal_target.X = np.array(signal_target.X, dtype=np.float32)
signal_target.y = np.array(signal_target.y, dtype=np.int64)
return signal_target
def _create_cnt_y_and_trial_bounds_from_start_and_ival(
n_samples, events, fs, name_to_start_codes, epoch_ival_ms):
ival_in_samples = ms_to_samples(np.array(epoch_ival_ms), fs)
start_offset = np.int32(np.round(ival_in_samples[0]))
# we will use ceil but exclusive...
stop_offset = np.int32(np.ceil(ival_in_samples[1]))
mrk_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_codes)
mrk_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
class_to_n_trials = Counter()
X = []
y = []
n_classes = len(name_to_start_codes)
cnt_y = np.zeros((n_samples, n_classes), dtype=np.int64)
i_start_stops = []
for i_sample, mrk_code in zip(events[:, 0], events[:, 1]):
start_sample = int(i_sample) + start_offset
stop_sample = int(i_sample) + stop_offset
if mrk_code in mrk_code_to_name_and_y:
if start_sample < 0:
log.warning("Ignore trial with marker code {:d}, would start at "
"sample {:d}".format(mrk_code, start_sample))
log.warning(
"Ignore trial with marker code {:d}, would start at "
"sample {:d}".format(mrk_code, start_sample))
continue
if stop_sample > data.shape[1]:
if stop_sample > n_samples:
log.warning("Ignore trial with marker code {:d}, would end at "
"sample {:d} of {:d}".format(mrk_code, stop_sample-1,
data.shape[1]-1))
"sample {:d} of {:d}".format(
mrk_code, stop_sample - 1, n_samples - 1))
continue
name, this_y = mrk_code_to_name_and_y[mrk_code]
X.append(data[:, start_sample:stop_sample].astype(np.float32))
y.append(np.int64(this_y))
i_start_stops.append((start_sample, stop_sample))
cnt_y[start_sample:stop_sample, this_y] = 1
class_to_n_trials[name] += 1
log.info("Trial per class:\n{:s}".format(str(class_to_n_trials)))
return SignalAndTarget(np.array(X), np.array(y))
return cnt_y, i_start_stops
def _create_signal_target_from_start_and_stop(
data, events, fs, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes, pad_to_n_samples=None):
name_to_stop_codes, prepad_trials_to_n_samples,
one_hot_labels, one_label_per_trial, ):
assert np.array_equal(list(name_to_start_codes.keys()),
list(name_to_stop_codes.keys()))
ival_in_samples = ms_to_samples(np.array(epoch_ival_ms), fs)
start_offset = np.int32(np.round(ival_in_samples[0]))
# we will use ceil but exclusive...
stop_offset = np.int32(np.ceil(ival_in_samples[1]))
start_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
# Ensure all stop marker codes are iterables
for name in name_to_stop_codes:
codes = name_to_stop_codes[name]
if not hasattr(codes, '__len__'):
name_to_stop_codes[name] = [codes]
all_stop_codes = np.concatenate(list(name_to_stop_codes.values()))
class_to_n_trials = Counter()
X = []
y = []
event_samples = events[:, 0]
event_codes = events[:, 1]
i_event = 0
first_start_code_found = False
while i_event < len(events):
while i_event < len(events) and (
event_codes[i_event] not in start_code_to_name_and_y):
i_event += 1
if i_event < len(events):
start_sample = event_samples[i_event]
start_code = event_codes[i_event]
start_name = start_code_to_name_and_y[start_code][0]
start_y = start_code_to_name_and_y[start_code][1]
i_event += 1
first_start_code_found = True
waiting_for_end_code = True
while i_event < len(events) and (
event_codes[i_event] not in all_stop_codes):
if event_codes[i_event] in start_code_to_name_and_y:
log.warning(
"New start marker {:.0f} at {:.0f} samples found, "
"no end marker for earlier start marker {:.0f} "
"at {:.0f} samples found.".format(
event_codes[i_event], event_samples[i_event],
start_code, start_sample))
start_sample = event_samples[i_event]
start_name = start_code_to_name_and_y[start_code][0]
start_code = event_codes[i_event]
start_y = start_code_to_name_and_y[start_code][1]
i_event += 1
if i_event == len(events):
if waiting_for_end_code:
log.warning(("No end marker for start marker code {:.0f} "
"at sample {:.0f} found.").format(start_code,
start_sample))
elif (not first_start_code_found):
log.warning("No markers found at all.")
break
stop_sample = event_samples[i_event]
stop_code = event_codes[i_event]
assert stop_code in name_to_stop_codes[start_name]
i_start = int(start_sample) + start_offset
i_stop = int(stop_sample) + stop_offset
waiting_for_end_code = False
if (pad_to_n_samples is not None) and (
(i_stop - i_start) < pad_to_n_samples):
if i_stop < pad_to_n_samples:
log.warning("Could not pad trial enough, therefore not "
"not using trial from {:d} to {:d}".format(
i_start, i_stop
))
continue
i_start = i_stop - pad_to_n_samples
if i_start < 0:
log.warning("Ignore trial with start code {:d}, would start at "
"sample {:d}".format(start_code, i_start))
continue
if i_stop > data.shape[1]:
log.warning("Ignore trial with stop code {:d}, would end at "
"sample {:d} of {:d}".format(stop_code, i_stop - 1,
data.shape[1] - 1))
continue
X.append(data[:, i_start:i_stop].astype(np.float32))
y.append(np.int64(start_y))
class_to_n_trials[start_name] += 1
log.info("Trial per class:\n{:s}".format(str(class_to_n_trials)))
return SignalAndTarget(X, np.array(y))
cnt_y, i_start_stops = _create_cnt_y_and_trial_bounds_from_start_stop(
data.shape[1],events, fs,name_to_start_codes, epoch_ival_ms,
name_to_stop_codes)
signal_target = _create_signal_target_from_cnt_y_start_stops(
data, cnt_y, i_start_stops,
prepad_trials_to_n_samples=prepad_trials_to_n_samples,
one_hot_labels=one_hot_labels, one_label_per_trial=one_label_per_trial)
return signal_target
def add_breaks(
events, fs, break_start_code, break_stop_code, name_to_start_codes,
name_to_stop_codes, min_break_length_ms=None,
max_break_length_ms=None, break_start_offset_ms=None,
break_stop_offset_ms=None):
"""
Add break events to given events.
Parameters
----------
events: 2d-array
Dimensions: Number of events, 2. For each event, should contain sample
index and marker code.
fs: number
Sampling rate.
break_start_code: int
Marker code that will be used for break start markers.
break_stop_code: int
Marker code that will be used for break stop markers.
name_to_start_codes: OrderedDict (str -> int or list of int)
Ordered dictionary mapping class names to start marker code or
start marker codes.
name_to_stop_codes: dict (str -> int or list of int), optional
Dictionary mapping class names to stop marker code or stop marker codes.
min_break_length_ms: number, optional
Minimum length in milliseconds a break should have to be included.
max_break_length_ms: number, optional
Maximum length in milliseconds a break can have to be included.
Returns
-------
events: 2d-array
Events with break start and stop markers.
"""
min_samples = (None if min_break_length_ms is None
else ms_to_samples(min_break_length_ms, fs))
max_samples = (None if max_break_length_ms is None
else ms_to_samples(max_break_length_ms, fs))
orig_events = events
break_starts, break_stops = _extract_break_start_stop_ms(
events, name_to_start_codes, name_to_stop_codes)
break_durations = break_stops - break_starts
valid_mask = np.array([True] * len(break_starts))
if min_samples is not None:
valid_mask[break_durations < min_samples] = False
if max_samples is not None:
valid_mask[break_durations > max_samples] = False
if sum(valid_mask) == 0:
return deepcopy(events)
break_starts = break_starts[valid_mask]
break_stops = break_stops[valid_mask]
if break_start_offset_ms is not None:
break_starts += int(round(ms_to_samples(break_start_offset_ms, fs)))
if break_stop_offset_ms is not None:
break_stops += int(round(ms_to_samples(break_stop_offset_ms, fs)))
break_events = np.zeros((len(break_starts) * 2, 2))
break_events[0::2,0] = break_starts
break_events[1::2,0] = break_stops
break_events[0::2,1] = break_start_code
break_events[1::2,1] = break_stop_code
new_events = np.concatenate((orig_events, break_events))
# sort events
sort_order = np.argsort(new_events[:,0], kind='mergesort')
new_events = new_events[sort_order]
return new_events
def _extract_break_start_stop_ms(events, name_to_start_codes,
name_to_stop_codes):
assert len(events[0]) == 2, "expect only 2dimensional event array here"
start_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
# Ensure all stop marker codes are iterables
for name in name_to_stop_codes:
codes = name_to_stop_codes[name]
if not hasattr(codes, '__len__'):
name_to_stop_codes[name] = [codes]
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())).astype(np.int32)
event_samples = events[:, 0]
event_codes = events[:, 1]
break_starts = []
break_stops = []
i_event = 0
while i_event < len(events):
while (i_event < len(events)) and (
event_codes[i_event] not in all_stop_codes):
i_event += 1
if i_event < len(events):
# one sample after start
stop_sample = event_samples[i_event]
stop_code = event_codes[i_event]
i_event += 1
while (i_event < len(events)) and (
event_codes[i_event] not in start_code_to_name_and_y):
if event_codes[i_event] in all_stop_codes:
log.warning(
"New end marker {:.0f} at {:.0f} samples found, "
"no start marker for earlier end marker {:.0f} "
"at {:.0f} samples found.".format(
event_codes[i_event],
event_samples[i_event],
stop_code, stop_sample))
stop_sample = event_samples[i_event] + 1
stop_code = event_codes[i_event]
i_event += 1
if i_event == len(events):
break
start_sample = event_samples[i_event]
start_code = event_codes[i_event]
assert start_code in start_code_to_name_and_y
# let's start one after stop of the trial and stop one efore
# start of the trial to ensure that markers will be
# in right order
break_starts.append(stop_sample + 1)
break_stops.append(start_sample - 1)
return np.array(break_starts), np.array(break_stops)
def create_cnt_y_and_start_stop_samples(
def _create_cnt_y_and_trial_bounds_from_start_stop(
n_samples, events, fs, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes):
"""
@@ -397,6 +234,13 @@ def create_cnt_y_and_start_stop_samples(
Order does not matter, dictionary should contain each class in
`name_to_codes` dictionary.
Returns
-------
cnt_y: 2d-array
Timeseries of one-hot-labels, time x classes.
trial_bounds: list of (int,int)
List of (trial_start, trial_stop) tuples.
"""
assert np.array_equal(list(name_to_start_codes.keys()),
@@ -412,7 +256,8 @@ def create_cnt_y_and_start_stop_samples(
codes = name_to_stop_codes[name]
if not hasattr(codes, '__len__'):
name_to_stop_codes[name] = [codes]
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())).astype(np.int64)
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())
).astype(np.int64)
class_to_n_trials = Counter()
n_classes = len(name_to_start_codes)
cnt_y = np.zeros((n_samples, n_classes), dtype=np.int64)
@@ -471,137 +316,21 @@ def create_cnt_y_and_start_stop_samples(
return cnt_y, i_start_stops
def create_signal_target_with_breaks_from_mne(
cnt, name_to_start_codes,
trial_segment_ival_ms,
name_to_stop_codes,
min_break_length_ms, max_break_length_ms,
break_segment_ival_ms,
pad_trials_to_n_samples=None):
assert 'Break' not in name_to_start_codes
# Create new marker codes for start and stop of breaks
# Use marker codes that did not exist in the given marker codes...
all_start_codes = np.concatenate(
[np.atleast_1d(vals) for vals in name_to_start_codes.values()])
all_stop_codes = np.concatenate(
[np.atleast_1d(vals) for vals in name_to_stop_codes.values()])
break_start_code = -1
while break_start_code in np.concatenate((all_start_codes, all_stop_codes)):
break_start_code -= 1
break_stop_code = break_start_code - 1
while break_stop_code in np.concatenate((all_start_codes, all_stop_codes)):
break_stop_code -= 1
events = cnt.info['events'][:, [0, 2]]
# later trial segment ival will be added when creating set
# so remove it here
break_segment_ival_ms = np.array(break_segment_ival_ms) - (
np.array(trial_segment_ival_ms))
events_with_breaks = add_breaks(events, cnt.info['sfreq'],
break_start_code, break_stop_code,
name_to_start_codes, name_to_stop_codes,
min_break_length_ms=min_break_length_ms,
max_break_length_ms=max_break_length_ms,
break_start_offset_ms=break_segment_ival_ms[
0],
break_stop_offset_ms=break_segment_ival_ms[
1])
name_to_start_codes_with_breaks = deepcopy(name_to_start_codes)
name_to_start_codes_with_breaks['Break'] = break_start_code
name_to_stop_codes_with_breaks = deepcopy(name_to_stop_codes)
name_to_stop_codes_with_breaks['Break'] = break_stop_code
data = cnt.get_data()
fs = cnt.info['sfreq']
signal_target = create_signal_target_with_cnt_y(
data, events_with_breaks, fs,
name_to_start_codes_with_breaks, trial_segment_ival_ms,
name_to_stop_codes_with_breaks,
pad_to_n_samples=pad_trials_to_n_samples)
return signal_target
def create_signal_target_with_cnt_y_from_raw_mne(
raw, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes,
pad_to_n_samples=None):
data = raw.get_data()
events = raw.info['events'][:,[0,2]]
fs = raw.info['sfreq']
return create_signal_target_with_cnt_y(data, events, fs,
name_to_start_codes, epoch_ival_ms,
name_to_stop_codes,
pad_to_n_samples=pad_to_n_samples
)
def create_signal_target_with_cnt_y(data, events, fs,
name_to_start_codes, epoch_ival_ms,
name_to_stop_codes,
pad_to_n_samples=None
):
"""
Create a signal
Parameters
----------
data: 2d-array of number
The continuous recorded data. Channels x times order.
events: 2d-array
Dimensions: Number of events, 2. For each event, should contain sample
index and marker code.
fs: number
Sampling rate.
name_to_start_codes: OrderedDict (str -> int or list of int)
Ordered dictionary mapping class names to marker code or marker codes.
y-labels will be assigned in increasing key order, i.e.
first classname gets y-value 0, second classname y-value 1, etc.
epoch_ival_ms: iterable of (int,int)
Epoching interval in milliseconds. In case only `name_to_codes` given,
represents start offset and stop offset from start markers. In case
`name_to_stop_codes` given, represents offset from start marker
and offset from stop marker. E.g. [500, -500] would mean 500ms
after the start marker until 500 ms before the stop marker.
name_to_stop_codes: dict (str -> int or list of int), optional
Dictionary mapping class names to stop marker code or stop marker codes.
Order does not matter, dictionary should contain each class in
`name_to_codes` dictionary.
pad_to_n_samples: int, optional
Use signal before trial start to pad trials that are otherwise too small.
Returns
-------
dataset: :class:`.SignalAndTarget`
Dataset with `X` as the trial signals and `y` as the trial labels,
one array per trial, as labels can be different within one trial.
"""
cnt_y, i_start_stops = create_cnt_y_and_start_stop_samples(
data.shape[1], events, fs,
name_to_start_codes,
epoch_ival_ms, name_to_stop_codes, )
return create_signal_target_from_cnt_y_start_stops(
data, cnt_y, i_start_stops, pad_to_n_samples, one_hot_y=False,
one_label_per_trial=True)
def create_signal_target_from_cnt_y_start_stops(
def _create_signal_target_from_cnt_y_start_stops(
data,
cnt_y,
i_start_stops,
pad_to_n_samples,
one_hot_y,
prepad_trials_to_n_samples,
one_hot_labels,
one_label_per_trial):
if pad_to_n_samples is not None:
if prepad_trials_to_n_samples is not None:
new_i_start_stops = []
for i_start, i_stop in i_start_stops:
if (i_stop - i_start) > pad_to_n_samples:
if (i_stop - i_start) > prepad_trials_to_n_samples:
new_i_start_stops.append((i_start, i_stop))
elif i_stop >= pad_to_n_samples:
elif i_stop >= prepad_trials_to_n_samples:
new_i_start_stops.append(
(i_stop - pad_to_n_samples, i_stop))
(i_stop - prepad_trials_to_n_samples, i_stop))
else:
log.warning("Could not pad trial enough, therefore not "
"not using trial from {:d} to {:d}".format(
@@ -622,30 +351,278 @@ def create_signal_target_from_cnt_y_start_stops(
))
continue
if i_stop > data.shape[1]:
log.warning("Trial stop too late, therefore not "
log.warning("Trial stop too late (past {:d}), therefore not "
"not using trial from {:d} to {:d}".format(
data.shape[1] - 1,
i_start, i_stop
))
continue
X.append(data[:, i_start:i_stop].astype(np.float32))
y.append(cnt_y[i_start:i_stop])
if not one_hot_y:
# take last label always
if one_label_per_trial:
new_y = []
for this_y in y:
# if destroying one hot later, just set most occuring class to 1
unique_labels, counts = np.unique(
this_y, axis=0, return_counts=True)
if not one_hot_labels:
meaned_y = np.mean(this_y, axis=0)
this_new_y = np.zeros_like(meaned_y)
this_new_y[np.argmax(meaned_y)] = 1
else:
# take most frequency occurring label combination
this_new_y = unique_labels[np.argmax(counts)]
if len(unique_labels) > 1:
log.warning("Different labels within one trial: {:s},"
"setting single trial label to {:s}".format(
str(unique_labels), str(this_new_y)
))
new_y.append(this_new_y)
y = new_y
if not one_hot_labels:
# change from one-hot-encoding to regular encoding
# with -1 as indication none of the classes are present
new_y = []
for this_y in y:
this_new_y = np.argmax(this_y, axis=1)
this_new_y[np.sum(this_y, axis=1) == 0] = -1
if one_label_per_trial:
if np.sum(this_y) == 0:
this_new_y = -1
else:
this_new_y = np.argmax(this_y)
if np.sum(this_y) > 1:
log.warning(
"Have multiple active classes and will convert to "
"lowest class")
else:
if np.max(np.sum(this_y, axis=1)) > 1:
log.warning(
"Have multiple active classes and will convert to "
"lowest class")
this_new_y = np.argmax(this_y, axis=1)
this_new_y[np.sum(this_y, axis=1) == 0] = -1
new_y.append(this_new_y)
y = new_y
# take last label always
if one_label_per_trial:
y = [this_y[-1] for this_y in y]
y = np.array(y, dtype=np.int64)
return SignalAndTarget(X, y)
def create_signal_target_with_breaks_from_mne(
cnt, name_to_start_codes,
trial_epoch_ival_ms,
name_to_stop_codes,
min_break_length_ms, max_break_length_ms,
break_epoch_ival_ms,
prepad_trials_to_n_samples=None):
"""
Create SignalTarget set from given `mne.io.RawArray`.
Parameters
----------
cnt: `mne.io.RawArray`
name_to_start_codes: OrderedDict (str -> int or list of int)
Ordered dictionary mapping class names to marker code or marker codes.
y-labels will be assigned in increasing key order, i.e.
first classname gets y-value 0, second classname y-value 1, etc.
trial_epoch_ival_ms: iterable of (int,int)
Epoching interval in milliseconds. Represents offset from start marker
and offset from stop marker. E.g. [500, -500] would mean 500ms
after the start marker until 500 ms before the stop marker.
name_to_stop_codes: dict (str -> int or list of int), optional
Dictionary mapping class names to stop marker code or stop marker codes.
Order does not matter, dictionary should contain each class in
`name_to_codes` dictionary.
min_break_length_ms: number
Breaks below this length are excluded.
max_break_length_ms: number
Breaks above this length are excluded.
break_epoch_ival_ms: number
Break ival, offset from trial end to start of the break in ms and
offset from trial start to end of break in ms.
prepad_trials_to_n_samples: int
Pad trials that would be too short with the signal before it (only
valid if name_to_stop_codes is not None).
Returns
-------
dataset: :class:`.SignalAndTarget`
Dataset with `X` as the trial signals and `y` as the trial labels.
Labels as timeseries and of integers, i.e., not one-hot encoded.
"""
assert 'Break' not in name_to_start_codes
# Create new marker codes for start and stop of breaks
# Use marker codes that did not exist in the given marker codes...
all_start_codes = np.concatenate(
[np.atleast_1d(vals) for vals in name_to_start_codes.values()])
all_stop_codes = np.concatenate(
[np.atleast_1d(vals) for vals in name_to_stop_codes.values()])
break_start_code = -1
while break_start_code in np.concatenate((all_start_codes, all_stop_codes)):
break_start_code -= 1
break_stop_code = break_start_code - 1
while break_stop_code in np.concatenate((all_start_codes, all_stop_codes)):
break_stop_code -= 1
events = cnt.info['events'][:, [0, 2]]
# later trial segment ival will be added when creating set
# so remove it here
break_epoch_ival_ms = np.array(break_epoch_ival_ms) - (
np.array(trial_epoch_ival_ms))
events_with_breaks = add_breaks(events, cnt.info['sfreq'],
break_start_code, break_stop_code,
name_to_start_codes, name_to_stop_codes,
min_break_length_ms=min_break_length_ms,
max_break_length_ms=max_break_length_ms,
break_start_offset_ms=break_epoch_ival_ms[
0],
break_stop_offset_ms=break_epoch_ival_ms[
1])
name_to_start_codes_with_breaks = deepcopy(name_to_start_codes)
name_to_start_codes_with_breaks['Break'] = break_start_code
name_to_stop_codes_with_breaks = deepcopy(name_to_stop_codes)
name_to_stop_codes_with_breaks['Break'] = break_stop_code
data = cnt.get_data()
fs = cnt.info['sfreq']
signal_target = create_signal_target(
data, events_with_breaks, fs,
name_to_start_codes_with_breaks, trial_epoch_ival_ms,
name_to_stop_codes_with_breaks,
prepad_trials_to_n_samples=prepad_trials_to_n_samples,
one_hot_labels=False,
one_label_per_trial=False)
return signal_target
def add_breaks(
events, fs, break_start_code, break_stop_code, name_to_start_codes,
name_to_stop_codes, min_break_length_ms=None,
max_break_length_ms=None, break_start_offset_ms=None,
break_stop_offset_ms=None):
"""
Add break events to given events.
Parameters
----------
events: 2d-array
Dimensions: Number of events, 2. For each event, should contain sample
index and marker code.
fs: number
Sampling rate.
break_start_code: int
Marker code that will be used for break start markers.
break_stop_code: int
Marker code that will be used for break stop markers.
name_to_start_codes: OrderedDict (str -> int or list of int)
Ordered dictionary mapping class names to start marker code or
start marker codes.
name_to_stop_codes: dict (str -> int or list of int), optional
Dictionary mapping class names to stop marker code or stop marker codes.
min_break_length_ms: number, optional
Minimum length in milliseconds a break should have to be included.
max_break_length_ms: number, optional
Maximum length in milliseconds a break can have to be included.
break_start_offset_ms: number, optional
What offset from trial end to start of the break in ms.
break_stop_offset_ms: number, optional
What offset from next trial start end to previous break end in ms.
Returns
-------
events: 2d-array
Events with break start and stop markers.
"""
min_samples = (None if min_break_length_ms is None
else ms_to_samples(min_break_length_ms, fs))
max_samples = (None if max_break_length_ms is None
else ms_to_samples(max_break_length_ms, fs))
orig_events = events
break_starts, break_stops = _extract_break_start_stop_ms(
events, name_to_start_codes, name_to_stop_codes)
break_durations = break_stops - break_starts
valid_mask = np.array([True] * len(break_starts))
if min_samples is not None:
valid_mask[break_durations < min_samples] = False
if max_samples is not None:
valid_mask[break_durations > max_samples] = False
if sum(valid_mask) == 0:
return deepcopy(events)
break_starts = break_starts[valid_mask]
break_stops = break_stops[valid_mask]
if break_start_offset_ms is not None:
break_starts += int(round(ms_to_samples(break_start_offset_ms, fs)))
if break_stop_offset_ms is not None:
break_stops += int(round(ms_to_samples(break_stop_offset_ms, fs)))
break_events = np.zeros((len(break_starts) * 2, 2))
break_events[0::2, 0] = break_starts
break_events[1::2, 0] = break_stops
break_events[0::2, 1] = break_start_code
break_events[1::2, 1] = break_stop_code
new_events = np.concatenate((orig_events, break_events))
# sort events
sort_order = np.argsort(new_events[:, 0], kind='mergesort')
new_events = new_events[sort_order]
return new_events
def _extract_break_start_stop_ms(events, name_to_start_codes,
name_to_stop_codes):
assert len(events[0]) == 2, "expect only 2dimensional event array here"
start_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
# Ensure all stop marker codes are iterables
for name in name_to_stop_codes:
codes = name_to_stop_codes[name]
if not hasattr(codes, '__len__'):
name_to_stop_codes[name] = [codes]
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())).astype(
np.int32)
event_samples = events[:, 0]
event_codes = events[:, 1]
break_starts = []
break_stops = []
i_event = 0
while i_event < len(events):
while (i_event < len(events)) and (
event_codes[i_event] not in all_stop_codes):
i_event += 1
if i_event < len(events):
# one sample after start
stop_sample = event_samples[i_event]
stop_code = event_codes[i_event]
i_event += 1
while (i_event < len(events)) and (
event_codes[i_event] not in start_code_to_name_and_y):
if event_codes[i_event] in all_stop_codes:
log.warning(
"New end marker {:.0f} at {:.0f} samples found, "
"no start marker for earlier end marker {:.0f} "
"at {:.0f} samples found.".format(
event_codes[i_event],
event_samples[i_event],
stop_code, stop_sample))
stop_sample = event_samples[i_event]
stop_code = event_codes[i_event]
i_event += 1
if i_event == len(events):
break
start_sample = event_samples[i_event]
start_code = event_codes[i_event]
assert start_code in start_code_to_name_and_y
# let's start one after stop of the trial and stop one efore
# start of the trial to ensure that markers will be
# in right order
break_starts.append(stop_sample + 1)
break_stops.append(start_sample - 1)
return np.array(break_starts), np.array(break_stops)
+1 -1
Ver Arquivo
@@ -1 +1 @@
__version__ = "0.2.1"
__version__ = "0.3.0"
-10
Ver Arquivo
@@ -149,16 +149,6 @@
"train_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.save('/data/schirrmr/schirrmr/BBCI_Data-train-set-X-tmp-test.npy', train_set.X)\n",
"np.save('/data/schirrmr/schirrmr/BBCI_Data-train-set-y-tmp-test.npy', train_set.y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
+31 -72
Ver Arquivo
@@ -4,6 +4,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"nbsphinx": "hidden"
},
"outputs": [],
@@ -81,39 +82,18 @@
"from braindecode.datautil.signalproc import lowpass_cnt\n",
"from braindecode.datautil.signalproc import exponential_running_standardize\n",
"\n",
"def create_cnts(folder, runs, name_to_start_code, name_to_stop_code, break_start_offset_ms,\n",
" break_stop_offset_ms, break_start_code, break_stop_code):\n",
"def create_cnts(folder, runs,):\n",
" # Load data\n",
" cnts = load_bbci_sets_from_folder(folder, runs)\n",
" \n",
" # Now do some preprocessings:\n",
" # Adding breaks, resampling to 250 Hz, lowpass below 38, eponential standardization\n",
" break_start_code = -1\n",
" break_stop_code = -2\n",
" # Resampling to 250 Hz, lowpass below 38, eponential standardization\n",
" \n",
" new_cnts = []\n",
" for cnt in cnts:\n",
" # Only take some channels \n",
" #cnt = cnt.drop_channels(['STI 014']) # This would remove stimulus channel\n",
" cnt = cnt.pick_channels(['C3', 'CPz' 'C4'])\n",
" # add breaks\n",
" new_events = add_breaks(\n",
" np.array(cnt.info['events'])[:, [0,2]], cnt.info['sfreq'],\n",
" break_start_code=break_start_code, break_stop_code=break_stop_code,\n",
" name_to_start_codes=name_to_start_code, name_to_stop_codes=name_to_stop_code,\n",
" min_break_length_ms=5000, max_break_length_ms=9000)\n",
" n_break_start_offset = int(cnt.info['sfreq'] * break_start_offset_ms / 1000.0)\n",
" n_break_stop_offset = int(cnt.info['sfreq'] * break_stop_offset_ms / 1000.0)\n",
" # lets add some offset to break start and stop\n",
" # new_events[:,2] contains event codes, new_events[:,0] the sample indices\n",
" # new_events[:,1] is always 0 for my loading of BBCI data\n",
" new_events[new_events[:,1] == break_start_code, 0] += n_break_start_offset\n",
" # 0.5 sec for break stop\n",
" new_events[new_events[:,1] == break_stop_code, 0] += n_break_stop_offset\n",
" # add a middle column with zeros again\n",
" new_events_for_mne = np.zeros((len(new_events), 3))\n",
" new_events_for_mne[:,[0,2]] = new_events\n",
" cnt.info['events'] = new_events_for_mne\n",
" cnt = cnt.pick_channels(['C3', 'CPz', 'C4'])\n",
" log.info(\"Preprocessing....\")\n",
" cnt = mne_apply(lambda a: lowpass_cnt(a, 38,cnt.info['sfreq'], axis=1), cnt)\n",
" cnt = resample_cnt(cnt, 250)\n",
@@ -135,30 +115,18 @@
"outputs": [],
"source": [
"from collections import OrderedDict\n",
"\n",
"train_runs = [1,2,3]\n",
"train_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/', \n",
" train_runs,)\n",
"\n",
"name_to_start_code = OrderedDict([('Right Hand', 1), ('Feet', 4),\n",
" ('Rotation', 8), ('Words', [10])])\n",
"\n",
"name_to_stop_code = OrderedDict([('Right Hand', [20,21,22,23,24,28,30]),\n",
" ('Feet', [20,21,22,23,24,28,30]),\n",
" ('Rotation', [20,21,22,23,24,28,30]), \n",
" ('Words', [20,21,22,23,24,28,30])])\n",
"\n",
"break_start_offset_ms = 1000\n",
"break_stop_offset_ms = -500\n",
"# pick some numbers that were not used before/do not exist in markers\n",
"break_start_code = -1\n",
"break_stop_code = -2\n",
"train_runs = [1,2,3]\n",
"train_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/', \n",
" train_runs,\n",
" name_to_start_code,\n",
" name_to_stop_code, break_start_offset_ms,\n",
" break_stop_offset_ms, break_start_code, break_stop_code)\n",
"\n",
"name_to_code_with_breaks = deepcopy(name_to_start_code)\n",
"name_to_code_with_breaks['Break'] = break_start_code\n",
"name_to_stop_code_with_breaks = deepcopy(name_to_stop_code)\n",
"name_to_stop_code_with_breaks['Break'] = break_stop_code"
" ('Words', [20,21,22,23,24,28,30])])\n"
]
},
{
@@ -168,9 +136,7 @@
"outputs": [],
"source": [
"test_runs = [9,10]\n",
"test_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/', test_runs, name_to_start_code,\n",
" name_to_stop_code, break_start_offset_ms,\n",
" break_stop_offset_ms, break_start_code, break_stop_code)"
"test_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/', test_runs,)"
]
},
{
@@ -241,8 +207,17 @@
"metadata": {},
"outputs": [],
"source": [
"train_sets = [create_signal_target_from_raw_mne(cnt, name_to_code_with_breaks, [-receptive_field_ms,0], \n",
" name_to_stop_code_with_breaks) for cnt in train_cnts]\n",
"from braindecode.datautil.trial_segment import create_signal_target_with_breaks_from_mne\n",
"\n",
"break_start_offset_ms = 1000\n",
"break_stop_offset_ms = -500\n",
"\n",
"train_sets = [create_signal_target_with_breaks_from_mne(\n",
" cnt, name_to_start_code, [0,0], \n",
" name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,\n",
" break_epoch_ival_ms=[500,-500],\n",
" prepad_trials_to_n_samples=input_time_length) \n",
" for cnt in train_cnts]\n",
"train_set = concatenate_sets(train_sets)"
]
},
@@ -252,32 +227,15 @@
"metadata": {},
"outputs": [],
"source": [
"test_sets = [create_signal_target_from_raw_mne(cnt, name_to_code_with_breaks, [-receptive_field_ms,0], \n",
" name_to_stop_code_with_breaks) for cnt in test_cnts]\n",
"test_sets = [create_signal_target_with_breaks_from_mne(\n",
" cnt, name_to_start_code, [0,0], \n",
" name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,\n",
" break_epoch_ival_ms=[500,-500],\n",
" prepad_trials_to_n_samples=input_time_length) \n",
" for cnt in test_cnts]\n",
"test_set = concatenate_sets(test_sets)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"train_set_2 = np.load('/data/schirrmr/schirrmr/BBCI_Data_Start_Stop-train-set-tmp-test.npy')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"pickle.dump(train_set, open('/data/schirrmr/schirrmr/BBCI_Data_Start_Stop-train-set-tmp-test.npy', 'wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -345,9 +303,10 @@
"import torch.nn.functional as F\n",
"import torch as th\n",
"from braindecode.torch_ext.modules import Expression\n",
"from braindecode.torch_ext.losses import log_categorical_crossentropy\n",
"\n",
"\n",
"loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2, keepdim=False), targets)\n",
"loss_function = log_categorical_crossentropy\n",
"\n",
"model_constraint = None\n",
"monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),\n",
@@ -380,7 +339,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We arrive only at 31% accuracy. With only 3 sensors and 3 training runs, cannot expect too much great performance :)"
"We arrive at about 54% accuracy. With only 3 sensors and 3 training runs, we cannot get much better :)"
]
}
],
+2 -58
Ver Arquivo
@@ -217,9 +217,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
@@ -323,61 +321,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We arrive at 24.3% or 26.3% depending on stars :))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n",
"exp2 = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, None, model_constraint,\n",
" monitors, stop_criterion, remember_best_column='valid_misclass',\n",
" run_after_early_stop=True, batch_modifier=None, cuda=cuda)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"exp2.setup_training()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"exp2.monitor_epoch(exp2.datasets)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exp2.print_epoch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exp2.epochs_df"
"We arrive around 26%, exact value depending on stars :))"
]
}
],
Diff do arquivo suprimido porque uma ou mais linhas são muito longas
+288 -3
Ver Arquivo
@@ -1,9 +1,17 @@
from collections import OrderedDict
import numpy as np
import pytest
from braindecode.datautil.trial_segment import (
create_cnt_y_and_start_stop_samples)
_create_cnt_y_and_trial_bounds_from_start_stop)
from braindecode.datautil.trial_segment import (
_create_signal_target_from_start_and_ival)
from braindecode.datautil.trial_segment import (
_create_signal_target_from_start_and_stop)
from braindecode.datautil.trial_segment import add_breaks
def check_cnt_y_start_stop_samples(n_samples, events, fs, epoch_ival_ms,
@@ -11,8 +19,9 @@ def check_cnt_y_start_stop_samples(n_samples, events, fs, epoch_ival_ms,
name_to_stop_codes, cnt_y, start_stop):
cnt_y = np.array(cnt_y).T
real_cnt_y, real_start_stop = create_cnt_y_and_start_stop_samples(
n_samples, events ,fs, name_to_start_codes, epoch_ival_ms, name_to_stop_codes)
real_cnt_y, real_start_stop = _create_cnt_y_and_trial_bounds_from_start_stop(
n_samples, events ,fs, name_to_start_codes, epoch_ival_ms,
name_to_stop_codes)
np.testing.assert_array_equal(cnt_y, real_cnt_y)
np.testing.assert_array_equal(start_stop, real_start_stop)
@@ -61,3 +70,279 @@ def test_cnt_y_start_stop_samples_two_class_with_both_appearing():
cnt_y=[[1, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0]],
start_stop=[(0, 2), (3, 6)])
def check_signal_target_from_start_and_ival(data, events, fs, name_to_codes,
epoch_ival_ms, expected_X, expected_y):
data = np.array(data)
events = np.array(events)
name_to_codes = OrderedDict(name_to_codes)
out_set = _create_signal_target_from_start_and_ival(
data, events, fs, name_to_codes, epoch_ival_ms,
one_hot_labels=False, one_label_per_trial=True)
np.testing.assert_array_equal(out_set.y, expected_y)
np.testing.assert_allclose(out_set.X, expected_X)
def test_signal_target_from_start_and_ival():
check_signal_target_from_start_and_ival(
data=[np.arange(10)], events=[(0, 1)], fs=100, name_to_codes=[('A', 1)],
epoch_ival_ms=[0, 30],
expected_X=[[[0, 1, 2]]], expected_y=[0])
def test_signal_target_from_start_and_ival_ignored_marker():
check_signal_target_from_start_and_ival(
data=[np.arange(10)], events=[(0, 1), (5, 2)], fs=100,
name_to_codes=[('A', 1)], epoch_ival_ms=[0, 30],
expected_X=[[[0, 1, 2]]], expected_y=[0])
def test_signal_target_from_start_and_ival_two_class():
check_signal_target_from_start_and_ival(
data=[np.arange(10)], events=[(0,1), (5,2)], fs=100,
name_to_codes=[('A', 1), ('B', 2)], epoch_ival_ms=[0, 30],
expected_X=[[[0,1,2]], [[5,6,7]]], expected_y=[0,1])
def test_signal_target_from_start_and_ival_too_early():
check_signal_target_from_start_and_ival(
data=[np.arange(10)], events=[(0, 1)], fs=100, name_to_codes=[('A', 1)],
epoch_ival_ms=[-10, 30],
expected_X=[], expected_y=[])
def test_signal_target_from_start_and_ival_too_late():
check_signal_target_from_start_and_ival(
data=[np.arange(10)], events=[(8, 1)], fs=100, name_to_codes=[('A', 1)],
epoch_ival_ms=[0, 30],
expected_X=[], expected_y=[])
def test_signal_target_from_start_and_ival_overlapping():
check_signal_target_from_start_and_ival(
data=[np.arange(10)], events=[(0, 1), (1, 2)], fs=100,
name_to_codes=[('A', 1), ('B', 2)], epoch_ival_ms=[0, 30],
expected_X=[[[0, 1, 2]], [[1, 2, 3]]], expected_y=[0, 1])
def check_signal_target_from_start_and_stop(data, events, fs, name_to_codes,
epoch_ival_ms, name_to_stop_codes,
pad_to_n_samples,
expected_X, expected_y, ):
data = np.array(data)
events = np.array(events)
name_to_codes = OrderedDict(name_to_codes)
name_to_stop_codes = OrderedDict(name_to_stop_codes)
out_set = _create_signal_target_from_start_and_stop(
data, events, fs, name_to_codes, epoch_ival_ms, name_to_stop_codes,
pad_to_n_samples, one_hot_labels=False, one_label_per_trial=True)
np.testing.assert_array_equal(out_set.y, expected_y)
assert len(out_set.X) == len(expected_X)
for x_out, x_expected in zip(out_set.X, expected_X):
np.testing.assert_allclose(x_out, x_expected)
def test_signal_target_from_start_and_stop():
check_signal_target_from_start_and_stop(
data=[np.arange(10)], events=[(0, 1), (2, -1)], fs=100,
name_to_codes=[('A', 1)], epoch_ival_ms=[0, 20],
name_to_stop_codes=[('A', -1)],
pad_to_n_samples=None,
expected_X=[[[0, 1, 2, 3]]], expected_y=[0])
def test_signal_target_from_start_and_stop_different_ival():
check_signal_target_from_start_and_stop(
data=[np.arange(10)], events=[(0, 1), (2, -1)], fs=100,
name_to_codes=[('A', 1)], epoch_ival_ms=[0, -10],
name_to_stop_codes=[('A', -1)],
pad_to_n_samples=None,
expected_X=[[[0, ]]], expected_y=[0])
def test_signal_target_from_start_and_stop_ignored_marker():
check_signal_target_from_start_and_stop(
data=[np.arange(10)], events=[(0, 1), (1, 3), (2, -1)], fs=100,
name_to_codes=[('A', 1)], epoch_ival_ms=[0, -10],
name_to_stop_codes=[('A', -1)],
pad_to_n_samples=None,
expected_X=[[[0, ]]], expected_y=[0])
def test_signal_target_from_start_and_stop_too_early():
check_signal_target_from_start_and_stop(
data=[np.arange(10)], events=[(0, 1), (2, -1)], fs=100,
name_to_codes=[('A', 1)], epoch_ival_ms=[-10, 20],
name_to_stop_codes=[('A', -1)],
pad_to_n_samples=None,
expected_X=[], expected_y=[])
def test_signal_target_from_start_and_stop_too_late():
check_signal_target_from_start_and_stop(
data=[np.arange(10)], events=[(0, 1), (2, -1), (8, 1), (9, -1)], fs=100,
name_to_codes=[('A', 1)], epoch_ival_ms=[0, 20],
name_to_stop_codes=[('A', -1)],
pad_to_n_samples=None,
expected_X=[[[0, 1, 2, 3]]], expected_y=[0])
def test_signal_target_from_start_and_stop_overlapping():
check_signal_target_from_start_and_stop(
data=[np.arange(10)],
events=[(0, 1), (2, -1), (2, 2), (5, -2)],
fs=100,
name_to_codes=[('A', 1), ('B', 2)],
epoch_ival_ms=[0, 20],
name_to_stop_codes=[('A', -1), ('B', -2)],
pad_to_n_samples=None,
expected_X=[[[0, 1, 2, 3]], [[2, 3, 4, 5, 6]]],
expected_y=[0, 1])
def test_signal_target_from_start_and_stop_stop_missing():
check_signal_target_from_start_and_stop(
data=[np.arange(10)],
events=[(0, 1), (2, -1), (2, 2), ],
fs=100,
name_to_codes=[('A', 1), ('B', 2)],
epoch_ival_ms=[0, 20],
name_to_stop_codes=[('A', -1), ('B', -2)],
pad_to_n_samples=None,
expected_X=[[[0, 1, 2, 3]], ],
expected_y=[0, ])
def test_signal_target_from_start_and_stop_wrong_stop_for_start():
# wrong stop for start
# expect assertion raised
with pytest.raises(AssertionError):
check_signal_target_from_start_and_stop(
data=[np.arange(10)],
events=[(0, 1), (2, -1), (2, 2), (3, -1)],
fs=100,
name_to_codes=[('A', 1), ('B', 2)],
epoch_ival_ms=[0, 20],
name_to_stop_codes=[('A', -1), ('B', -2)],
pad_to_n_samples=None,
expected_X=[[[0, 1, 2, 3]], ],
expected_y=[0, ])
def check_add_breaks(
events,
fs,
break_start_code,
break_stop_code,
name_to_start_codes,
name_to_stop_codes,
min_break_length_ms,
max_break_length_ms,
break_start_offset_ms,
break_stop_offset_ms,
expected_events,
):
events = np.array(events)
name_to_start_codes = OrderedDict(name_to_start_codes)
name_to_stop_codes = OrderedDict(name_to_stop_codes)
events_with_breaks = add_breaks(
events, fs, break_start_code, break_stop_code, name_to_start_codes,
name_to_stop_codes, min_break_length_ms=min_break_length_ms,
max_break_length_ms=max_break_length_ms,
break_start_offset_ms=break_start_offset_ms, break_stop_offset_ms=break_stop_offset_ms)
np.testing.assert_array_equal(events_with_breaks,
expected_events)
def test_add_breaks_no_break():
check_add_breaks(
events=np.array([(0, 1), (2, -1)]),
fs=100,
break_start_code=-3,
break_stop_code=-4,
name_to_start_codes=[('A', 1), ],
name_to_stop_codes=[('A', -1), ],
min_break_length_ms=None,
max_break_length_ms=None,
break_start_offset_ms=None,
break_stop_offset_ms=None,
expected_events=[(0, 1), (2, -1), ])
def test_add_breaks_one_break():
# a break added!
check_add_breaks(
events=np.array([(0, 1), (2, -1), (5, 2)]),
fs=100,
break_start_code=-3,
break_stop_code=-4,
name_to_start_codes=[('A', 1), ('B', 2), ],
name_to_stop_codes=[('A', -1), ],
min_break_length_ms=None,
max_break_length_ms=None,
break_start_offset_ms=None,
break_stop_offset_ms=None,
expected_events=[(0, 1), (2, -1), (3, -3), (4, -4), (5, 2), ])
def test_add_breaks_break_within_bound():
# a break added within bounds!
check_add_breaks(
events=np.array([(0, 1), (2, -1), (5, 2)]),
fs=100,
break_start_code=-3,
break_stop_code=-4,
name_to_start_codes=[('A', 1), ('B', 2), ],
name_to_stop_codes=[('A', -1), ],
min_break_length_ms=10,
max_break_length_ms=None,
break_start_offset_ms=None,
break_stop_offset_ms=None,
expected_events=[(0, 1), (2, -1), (3, -3), (4, -4), (5, 2), ])
def test_add_breaks_too_short():
# not added, too short!
check_add_breaks(
events=np.array([(0, 1), (2, -1), (5, 2)]),
fs=100,
break_start_code=-3,
break_stop_code=-4,
name_to_start_codes=[('A', 1), ('B', 2), ],
name_to_stop_codes=[('A', -1), ],
min_break_length_ms=20,
max_break_length_ms=None,
break_start_offset_ms=None,
break_stop_offset_ms=None,
expected_events=[(0, 1), (2, -1), (5, 2), ])
def test_add_breaks_too_long():
# a break added within both upper and lower bound!
check_add_breaks(
events=np.array([(0, 1), (2, -1), (6, 2)]),
fs=100,
break_start_code=-3,
break_stop_code=-4,
name_to_start_codes=[('A', 1), ('B', 2), ],
name_to_stop_codes=[('A', -1), ],
min_break_length_ms=None,
max_break_length_ms=10,
break_start_offset_ms=None,
break_stop_offset_ms=None,
expected_events=[(0, 1), (2, -1), (6, 2), ])
def test_add_breaks_within_both_bounds():
# a break added within both upper and lower bound!
check_add_breaks(
events=np.array([(0, 1), (2, -1), (5, 2)]),
fs=100,
break_start_code=-3,
break_stop_code=-4,
name_to_start_codes=[('A', 1), ('B', 2), ],
name_to_stop_codes=[('A', -1), ],
min_break_length_ms=10,
max_break_length_ms=30,
break_start_offset_ms=None,
break_stop_offset_ms=None,
expected_events=[(0, 1), (2, -1), (3, -3), (4, -4), (5, 2), ])