now all trial segmentation happens by creating continuous labels first
Esse commit está contido em:
@@ -13,9 +13,12 @@ marker_def = OrderedDict(
|
||||
(['Right', 1], ['Left', [2]], ['Rest', 3], ['Feet', 4]))
|
||||
|
||||
|
||||
def create_signal_target_from_raw_mne(raw, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes=None,
|
||||
pad_to_n_samples=None):
|
||||
def create_signal_target_from_raw_mne(
|
||||
raw, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes=None,
|
||||
prepad_trials_to_n_samples=None,
|
||||
one_hot_labels=False,
|
||||
one_label_per_trial=True):
|
||||
"""
|
||||
Create SignalTarget set from given `mne.io.RawArray`.
|
||||
|
||||
@@ -36,9 +39,14 @@ def create_signal_target_from_raw_mne(raw, name_to_start_codes, epoch_ival_ms,
|
||||
Dictionary mapping class names to stop marker code or stop marker codes.
|
||||
Order does not matter, dictionary should contain each class in
|
||||
`name_to_codes` dictionary.
|
||||
pad_to_n_samples: int
|
||||
prepad_trials_to_n_samples: int
|
||||
Pad trials that would be too short with the signal before it (only
|
||||
valid name_to_stop_codes is not None.
|
||||
valid if name_to_stop_codes is not None).
|
||||
one_hot_labels: bool, optional
|
||||
Whether to have the labels in a one-hot format, e.g. [0,0,1] or to
|
||||
have them just as an int, e.g. 2
|
||||
one_label_per_trial: bool, optional
|
||||
Whether to have a timeseries of labels or just a single label per trial.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -49,14 +57,18 @@ def create_signal_target_from_raw_mne(raw, name_to_start_codes, epoch_ival_ms,
|
||||
events = np.array([raw.info['events'][:,0],
|
||||
raw.info['events'][:,2]]).T
|
||||
fs = raw.info['sfreq']
|
||||
return create_signal_target(data, events, fs, name_to_start_codes,
|
||||
epoch_ival_ms,
|
||||
name_to_stop_codes=name_to_stop_codes,
|
||||
pad_to_n_samples=pad_to_n_samples)
|
||||
return create_signal_target(
|
||||
data, events, fs, name_to_start_codes,
|
||||
epoch_ival_ms,
|
||||
name_to_stop_codes=name_to_stop_codes,
|
||||
prepad_trials_to_n_samples=prepad_trials_to_n_samples,
|
||||
one_hot_labels=one_hot_labels,
|
||||
one_label_per_trial=one_label_per_trial)
|
||||
|
||||
|
||||
def create_signal_target(data, events, fs, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes=None, pad_to_n_samples=None):
|
||||
name_to_stop_codes=None, prepad_trials_to_n_samples=None,
|
||||
one_hot_labels=False, one_label_per_trial=True):
|
||||
"""
|
||||
Create SignalTarget set given continuous data.
|
||||
|
||||
@@ -83,10 +95,14 @@ def create_signal_target(data, events, fs, name_to_start_codes, epoch_ival_ms,
|
||||
Dictionary mapping class names to stop marker code or stop marker codes.
|
||||
Order does not matter, dictionary should contain each class in
|
||||
`name_to_codes` dictionary.
|
||||
pad_to_n_samples: int
|
||||
prepad_trials_to_n_samples: int, optional
|
||||
Pad trials that would be too short with the signal before it (only
|
||||
valid name_to_stop_codes is not None.
|
||||
|
||||
valid if name_to_stop_codes is not None).
|
||||
one_hot_labels: bool, optional
|
||||
Whether to have the labels in a one-hot format, e.g. [0,0,1] or to
|
||||
have them just as an int, e.g. 2
|
||||
one_label_per_trial: bool, optional
|
||||
Whether to have a timeseries of labels or just a single label per trial.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -96,11 +112,15 @@ def create_signal_target(data, events, fs, name_to_start_codes, epoch_ival_ms,
|
||||
"""
|
||||
if name_to_stop_codes is None:
|
||||
return _create_signal_target_from_start_and_ival(
|
||||
data, events, fs, name_to_start_codes, epoch_ival_ms)
|
||||
data, events, fs, name_to_start_codes, epoch_ival_ms,
|
||||
one_hot_labels=one_hot_labels,
|
||||
one_label_per_trial=one_label_per_trial)
|
||||
else:
|
||||
return _create_signal_target_from_start_and_stop(
|
||||
data, events, fs, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes, pad_to_n_samples)
|
||||
name_to_stop_codes, prepad_trials_to_n_samples,
|
||||
one_hot_labels=one_hot_labels,
|
||||
one_label_per_trial=one_label_per_trial)
|
||||
|
||||
|
||||
def _to_mrk_code_to_name_and_y(name_to_codes):
|
||||
@@ -119,255 +139,72 @@ def _to_mrk_code_to_name_and_y(name_to_codes):
|
||||
|
||||
|
||||
def _create_signal_target_from_start_and_ival(
|
||||
data, events, fs, name_to_codes, epoch_ival_ms):
|
||||
data, events, fs, name_to_codes, epoch_ival_ms,
|
||||
one_hot_labels, one_label_per_trial):
|
||||
cnt_y, i_start_stops = _create_cnt_y_and_trial_bounds_from_start_and_ival(
|
||||
data.shape[1], events, fs, name_to_codes, epoch_ival_ms
|
||||
)
|
||||
signal_target = _create_signal_target_from_cnt_y_start_stops(
|
||||
data, cnt_y, i_start_stops, prepad_trials_to_n_samples=None,
|
||||
one_hot_labels=one_hot_labels,
|
||||
one_label_per_trial=one_label_per_trial)
|
||||
# make into arrray as all should have same dimensions
|
||||
signal_target.X = np.array(signal_target.X, dtype=np.float32)
|
||||
signal_target.y = np.array(signal_target.y, dtype=np.int64)
|
||||
return signal_target
|
||||
|
||||
|
||||
def _create_cnt_y_and_trial_bounds_from_start_and_ival(
|
||||
n_samples, events, fs, name_to_start_codes, epoch_ival_ms):
|
||||
ival_in_samples = ms_to_samples(np.array(epoch_ival_ms), fs)
|
||||
start_offset = np.int32(np.round(ival_in_samples[0]))
|
||||
# we will use ceil but exclusive...
|
||||
stop_offset = np.int32(np.ceil(ival_in_samples[1]))
|
||||
mrk_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_codes)
|
||||
|
||||
mrk_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
|
||||
class_to_n_trials = Counter()
|
||||
X = []
|
||||
y = []
|
||||
|
||||
n_classes = len(name_to_start_codes)
|
||||
cnt_y = np.zeros((n_samples, n_classes), dtype=np.int64)
|
||||
i_start_stops = []
|
||||
for i_sample, mrk_code in zip(events[:, 0], events[:, 1]):
|
||||
start_sample = int(i_sample) + start_offset
|
||||
stop_sample = int(i_sample) + stop_offset
|
||||
if mrk_code in mrk_code_to_name_and_y:
|
||||
if start_sample < 0:
|
||||
log.warning("Ignore trial with marker code {:d}, would start at "
|
||||
"sample {:d}".format(mrk_code, start_sample))
|
||||
log.warning(
|
||||
"Ignore trial with marker code {:d}, would start at "
|
||||
"sample {:d}".format(mrk_code, start_sample))
|
||||
continue
|
||||
if stop_sample > data.shape[1]:
|
||||
if stop_sample > n_samples:
|
||||
log.warning("Ignore trial with marker code {:d}, would end at "
|
||||
"sample {:d} of {:d}".format(mrk_code, stop_sample-1,
|
||||
data.shape[1]-1))
|
||||
"sample {:d} of {:d}".format(
|
||||
mrk_code, stop_sample - 1, n_samples - 1))
|
||||
continue
|
||||
|
||||
name, this_y = mrk_code_to_name_and_y[mrk_code]
|
||||
X.append(data[:, start_sample:stop_sample].astype(np.float32))
|
||||
y.append(np.int64(this_y))
|
||||
i_start_stops.append((start_sample, stop_sample))
|
||||
cnt_y[start_sample:stop_sample, this_y] = 1
|
||||
class_to_n_trials[name] += 1
|
||||
log.info("Trial per class:\n{:s}".format(str(class_to_n_trials)))
|
||||
return SignalAndTarget(np.array(X), np.array(y))
|
||||
return cnt_y, i_start_stops
|
||||
|
||||
|
||||
def _create_signal_target_from_start_and_stop(
|
||||
data, events, fs, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes, pad_to_n_samples=None):
|
||||
name_to_stop_codes, prepad_trials_to_n_samples,
|
||||
one_hot_labels, one_label_per_trial, ):
|
||||
assert np.array_equal(list(name_to_start_codes.keys()),
|
||||
list(name_to_stop_codes.keys()))
|
||||
ival_in_samples = ms_to_samples(np.array(epoch_ival_ms), fs)
|
||||
start_offset = np.int32(np.round(ival_in_samples[0]))
|
||||
# we will use ceil but exclusive...
|
||||
stop_offset = np.int32(np.ceil(ival_in_samples[1]))
|
||||
start_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
|
||||
# Ensure all stop marker codes are iterables
|
||||
for name in name_to_stop_codes:
|
||||
codes = name_to_stop_codes[name]
|
||||
if not hasattr(codes, '__len__'):
|
||||
name_to_stop_codes[name] = [codes]
|
||||
all_stop_codes = np.concatenate(list(name_to_stop_codes.values()))
|
||||
class_to_n_trials = Counter()
|
||||
X = []
|
||||
y = []
|
||||
|
||||
event_samples = events[:, 0]
|
||||
event_codes = events[:, 1]
|
||||
|
||||
i_event = 0
|
||||
first_start_code_found = False
|
||||
while i_event < len(events):
|
||||
while i_event < len(events) and (
|
||||
event_codes[i_event] not in start_code_to_name_and_y):
|
||||
i_event += 1
|
||||
if i_event < len(events):
|
||||
start_sample = event_samples[i_event]
|
||||
start_code = event_codes[i_event]
|
||||
start_name = start_code_to_name_and_y[start_code][0]
|
||||
start_y = start_code_to_name_and_y[start_code][1]
|
||||
i_event += 1
|
||||
first_start_code_found = True
|
||||
waiting_for_end_code = True
|
||||
|
||||
while i_event < len(events) and (
|
||||
event_codes[i_event] not in all_stop_codes):
|
||||
if event_codes[i_event] in start_code_to_name_and_y:
|
||||
log.warning(
|
||||
"New start marker {:.0f} at {:.0f} samples found, "
|
||||
"no end marker for earlier start marker {:.0f} "
|
||||
"at {:.0f} samples found.".format(
|
||||
event_codes[i_event], event_samples[i_event],
|
||||
start_code, start_sample))
|
||||
start_sample = event_samples[i_event]
|
||||
start_name = start_code_to_name_and_y[start_code][0]
|
||||
start_code = event_codes[i_event]
|
||||
start_y = start_code_to_name_and_y[start_code][1]
|
||||
i_event += 1
|
||||
if i_event == len(events):
|
||||
if waiting_for_end_code:
|
||||
log.warning(("No end marker for start marker code {:.0f} "
|
||||
"at sample {:.0f} found.").format(start_code,
|
||||
start_sample))
|
||||
elif (not first_start_code_found):
|
||||
log.warning("No markers found at all.")
|
||||
break
|
||||
stop_sample = event_samples[i_event]
|
||||
stop_code = event_codes[i_event]
|
||||
assert stop_code in name_to_stop_codes[start_name]
|
||||
i_start = int(start_sample) + start_offset
|
||||
i_stop = int(stop_sample) + stop_offset
|
||||
waiting_for_end_code = False
|
||||
if (pad_to_n_samples is not None) and (
|
||||
(i_stop - i_start) < pad_to_n_samples):
|
||||
if i_stop < pad_to_n_samples:
|
||||
log.warning("Could not pad trial enough, therefore not "
|
||||
"not using trial from {:d} to {:d}".format(
|
||||
i_start, i_stop
|
||||
))
|
||||
continue
|
||||
i_start = i_stop - pad_to_n_samples
|
||||
if i_start < 0:
|
||||
log.warning("Ignore trial with start code {:d}, would start at "
|
||||
"sample {:d}".format(start_code, i_start))
|
||||
continue
|
||||
if i_stop > data.shape[1]:
|
||||
log.warning("Ignore trial with stop code {:d}, would end at "
|
||||
"sample {:d} of {:d}".format(stop_code, i_stop - 1,
|
||||
data.shape[1] - 1))
|
||||
continue
|
||||
|
||||
X.append(data[:, i_start:i_stop].astype(np.float32))
|
||||
y.append(np.int64(start_y))
|
||||
class_to_n_trials[start_name] += 1
|
||||
|
||||
log.info("Trial per class:\n{:s}".format(str(class_to_n_trials)))
|
||||
return SignalAndTarget(X, np.array(y))
|
||||
cnt_y, i_start_stops = _create_cnt_y_and_trial_bounds_from_start_stop(
|
||||
data.shape[1],events, fs,name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes)
|
||||
signal_target = _create_signal_target_from_cnt_y_start_stops(
|
||||
data, cnt_y, i_start_stops,
|
||||
prepad_trials_to_n_samples=prepad_trials_to_n_samples,
|
||||
one_hot_labels=one_hot_labels, one_label_per_trial=one_label_per_trial)
|
||||
return signal_target
|
||||
|
||||
|
||||
def add_breaks(
|
||||
events, fs, break_start_code, break_stop_code, name_to_start_codes,
|
||||
name_to_stop_codes, min_break_length_ms=None,
|
||||
max_break_length_ms=None, break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None):
|
||||
"""
|
||||
Add break events to given events.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
events: 2d-array
|
||||
Dimensions: Number of events, 2. For each event, should contain sample
|
||||
index and marker code.
|
||||
fs: number
|
||||
Sampling rate.
|
||||
break_start_code: int
|
||||
Marker code that will be used for break start markers.
|
||||
break_stop_code: int
|
||||
Marker code that will be used for break stop markers.
|
||||
name_to_start_codes: OrderedDict (str -> int or list of int)
|
||||
Ordered dictionary mapping class names to start marker code or
|
||||
start marker codes.
|
||||
name_to_stop_codes: dict (str -> int or list of int), optional
|
||||
Dictionary mapping class names to stop marker code or stop marker codes.
|
||||
min_break_length_ms: number, optional
|
||||
Minimum length in milliseconds a break should have to be included.
|
||||
max_break_length_ms: number, optional
|
||||
Maximum length in milliseconds a break can have to be included.
|
||||
|
||||
Returns
|
||||
-------
|
||||
events: 2d-array
|
||||
Events with break start and stop markers.
|
||||
"""
|
||||
min_samples = (None if min_break_length_ms is None
|
||||
else ms_to_samples(min_break_length_ms, fs))
|
||||
max_samples = (None if max_break_length_ms is None
|
||||
else ms_to_samples(max_break_length_ms, fs))
|
||||
orig_events = events
|
||||
break_starts, break_stops = _extract_break_start_stop_ms(
|
||||
events, name_to_start_codes, name_to_stop_codes)
|
||||
|
||||
break_durations = break_stops - break_starts
|
||||
valid_mask = np.array([True] * len(break_starts))
|
||||
if min_samples is not None:
|
||||
valid_mask[break_durations < min_samples] = False
|
||||
if max_samples is not None:
|
||||
valid_mask[break_durations > max_samples] = False
|
||||
if sum(valid_mask) == 0:
|
||||
return deepcopy(events)
|
||||
break_starts = break_starts[valid_mask]
|
||||
break_stops = break_stops[valid_mask]
|
||||
if break_start_offset_ms is not None:
|
||||
break_starts += int(round(ms_to_samples(break_start_offset_ms, fs)))
|
||||
if break_stop_offset_ms is not None:
|
||||
break_stops += int(round(ms_to_samples(break_stop_offset_ms, fs)))
|
||||
break_events = np.zeros((len(break_starts) * 2, 2))
|
||||
break_events[0::2,0] = break_starts
|
||||
break_events[1::2,0] = break_stops
|
||||
break_events[0::2,1] = break_start_code
|
||||
break_events[1::2,1] = break_stop_code
|
||||
|
||||
new_events = np.concatenate((orig_events, break_events))
|
||||
# sort events
|
||||
sort_order = np.argsort(new_events[:,0], kind='mergesort')
|
||||
new_events = new_events[sort_order]
|
||||
return new_events
|
||||
|
||||
|
||||
def _extract_break_start_stop_ms(events, name_to_start_codes,
|
||||
name_to_stop_codes):
|
||||
assert len(events[0]) == 2, "expect only 2dimensional event array here"
|
||||
start_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
|
||||
# Ensure all stop marker codes are iterables
|
||||
for name in name_to_stop_codes:
|
||||
codes = name_to_stop_codes[name]
|
||||
if not hasattr(codes, '__len__'):
|
||||
name_to_stop_codes[name] = [codes]
|
||||
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())).astype(np.int32)
|
||||
event_samples = events[:, 0]
|
||||
event_codes = events[:, 1]
|
||||
|
||||
break_starts = []
|
||||
break_stops = []
|
||||
i_event = 0
|
||||
while i_event < len(events):
|
||||
while (i_event < len(events)) and (
|
||||
event_codes[i_event] not in all_stop_codes):
|
||||
i_event += 1
|
||||
if i_event < len(events):
|
||||
# one sample after start
|
||||
stop_sample = event_samples[i_event]
|
||||
stop_code = event_codes[i_event]
|
||||
i_event += 1
|
||||
while (i_event < len(events)) and (
|
||||
event_codes[i_event] not in start_code_to_name_and_y):
|
||||
if event_codes[i_event] in all_stop_codes:
|
||||
log.warning(
|
||||
"New end marker {:.0f} at {:.0f} samples found, "
|
||||
"no start marker for earlier end marker {:.0f} "
|
||||
"at {:.0f} samples found.".format(
|
||||
event_codes[i_event],
|
||||
event_samples[i_event],
|
||||
stop_code, stop_sample))
|
||||
stop_sample = event_samples[i_event] + 1
|
||||
stop_code = event_codes[i_event]
|
||||
i_event += 1
|
||||
|
||||
if i_event == len(events):
|
||||
break
|
||||
|
||||
start_sample = event_samples[i_event]
|
||||
start_code = event_codes[i_event]
|
||||
assert start_code in start_code_to_name_and_y
|
||||
# let's start one after stop of the trial and stop one efore
|
||||
# start of the trial to ensure that markers will be
|
||||
# in right order
|
||||
break_starts.append(stop_sample + 1)
|
||||
break_stops.append(start_sample - 1)
|
||||
return np.array(break_starts), np.array(break_stops)
|
||||
|
||||
|
||||
def create_cnt_y_and_start_stop_samples(
|
||||
def _create_cnt_y_and_trial_bounds_from_start_stop(
|
||||
n_samples, events, fs, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes):
|
||||
"""
|
||||
@@ -397,6 +234,13 @@ def create_cnt_y_and_start_stop_samples(
|
||||
Order does not matter, dictionary should contain each class in
|
||||
`name_to_codes` dictionary.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cnt_y: 2d-array
|
||||
Timeseries of one-hot-labels, time x classes.
|
||||
trial_bounds: list of (int,int)
|
||||
List of (trial_start, trial_stop) tuples.
|
||||
|
||||
|
||||
"""
|
||||
assert np.array_equal(list(name_to_start_codes.keys()),
|
||||
@@ -412,7 +256,8 @@ def create_cnt_y_and_start_stop_samples(
|
||||
codes = name_to_stop_codes[name]
|
||||
if not hasattr(codes, '__len__'):
|
||||
name_to_stop_codes[name] = [codes]
|
||||
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())).astype(np.int64)
|
||||
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())
|
||||
).astype(np.int64)
|
||||
class_to_n_trials = Counter()
|
||||
n_classes = len(name_to_start_codes)
|
||||
cnt_y = np.zeros((n_samples, n_classes), dtype=np.int64)
|
||||
@@ -471,137 +316,21 @@ def create_cnt_y_and_start_stop_samples(
|
||||
return cnt_y, i_start_stops
|
||||
|
||||
|
||||
def create_signal_target_with_breaks_from_mne(
|
||||
cnt, name_to_start_codes,
|
||||
trial_segment_ival_ms,
|
||||
name_to_stop_codes,
|
||||
min_break_length_ms, max_break_length_ms,
|
||||
break_segment_ival_ms,
|
||||
pad_trials_to_n_samples=None):
|
||||
assert 'Break' not in name_to_start_codes
|
||||
# Create new marker codes for start and stop of breaks
|
||||
# Use marker codes that did not exist in the given marker codes...
|
||||
all_start_codes = np.concatenate(
|
||||
[np.atleast_1d(vals) for vals in name_to_start_codes.values()])
|
||||
all_stop_codes = np.concatenate(
|
||||
[np.atleast_1d(vals) for vals in name_to_stop_codes.values()])
|
||||
break_start_code = -1
|
||||
while break_start_code in np.concatenate((all_start_codes, all_stop_codes)):
|
||||
break_start_code -= 1
|
||||
break_stop_code = break_start_code - 1
|
||||
while break_stop_code in np.concatenate((all_start_codes, all_stop_codes)):
|
||||
break_stop_code -= 1
|
||||
|
||||
events = cnt.info['events'][:, [0, 2]]
|
||||
# later trial segment ival will be added when creating set
|
||||
# so remove it here
|
||||
break_segment_ival_ms = np.array(break_segment_ival_ms) - (
|
||||
np.array(trial_segment_ival_ms))
|
||||
events_with_breaks = add_breaks(events, cnt.info['sfreq'],
|
||||
break_start_code, break_stop_code,
|
||||
name_to_start_codes, name_to_stop_codes,
|
||||
min_break_length_ms=min_break_length_ms,
|
||||
max_break_length_ms=max_break_length_ms,
|
||||
break_start_offset_ms=break_segment_ival_ms[
|
||||
0],
|
||||
break_stop_offset_ms=break_segment_ival_ms[
|
||||
1])
|
||||
|
||||
name_to_start_codes_with_breaks = deepcopy(name_to_start_codes)
|
||||
name_to_start_codes_with_breaks['Break'] = break_start_code
|
||||
name_to_stop_codes_with_breaks = deepcopy(name_to_stop_codes)
|
||||
name_to_stop_codes_with_breaks['Break'] = break_stop_code
|
||||
|
||||
|
||||
data = cnt.get_data()
|
||||
fs = cnt.info['sfreq']
|
||||
signal_target = create_signal_target_with_cnt_y(
|
||||
data, events_with_breaks, fs,
|
||||
name_to_start_codes_with_breaks, trial_segment_ival_ms,
|
||||
name_to_stop_codes_with_breaks,
|
||||
pad_to_n_samples=pad_trials_to_n_samples)
|
||||
|
||||
return signal_target
|
||||
|
||||
|
||||
def create_signal_target_with_cnt_y_from_raw_mne(
|
||||
raw, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes,
|
||||
pad_to_n_samples=None):
|
||||
data = raw.get_data()
|
||||
events = raw.info['events'][:,[0,2]]
|
||||
fs = raw.info['sfreq']
|
||||
return create_signal_target_with_cnt_y(data, events, fs,
|
||||
name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes,
|
||||
pad_to_n_samples=pad_to_n_samples
|
||||
)
|
||||
|
||||
|
||||
def create_signal_target_with_cnt_y(data, events, fs,
|
||||
name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes,
|
||||
pad_to_n_samples=None
|
||||
):
|
||||
"""
|
||||
Create a signal
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: 2d-array of number
|
||||
The continuous recorded data. Channels x times order.
|
||||
events: 2d-array
|
||||
Dimensions: Number of events, 2. For each event, should contain sample
|
||||
index and marker code.
|
||||
fs: number
|
||||
Sampling rate.
|
||||
name_to_start_codes: OrderedDict (str -> int or list of int)
|
||||
Ordered dictionary mapping class names to marker code or marker codes.
|
||||
y-labels will be assigned in increasing key order, i.e.
|
||||
first classname gets y-value 0, second classname y-value 1, etc.
|
||||
epoch_ival_ms: iterable of (int,int)
|
||||
Epoching interval in milliseconds. In case only `name_to_codes` given,
|
||||
represents start offset and stop offset from start markers. In case
|
||||
`name_to_stop_codes` given, represents offset from start marker
|
||||
and offset from stop marker. E.g. [500, -500] would mean 500ms
|
||||
after the start marker until 500 ms before the stop marker.
|
||||
name_to_stop_codes: dict (str -> int or list of int), optional
|
||||
Dictionary mapping class names to stop marker code or stop marker codes.
|
||||
Order does not matter, dictionary should contain each class in
|
||||
`name_to_codes` dictionary.
|
||||
pad_to_n_samples: int, optional
|
||||
Use signal before trial start to pad trials that are otherwise too small.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dataset: :class:`.SignalAndTarget`
|
||||
Dataset with `X` as the trial signals and `y` as the trial labels,
|
||||
one array per trial, as labels can be different within one trial.
|
||||
"""
|
||||
cnt_y, i_start_stops = create_cnt_y_and_start_stop_samples(
|
||||
data.shape[1], events, fs,
|
||||
name_to_start_codes,
|
||||
epoch_ival_ms, name_to_stop_codes, )
|
||||
return create_signal_target_from_cnt_y_start_stops(
|
||||
data, cnt_y, i_start_stops, pad_to_n_samples, one_hot_y=False,
|
||||
one_label_per_trial=True)
|
||||
|
||||
|
||||
def create_signal_target_from_cnt_y_start_stops(
|
||||
def _create_signal_target_from_cnt_y_start_stops(
|
||||
data,
|
||||
cnt_y,
|
||||
i_start_stops,
|
||||
pad_to_n_samples,
|
||||
one_hot_y,
|
||||
prepad_trials_to_n_samples,
|
||||
one_hot_labels,
|
||||
one_label_per_trial):
|
||||
if pad_to_n_samples is not None:
|
||||
if prepad_trials_to_n_samples is not None:
|
||||
new_i_start_stops = []
|
||||
for i_start, i_stop in i_start_stops:
|
||||
if (i_stop - i_start) > pad_to_n_samples:
|
||||
if (i_stop - i_start) > prepad_trials_to_n_samples:
|
||||
new_i_start_stops.append((i_start, i_stop))
|
||||
elif i_stop >= pad_to_n_samples:
|
||||
elif i_stop >= prepad_trials_to_n_samples:
|
||||
new_i_start_stops.append(
|
||||
(i_stop - pad_to_n_samples, i_stop))
|
||||
(i_stop - prepad_trials_to_n_samples, i_stop))
|
||||
else:
|
||||
log.warning("Could not pad trial enough, therefore not "
|
||||
"not using trial from {:d} to {:d}".format(
|
||||
@@ -622,30 +351,278 @@ def create_signal_target_from_cnt_y_start_stops(
|
||||
))
|
||||
continue
|
||||
if i_stop > data.shape[1]:
|
||||
log.warning("Trial stop too late, therefore not "
|
||||
log.warning("Trial stop too late (past {:d}), therefore not "
|
||||
"not using trial from {:d} to {:d}".format(
|
||||
data.shape[1] - 1,
|
||||
i_start, i_stop
|
||||
))
|
||||
continue
|
||||
X.append(data[:, i_start:i_stop].astype(np.float32))
|
||||
y.append(cnt_y[i_start:i_stop])
|
||||
|
||||
if not one_hot_y:
|
||||
# take last label always
|
||||
if one_label_per_trial:
|
||||
new_y = []
|
||||
for this_y in y:
|
||||
# if destroying one hot later, just set most occuring class to 1
|
||||
unique_labels, counts = np.unique(
|
||||
this_y, axis=0, return_counts=True)
|
||||
if not one_hot_labels:
|
||||
meaned_y = np.mean(this_y, axis=0)
|
||||
this_new_y = np.zeros_like(meaned_y)
|
||||
this_new_y[np.argmax(meaned_y)] = 1
|
||||
else:
|
||||
# take most frequency occurring label combination
|
||||
this_new_y = unique_labels[np.argmax(counts)]
|
||||
|
||||
if len(unique_labels) > 1:
|
||||
log.warning("Different labels within one trial: {:s},"
|
||||
"setting single trial label to {:s}".format(
|
||||
str(unique_labels), str(this_new_y)
|
||||
))
|
||||
new_y.append(this_new_y)
|
||||
y = new_y
|
||||
if not one_hot_labels:
|
||||
# change from one-hot-encoding to regular encoding
|
||||
# with -1 as indication none of the classes are present
|
||||
new_y = []
|
||||
for this_y in y:
|
||||
this_new_y = np.argmax(this_y, axis=1)
|
||||
this_new_y[np.sum(this_y, axis=1) == 0] = -1
|
||||
if one_label_per_trial:
|
||||
if np.sum(this_y) == 0:
|
||||
this_new_y = -1
|
||||
else:
|
||||
this_new_y = np.argmax(this_y)
|
||||
if np.sum(this_y) > 1:
|
||||
log.warning(
|
||||
"Have multiple active classes and will convert to "
|
||||
"lowest class")
|
||||
else:
|
||||
if np.max(np.sum(this_y, axis=1)) > 1:
|
||||
log.warning(
|
||||
"Have multiple active classes and will convert to "
|
||||
"lowest class")
|
||||
this_new_y = np.argmax(this_y, axis=1)
|
||||
this_new_y[np.sum(this_y, axis=1) == 0] = -1
|
||||
new_y.append(this_new_y)
|
||||
y = new_y
|
||||
# take last label always
|
||||
if one_label_per_trial:
|
||||
y = [this_y[-1] for this_y in y]
|
||||
y = np.array(y, dtype=np.int64)
|
||||
|
||||
return SignalAndTarget(X, y)
|
||||
|
||||
|
||||
def create_signal_target_with_breaks_from_mne(
|
||||
cnt, name_to_start_codes,
|
||||
trial_epoch_ival_ms,
|
||||
name_to_stop_codes,
|
||||
min_break_length_ms, max_break_length_ms,
|
||||
break_epoch_ival_ms,
|
||||
prepad_trials_to_n_samples=None):
|
||||
"""
|
||||
Create SignalTarget set from given `mne.io.RawArray`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cnt: `mne.io.RawArray`
|
||||
name_to_start_codes: OrderedDict (str -> int or list of int)
|
||||
Ordered dictionary mapping class names to marker code or marker codes.
|
||||
y-labels will be assigned in increasing key order, i.e.
|
||||
first classname gets y-value 0, second classname y-value 1, etc.
|
||||
trial_epoch_ival_ms: iterable of (int,int)
|
||||
Epoching interval in milliseconds. Represents offset from start marker
|
||||
and offset from stop marker. E.g. [500, -500] would mean 500ms
|
||||
after the start marker until 500 ms before the stop marker.
|
||||
name_to_stop_codes: dict (str -> int or list of int), optional
|
||||
Dictionary mapping class names to stop marker code or stop marker codes.
|
||||
Order does not matter, dictionary should contain each class in
|
||||
`name_to_codes` dictionary.
|
||||
min_break_length_ms: number
|
||||
Breaks below this length are excluded.
|
||||
max_break_length_ms: number
|
||||
Breaks above this length are excluded.
|
||||
break_epoch_ival_ms: number
|
||||
Break ival, offset from trial end to start of the break in ms and
|
||||
offset from trial start to end of break in ms.
|
||||
prepad_trials_to_n_samples: int
|
||||
Pad trials that would be too short with the signal before it (only
|
||||
valid if name_to_stop_codes is not None).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dataset: :class:`.SignalAndTarget`
|
||||
Dataset with `X` as the trial signals and `y` as the trial labels.
|
||||
Labels as timeseries and of integers, i.e., not one-hot encoded.
|
||||
"""
|
||||
assert 'Break' not in name_to_start_codes
|
||||
# Create new marker codes for start and stop of breaks
|
||||
# Use marker codes that did not exist in the given marker codes...
|
||||
all_start_codes = np.concatenate(
|
||||
[np.atleast_1d(vals) for vals in name_to_start_codes.values()])
|
||||
all_stop_codes = np.concatenate(
|
||||
[np.atleast_1d(vals) for vals in name_to_stop_codes.values()])
|
||||
break_start_code = -1
|
||||
while break_start_code in np.concatenate((all_start_codes, all_stop_codes)):
|
||||
break_start_code -= 1
|
||||
break_stop_code = break_start_code - 1
|
||||
while break_stop_code in np.concatenate((all_start_codes, all_stop_codes)):
|
||||
break_stop_code -= 1
|
||||
|
||||
events = cnt.info['events'][:, [0, 2]]
|
||||
# later trial segment ival will be added when creating set
|
||||
# so remove it here
|
||||
break_epoch_ival_ms = np.array(break_epoch_ival_ms) - (
|
||||
np.array(trial_epoch_ival_ms))
|
||||
events_with_breaks = add_breaks(events, cnt.info['sfreq'],
|
||||
break_start_code, break_stop_code,
|
||||
name_to_start_codes, name_to_stop_codes,
|
||||
min_break_length_ms=min_break_length_ms,
|
||||
max_break_length_ms=max_break_length_ms,
|
||||
break_start_offset_ms=break_epoch_ival_ms[
|
||||
0],
|
||||
break_stop_offset_ms=break_epoch_ival_ms[
|
||||
1])
|
||||
|
||||
name_to_start_codes_with_breaks = deepcopy(name_to_start_codes)
|
||||
name_to_start_codes_with_breaks['Break'] = break_start_code
|
||||
name_to_stop_codes_with_breaks = deepcopy(name_to_stop_codes)
|
||||
name_to_stop_codes_with_breaks['Break'] = break_stop_code
|
||||
|
||||
|
||||
data = cnt.get_data()
|
||||
fs = cnt.info['sfreq']
|
||||
signal_target = create_signal_target(
|
||||
data, events_with_breaks, fs,
|
||||
name_to_start_codes_with_breaks, trial_epoch_ival_ms,
|
||||
name_to_stop_codes_with_breaks,
|
||||
prepad_trials_to_n_samples=prepad_trials_to_n_samples,
|
||||
one_hot_labels=False,
|
||||
one_label_per_trial=False)
|
||||
|
||||
return signal_target
|
||||
|
||||
|
||||
def add_breaks(
|
||||
events, fs, break_start_code, break_stop_code, name_to_start_codes,
|
||||
name_to_stop_codes, min_break_length_ms=None,
|
||||
max_break_length_ms=None, break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None):
|
||||
"""
|
||||
Add break events to given events.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
events: 2d-array
|
||||
Dimensions: Number of events, 2. For each event, should contain sample
|
||||
index and marker code.
|
||||
fs: number
|
||||
Sampling rate.
|
||||
break_start_code: int
|
||||
Marker code that will be used for break start markers.
|
||||
break_stop_code: int
|
||||
Marker code that will be used for break stop markers.
|
||||
name_to_start_codes: OrderedDict (str -> int or list of int)
|
||||
Ordered dictionary mapping class names to start marker code or
|
||||
start marker codes.
|
||||
name_to_stop_codes: dict (str -> int or list of int), optional
|
||||
Dictionary mapping class names to stop marker code or stop marker codes.
|
||||
min_break_length_ms: number, optional
|
||||
Minimum length in milliseconds a break should have to be included.
|
||||
max_break_length_ms: number, optional
|
||||
Maximum length in milliseconds a break can have to be included.
|
||||
break_start_offset_ms: number, optional
|
||||
What offset from trial end to start of the break in ms.
|
||||
break_stop_offset_ms: number, optional
|
||||
What offset from next trial start end to previous break end in ms.
|
||||
|
||||
Returns
|
||||
-------
|
||||
events: 2d-array
|
||||
Events with break start and stop markers.
|
||||
"""
|
||||
min_samples = (None if min_break_length_ms is None
|
||||
else ms_to_samples(min_break_length_ms, fs))
|
||||
max_samples = (None if max_break_length_ms is None
|
||||
else ms_to_samples(max_break_length_ms, fs))
|
||||
orig_events = events
|
||||
break_starts, break_stops = _extract_break_start_stop_ms(
|
||||
events, name_to_start_codes, name_to_stop_codes)
|
||||
|
||||
break_durations = break_stops - break_starts
|
||||
valid_mask = np.array([True] * len(break_starts))
|
||||
if min_samples is not None:
|
||||
valid_mask[break_durations < min_samples] = False
|
||||
if max_samples is not None:
|
||||
valid_mask[break_durations > max_samples] = False
|
||||
if sum(valid_mask) == 0:
|
||||
return deepcopy(events)
|
||||
break_starts = break_starts[valid_mask]
|
||||
break_stops = break_stops[valid_mask]
|
||||
if break_start_offset_ms is not None:
|
||||
break_starts += int(round(ms_to_samples(break_start_offset_ms, fs)))
|
||||
if break_stop_offset_ms is not None:
|
||||
break_stops += int(round(ms_to_samples(break_stop_offset_ms, fs)))
|
||||
break_events = np.zeros((len(break_starts) * 2, 2))
|
||||
break_events[0::2, 0] = break_starts
|
||||
break_events[1::2, 0] = break_stops
|
||||
break_events[0::2, 1] = break_start_code
|
||||
break_events[1::2, 1] = break_stop_code
|
||||
|
||||
new_events = np.concatenate((orig_events, break_events))
|
||||
# sort events
|
||||
sort_order = np.argsort(new_events[:, 0], kind='mergesort')
|
||||
new_events = new_events[sort_order]
|
||||
return new_events
|
||||
|
||||
|
||||
def _extract_break_start_stop_ms(events, name_to_start_codes,
|
||||
name_to_stop_codes):
|
||||
assert len(events[0]) == 2, "expect only 2dimensional event array here"
|
||||
start_code_to_name_and_y = _to_mrk_code_to_name_and_y(name_to_start_codes)
|
||||
# Ensure all stop marker codes are iterables
|
||||
for name in name_to_stop_codes:
|
||||
codes = name_to_stop_codes[name]
|
||||
if not hasattr(codes, '__len__'):
|
||||
name_to_stop_codes[name] = [codes]
|
||||
all_stop_codes = np.concatenate(list(name_to_stop_codes.values())).astype(
|
||||
np.int32)
|
||||
event_samples = events[:, 0]
|
||||
event_codes = events[:, 1]
|
||||
|
||||
break_starts = []
|
||||
break_stops = []
|
||||
i_event = 0
|
||||
while i_event < len(events):
|
||||
while (i_event < len(events)) and (
|
||||
event_codes[i_event] not in all_stop_codes):
|
||||
i_event += 1
|
||||
if i_event < len(events):
|
||||
# one sample after start
|
||||
stop_sample = event_samples[i_event]
|
||||
stop_code = event_codes[i_event]
|
||||
i_event += 1
|
||||
while (i_event < len(events)) and (
|
||||
event_codes[i_event] not in start_code_to_name_and_y):
|
||||
if event_codes[i_event] in all_stop_codes:
|
||||
log.warning(
|
||||
"New end marker {:.0f} at {:.0f} samples found, "
|
||||
"no start marker for earlier end marker {:.0f} "
|
||||
"at {:.0f} samples found.".format(
|
||||
event_codes[i_event],
|
||||
event_samples[i_event],
|
||||
stop_code, stop_sample))
|
||||
stop_sample = event_samples[i_event]
|
||||
stop_code = event_codes[i_event]
|
||||
i_event += 1
|
||||
|
||||
if i_event == len(events):
|
||||
break
|
||||
|
||||
start_sample = event_samples[i_event]
|
||||
start_code = event_codes[i_event]
|
||||
assert start_code in start_code_to_name_and_y
|
||||
# let's start one after stop of the trial and stop one efore
|
||||
# start of the trial to ensure that markers will be
|
||||
# in right order
|
||||
break_starts.append(stop_sample + 1)
|
||||
break_stops.append(start_sample - 1)
|
||||
return np.array(break_starts), np.array(break_stops)
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.2.1"
|
||||
__version__ = "0.3.0"
|
||||
@@ -149,16 +149,6 @@
|
||||
"train_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.save('/data/schirrmr/schirrmr/BBCI_Data-train-set-X-tmp-test.npy', train_set.X)\n",
|
||||
"np.save('/data/schirrmr/schirrmr/BBCI_Data-train-set-y-tmp-test.npy', train_set.y)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"nbsphinx": "hidden"
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -81,39 +82,18 @@
|
||||
"from braindecode.datautil.signalproc import lowpass_cnt\n",
|
||||
"from braindecode.datautil.signalproc import exponential_running_standardize\n",
|
||||
"\n",
|
||||
"def create_cnts(folder, runs, name_to_start_code, name_to_stop_code, break_start_offset_ms,\n",
|
||||
" break_stop_offset_ms, break_start_code, break_stop_code):\n",
|
||||
"def create_cnts(folder, runs,):\n",
|
||||
" # Load data\n",
|
||||
" cnts = load_bbci_sets_from_folder(folder, runs)\n",
|
||||
" \n",
|
||||
" # Now do some preprocessings:\n",
|
||||
" # Adding breaks, resampling to 250 Hz, lowpass below 38, eponential standardization\n",
|
||||
" break_start_code = -1\n",
|
||||
" break_stop_code = -2\n",
|
||||
" # Resampling to 250 Hz, lowpass below 38, eponential standardization\n",
|
||||
" \n",
|
||||
" new_cnts = []\n",
|
||||
" for cnt in cnts:\n",
|
||||
" # Only take some channels \n",
|
||||
" #cnt = cnt.drop_channels(['STI 014']) # This would remove stimulus channel\n",
|
||||
" cnt = cnt.pick_channels(['C3', 'CPz' 'C4'])\n",
|
||||
" # add breaks\n",
|
||||
" new_events = add_breaks(\n",
|
||||
" np.array(cnt.info['events'])[:, [0,2]], cnt.info['sfreq'],\n",
|
||||
" break_start_code=break_start_code, break_stop_code=break_stop_code,\n",
|
||||
" name_to_start_codes=name_to_start_code, name_to_stop_codes=name_to_stop_code,\n",
|
||||
" min_break_length_ms=5000, max_break_length_ms=9000)\n",
|
||||
" n_break_start_offset = int(cnt.info['sfreq'] * break_start_offset_ms / 1000.0)\n",
|
||||
" n_break_stop_offset = int(cnt.info['sfreq'] * break_stop_offset_ms / 1000.0)\n",
|
||||
" # lets add some offset to break start and stop\n",
|
||||
" # new_events[:,2] contains event codes, new_events[:,0] the sample indices\n",
|
||||
" # new_events[:,1] is always 0 for my loading of BBCI data\n",
|
||||
" new_events[new_events[:,1] == break_start_code, 0] += n_break_start_offset\n",
|
||||
" # 0.5 sec for break stop\n",
|
||||
" new_events[new_events[:,1] == break_stop_code, 0] += n_break_stop_offset\n",
|
||||
" # add a middle column with zeros again\n",
|
||||
" new_events_for_mne = np.zeros((len(new_events), 3))\n",
|
||||
" new_events_for_mne[:,[0,2]] = new_events\n",
|
||||
" cnt.info['events'] = new_events_for_mne\n",
|
||||
" cnt = cnt.pick_channels(['C3', 'CPz', 'C4'])\n",
|
||||
" log.info(\"Preprocessing....\")\n",
|
||||
" cnt = mne_apply(lambda a: lowpass_cnt(a, 38,cnt.info['sfreq'], axis=1), cnt)\n",
|
||||
" cnt = resample_cnt(cnt, 250)\n",
|
||||
@@ -135,30 +115,18 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import OrderedDict\n",
|
||||
"\n",
|
||||
"train_runs = [1,2,3]\n",
|
||||
"train_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/', \n",
|
||||
" train_runs,)\n",
|
||||
"\n",
|
||||
"name_to_start_code = OrderedDict([('Right Hand', 1), ('Feet', 4),\n",
|
||||
" ('Rotation', 8), ('Words', [10])])\n",
|
||||
"\n",
|
||||
"name_to_stop_code = OrderedDict([('Right Hand', [20,21,22,23,24,28,30]),\n",
|
||||
" ('Feet', [20,21,22,23,24,28,30]),\n",
|
||||
" ('Rotation', [20,21,22,23,24,28,30]), \n",
|
||||
" ('Words', [20,21,22,23,24,28,30])])\n",
|
||||
"\n",
|
||||
"break_start_offset_ms = 1000\n",
|
||||
"break_stop_offset_ms = -500\n",
|
||||
"# pick some numbers that were not used before/do not exist in markers\n",
|
||||
"break_start_code = -1\n",
|
||||
"break_stop_code = -2\n",
|
||||
"train_runs = [1,2,3]\n",
|
||||
"train_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R01-8/', \n",
|
||||
" train_runs,\n",
|
||||
" name_to_start_code,\n",
|
||||
" name_to_stop_code, break_start_offset_ms,\n",
|
||||
" break_stop_offset_ms, break_start_code, break_stop_code)\n",
|
||||
"\n",
|
||||
"name_to_code_with_breaks = deepcopy(name_to_start_code)\n",
|
||||
"name_to_code_with_breaks['Break'] = break_start_code\n",
|
||||
"name_to_stop_code_with_breaks = deepcopy(name_to_stop_code)\n",
|
||||
"name_to_stop_code_with_breaks['Break'] = break_stop_code"
|
||||
" ('Words', [20,21,22,23,24,28,30])])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -168,9 +136,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_runs = [9,10]\n",
|
||||
"test_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/', test_runs, name_to_start_code,\n",
|
||||
" name_to_stop_code, break_start_offset_ms,\n",
|
||||
" break_stop_offset_ms, break_start_code, break_stop_code)"
|
||||
"test_cnts = create_cnts('/home/schirrmr/data/robot-hall/AnLa/AnLaNBD1R09-10/', test_runs,)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -241,8 +207,17 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_sets = [create_signal_target_from_raw_mne(cnt, name_to_code_with_breaks, [-receptive_field_ms,0], \n",
|
||||
" name_to_stop_code_with_breaks) for cnt in train_cnts]\n",
|
||||
"from braindecode.datautil.trial_segment import create_signal_target_with_breaks_from_mne\n",
|
||||
"\n",
|
||||
"break_start_offset_ms = 1000\n",
|
||||
"break_stop_offset_ms = -500\n",
|
||||
"\n",
|
||||
"train_sets = [create_signal_target_with_breaks_from_mne(\n",
|
||||
" cnt, name_to_start_code, [0,0], \n",
|
||||
" name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,\n",
|
||||
" break_epoch_ival_ms=[500,-500],\n",
|
||||
" prepad_trials_to_n_samples=input_time_length) \n",
|
||||
" for cnt in train_cnts]\n",
|
||||
"train_set = concatenate_sets(train_sets)"
|
||||
]
|
||||
},
|
||||
@@ -252,32 +227,15 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_sets = [create_signal_target_from_raw_mne(cnt, name_to_code_with_breaks, [-receptive_field_ms,0], \n",
|
||||
" name_to_stop_code_with_breaks) for cnt in test_cnts]\n",
|
||||
"test_sets = [create_signal_target_with_breaks_from_mne(\n",
|
||||
" cnt, name_to_start_code, [0,0], \n",
|
||||
" name_to_stop_code, min_break_length_ms=1000, max_break_length_ms=10000,\n",
|
||||
" break_epoch_ival_ms=[500,-500],\n",
|
||||
" prepad_trials_to_n_samples=input_time_length) \n",
|
||||
" for cnt in test_cnts]\n",
|
||||
"test_set = concatenate_sets(test_sets)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_set_2 = np.load('/data/schirrmr/schirrmr/BBCI_Data_Start_Stop-train-set-tmp-test.npy')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"pickle.dump(train_set, open('/data/schirrmr/schirrmr/BBCI_Data_Start_Stop-train-set-tmp-test.npy', 'wb'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -345,9 +303,10 @@
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch as th\n",
|
||||
"from braindecode.torch_ext.modules import Expression\n",
|
||||
"from braindecode.torch_ext.losses import log_categorical_crossentropy\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2, keepdim=False), targets)\n",
|
||||
"loss_function = log_categorical_crossentropy\n",
|
||||
"\n",
|
||||
"model_constraint = None\n",
|
||||
"monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),\n",
|
||||
@@ -380,7 +339,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We arrive only at 31% accuracy. With only 3 sensors and 3 training runs, cannot expect too much great performance :)"
|
||||
"We arrive at about 54% accuracy. With only 3 sensors and 3 training runs, we cannot get much better :)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@@ -217,9 +217,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
|
||||
@@ -323,61 +321,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We arrive at 24.3% or 26.3% depending on stars :))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"exp2 = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, None, model_constraint,\n",
|
||||
" monitors, stop_criterion, remember_best_column='valid_misclass',\n",
|
||||
" run_after_early_stop=True, batch_modifier=None, cuda=cuda)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exp2.setup_training()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exp2.monitor_epoch(exp2.datasets)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exp2.print_epoch()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"exp2.epochs_df"
|
||||
"We arrive around 26%, exact value depending on stars :))"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
Diff do arquivo suprimido porque uma ou mais linhas são muito longas
@@ -1,9 +1,17 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from braindecode.datautil.trial_segment import (
|
||||
create_cnt_y_and_start_stop_samples)
|
||||
_create_cnt_y_and_trial_bounds_from_start_stop)
|
||||
from braindecode.datautil.trial_segment import (
|
||||
_create_signal_target_from_start_and_ival)
|
||||
from braindecode.datautil.trial_segment import (
|
||||
_create_signal_target_from_start_and_stop)
|
||||
from braindecode.datautil.trial_segment import add_breaks
|
||||
|
||||
|
||||
|
||||
|
||||
def check_cnt_y_start_stop_samples(n_samples, events, fs, epoch_ival_ms,
|
||||
@@ -11,8 +19,9 @@ def check_cnt_y_start_stop_samples(n_samples, events, fs, epoch_ival_ms,
|
||||
name_to_stop_codes, cnt_y, start_stop):
|
||||
|
||||
cnt_y = np.array(cnt_y).T
|
||||
real_cnt_y, real_start_stop = create_cnt_y_and_start_stop_samples(
|
||||
n_samples, events ,fs, name_to_start_codes, epoch_ival_ms, name_to_stop_codes)
|
||||
real_cnt_y, real_start_stop = _create_cnt_y_and_trial_bounds_from_start_stop(
|
||||
n_samples, events ,fs, name_to_start_codes, epoch_ival_ms,
|
||||
name_to_stop_codes)
|
||||
np.testing.assert_array_equal(cnt_y, real_cnt_y)
|
||||
np.testing.assert_array_equal(start_stop, real_start_stop)
|
||||
|
||||
@@ -61,3 +70,279 @@ def test_cnt_y_start_stop_samples_two_class_with_both_appearing():
|
||||
cnt_y=[[1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0]],
|
||||
start_stop=[(0, 2), (3, 6)])
|
||||
|
||||
|
||||
def check_signal_target_from_start_and_ival(data, events, fs, name_to_codes,
|
||||
epoch_ival_ms, expected_X, expected_y):
|
||||
data = np.array(data)
|
||||
events = np.array(events)
|
||||
name_to_codes = OrderedDict(name_to_codes)
|
||||
out_set = _create_signal_target_from_start_and_ival(
|
||||
data, events, fs, name_to_codes, epoch_ival_ms,
|
||||
one_hot_labels=False, one_label_per_trial=True)
|
||||
np.testing.assert_array_equal(out_set.y, expected_y)
|
||||
np.testing.assert_allclose(out_set.X, expected_X)
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_ival():
|
||||
check_signal_target_from_start_and_ival(
|
||||
data=[np.arange(10)], events=[(0, 1)], fs=100, name_to_codes=[('A', 1)],
|
||||
epoch_ival_ms=[0, 30],
|
||||
expected_X=[[[0, 1, 2]]], expected_y=[0])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_ival_ignored_marker():
|
||||
check_signal_target_from_start_and_ival(
|
||||
data=[np.arange(10)], events=[(0, 1), (5, 2)], fs=100,
|
||||
name_to_codes=[('A', 1)], epoch_ival_ms=[0, 30],
|
||||
expected_X=[[[0, 1, 2]]], expected_y=[0])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_ival_two_class():
|
||||
check_signal_target_from_start_and_ival(
|
||||
data=[np.arange(10)], events=[(0,1), (5,2)], fs=100,
|
||||
name_to_codes=[('A', 1), ('B', 2)], epoch_ival_ms=[0, 30],
|
||||
expected_X=[[[0,1,2]], [[5,6,7]]], expected_y=[0,1])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_ival_too_early():
|
||||
check_signal_target_from_start_and_ival(
|
||||
data=[np.arange(10)], events=[(0, 1)], fs=100, name_to_codes=[('A', 1)],
|
||||
epoch_ival_ms=[-10, 30],
|
||||
expected_X=[], expected_y=[])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_ival_too_late():
|
||||
check_signal_target_from_start_and_ival(
|
||||
data=[np.arange(10)], events=[(8, 1)], fs=100, name_to_codes=[('A', 1)],
|
||||
epoch_ival_ms=[0, 30],
|
||||
expected_X=[], expected_y=[])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_ival_overlapping():
|
||||
check_signal_target_from_start_and_ival(
|
||||
data=[np.arange(10)], events=[(0, 1), (1, 2)], fs=100,
|
||||
name_to_codes=[('A', 1), ('B', 2)], epoch_ival_ms=[0, 30],
|
||||
expected_X=[[[0, 1, 2]], [[1, 2, 3]]], expected_y=[0, 1])
|
||||
|
||||
|
||||
def check_signal_target_from_start_and_stop(data, events, fs, name_to_codes,
|
||||
epoch_ival_ms, name_to_stop_codes,
|
||||
pad_to_n_samples,
|
||||
expected_X, expected_y, ):
|
||||
data = np.array(data)
|
||||
events = np.array(events)
|
||||
name_to_codes = OrderedDict(name_to_codes)
|
||||
name_to_stop_codes = OrderedDict(name_to_stop_codes)
|
||||
out_set = _create_signal_target_from_start_and_stop(
|
||||
data, events, fs, name_to_codes, epoch_ival_ms, name_to_stop_codes,
|
||||
pad_to_n_samples, one_hot_labels=False, one_label_per_trial=True)
|
||||
np.testing.assert_array_equal(out_set.y, expected_y)
|
||||
assert len(out_set.X) == len(expected_X)
|
||||
for x_out, x_expected in zip(out_set.X, expected_X):
|
||||
np.testing.assert_allclose(x_out, x_expected)
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop():
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)], events=[(0, 1), (2, -1)], fs=100,
|
||||
name_to_codes=[('A', 1)], epoch_ival_ms=[0, 20],
|
||||
name_to_stop_codes=[('A', -1)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[[[0, 1, 2, 3]]], expected_y=[0])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop_different_ival():
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)], events=[(0, 1), (2, -1)], fs=100,
|
||||
name_to_codes=[('A', 1)], epoch_ival_ms=[0, -10],
|
||||
name_to_stop_codes=[('A', -1)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[[[0, ]]], expected_y=[0])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop_ignored_marker():
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)], events=[(0, 1), (1, 3), (2, -1)], fs=100,
|
||||
name_to_codes=[('A', 1)], epoch_ival_ms=[0, -10],
|
||||
name_to_stop_codes=[('A', -1)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[[[0, ]]], expected_y=[0])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop_too_early():
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)], events=[(0, 1), (2, -1)], fs=100,
|
||||
name_to_codes=[('A', 1)], epoch_ival_ms=[-10, 20],
|
||||
name_to_stop_codes=[('A', -1)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[], expected_y=[])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop_too_late():
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)], events=[(0, 1), (2, -1), (8, 1), (9, -1)], fs=100,
|
||||
name_to_codes=[('A', 1)], epoch_ival_ms=[0, 20],
|
||||
name_to_stop_codes=[('A', -1)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[[[0, 1, 2, 3]]], expected_y=[0])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop_overlapping():
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)],
|
||||
events=[(0, 1), (2, -1), (2, 2), (5, -2)],
|
||||
fs=100,
|
||||
name_to_codes=[('A', 1), ('B', 2)],
|
||||
epoch_ival_ms=[0, 20],
|
||||
name_to_stop_codes=[('A', -1), ('B', -2)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[[[0, 1, 2, 3]], [[2, 3, 4, 5, 6]]],
|
||||
expected_y=[0, 1])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop_stop_missing():
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)],
|
||||
events=[(0, 1), (2, -1), (2, 2), ],
|
||||
fs=100,
|
||||
name_to_codes=[('A', 1), ('B', 2)],
|
||||
epoch_ival_ms=[0, 20],
|
||||
name_to_stop_codes=[('A', -1), ('B', -2)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[[[0, 1, 2, 3]], ],
|
||||
expected_y=[0, ])
|
||||
|
||||
|
||||
def test_signal_target_from_start_and_stop_wrong_stop_for_start():
|
||||
# wrong stop for start
|
||||
# expect assertion raised
|
||||
with pytest.raises(AssertionError):
|
||||
check_signal_target_from_start_and_stop(
|
||||
data=[np.arange(10)],
|
||||
events=[(0, 1), (2, -1), (2, 2), (3, -1)],
|
||||
fs=100,
|
||||
name_to_codes=[('A', 1), ('B', 2)],
|
||||
epoch_ival_ms=[0, 20],
|
||||
name_to_stop_codes=[('A', -1), ('B', -2)],
|
||||
pad_to_n_samples=None,
|
||||
expected_X=[[[0, 1, 2, 3]], ],
|
||||
expected_y=[0, ])
|
||||
|
||||
def check_add_breaks(
|
||||
events,
|
||||
fs,
|
||||
break_start_code,
|
||||
break_stop_code,
|
||||
name_to_start_codes,
|
||||
name_to_stop_codes,
|
||||
min_break_length_ms,
|
||||
max_break_length_ms,
|
||||
break_start_offset_ms,
|
||||
break_stop_offset_ms,
|
||||
expected_events,
|
||||
):
|
||||
events = np.array(events)
|
||||
name_to_start_codes = OrderedDict(name_to_start_codes)
|
||||
name_to_stop_codes = OrderedDict(name_to_stop_codes)
|
||||
events_with_breaks = add_breaks(
|
||||
events, fs, break_start_code, break_stop_code, name_to_start_codes,
|
||||
name_to_stop_codes, min_break_length_ms=min_break_length_ms,
|
||||
max_break_length_ms=max_break_length_ms,
|
||||
break_start_offset_ms=break_start_offset_ms, break_stop_offset_ms=break_stop_offset_ms)
|
||||
np.testing.assert_array_equal(events_with_breaks,
|
||||
expected_events)
|
||||
|
||||
|
||||
def test_add_breaks_no_break():
|
||||
check_add_breaks(
|
||||
events=np.array([(0, 1), (2, -1)]),
|
||||
fs=100,
|
||||
break_start_code=-3,
|
||||
break_stop_code=-4,
|
||||
name_to_start_codes=[('A', 1), ],
|
||||
name_to_stop_codes=[('A', -1), ],
|
||||
min_break_length_ms=None,
|
||||
max_break_length_ms=None,
|
||||
break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None,
|
||||
expected_events=[(0, 1), (2, -1), ])
|
||||
|
||||
|
||||
def test_add_breaks_one_break():
|
||||
# a break added!
|
||||
check_add_breaks(
|
||||
events=np.array([(0, 1), (2, -1), (5, 2)]),
|
||||
fs=100,
|
||||
break_start_code=-3,
|
||||
break_stop_code=-4,
|
||||
name_to_start_codes=[('A', 1), ('B', 2), ],
|
||||
name_to_stop_codes=[('A', -1), ],
|
||||
min_break_length_ms=None,
|
||||
max_break_length_ms=None,
|
||||
break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None,
|
||||
expected_events=[(0, 1), (2, -1), (3, -3), (4, -4), (5, 2), ])
|
||||
|
||||
|
||||
def test_add_breaks_break_within_bound():
|
||||
# a break added within bounds!
|
||||
check_add_breaks(
|
||||
events=np.array([(0, 1), (2, -1), (5, 2)]),
|
||||
fs=100,
|
||||
break_start_code=-3,
|
||||
break_stop_code=-4,
|
||||
name_to_start_codes=[('A', 1), ('B', 2), ],
|
||||
name_to_stop_codes=[('A', -1), ],
|
||||
min_break_length_ms=10,
|
||||
max_break_length_ms=None,
|
||||
break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None,
|
||||
expected_events=[(0, 1), (2, -1), (3, -3), (4, -4), (5, 2), ])
|
||||
|
||||
|
||||
def test_add_breaks_too_short():
|
||||
# not added, too short!
|
||||
check_add_breaks(
|
||||
events=np.array([(0, 1), (2, -1), (5, 2)]),
|
||||
fs=100,
|
||||
break_start_code=-3,
|
||||
break_stop_code=-4,
|
||||
name_to_start_codes=[('A', 1), ('B', 2), ],
|
||||
name_to_stop_codes=[('A', -1), ],
|
||||
min_break_length_ms=20,
|
||||
max_break_length_ms=None,
|
||||
break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None,
|
||||
expected_events=[(0, 1), (2, -1), (5, 2), ])
|
||||
|
||||
def test_add_breaks_too_long():
|
||||
# a break added within both upper and lower bound!
|
||||
check_add_breaks(
|
||||
events=np.array([(0, 1), (2, -1), (6, 2)]),
|
||||
fs=100,
|
||||
break_start_code=-3,
|
||||
break_stop_code=-4,
|
||||
name_to_start_codes=[('A', 1), ('B', 2), ],
|
||||
name_to_stop_codes=[('A', -1), ],
|
||||
min_break_length_ms=None,
|
||||
max_break_length_ms=10,
|
||||
break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None,
|
||||
expected_events=[(0, 1), (2, -1), (6, 2), ])
|
||||
|
||||
|
||||
def test_add_breaks_within_both_bounds():
|
||||
# a break added within both upper and lower bound!
|
||||
check_add_breaks(
|
||||
events=np.array([(0, 1), (2, -1), (5, 2)]),
|
||||
fs=100,
|
||||
break_start_code=-3,
|
||||
break_stop_code=-4,
|
||||
name_to_start_codes=[('A', 1), ('B', 2), ],
|
||||
name_to_stop_codes=[('A', -1), ],
|
||||
min_break_length_ms=10,
|
||||
max_break_length_ms=30,
|
||||
break_start_offset_ms=None,
|
||||
break_stop_offset_ms=None,
|
||||
expected_events=[(0, 1), (2, -1), (3, -3), (4, -4), (5, 2), ])
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário