put hyperparameters into one place in code
Esse commit está contido em:
+20
-12
@@ -28,6 +28,15 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
ival = [-500, 4000]
|
||||
max_epochs = 1600
|
||||
max_increase_epochs = 160
|
||||
batch_size = 60
|
||||
high_cut_hz = 38
|
||||
factor_new = 1e-3
|
||||
init_block_size = 1000
|
||||
valid_set_fraction = 0.2
|
||||
|
||||
train_filename = 'A{:02d}T.gdf'.format(subject_id)
|
||||
test_filename = 'A{:02d}E.gdf'.format(subject_id)
|
||||
train_filepath = os.path.join(data_folder, train_filename)
|
||||
@@ -50,12 +59,12 @@ def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
# lets convert to millvolt for numerical stability of next operations
|
||||
train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
|
||||
train_cnt = mne_apply(
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, 38, train_cnt.info['sfreq'],
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'],
|
||||
filt_order=3,
|
||||
axis=1), train_cnt)
|
||||
train_cnt = mne_apply(
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
|
||||
init_block_size=1000,
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
|
||||
init_block_size=init_block_size,
|
||||
eps=1e-4).T,
|
||||
train_cnt)
|
||||
|
||||
@@ -64,24 +73,23 @@ def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
assert len(test_cnt.ch_names) == 22
|
||||
test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
|
||||
test_cnt = mne_apply(
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, 38, test_cnt.info['sfreq'],
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info['sfreq'],
|
||||
filt_order=3,
|
||||
axis=1), test_cnt)
|
||||
test_cnt = mne_apply(
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
|
||||
init_block_size=1000,
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
|
||||
init_block_size=init_block_size,
|
||||
eps=1e-4).T,
|
||||
test_cnt)
|
||||
|
||||
marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
|
||||
('Foot', [3]), ('Tongue', [4])])
|
||||
ival = [-500, 4000]
|
||||
|
||||
train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
|
||||
test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
|
||||
|
||||
train_set, valid_set = split_into_two_sets(train_set,
|
||||
first_set_fraction=0.8)
|
||||
train_set, valid_set = split_into_two_sets(
|
||||
train_set, first_set_fraction=1-valid_set_fraction)
|
||||
|
||||
set_random_seeds(seed=20190706, cuda=cuda)
|
||||
|
||||
@@ -100,10 +108,10 @@ def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
|
||||
optimizer = optim.Adam(model.parameters())
|
||||
|
||||
iterator = BalancedBatchSizeIterator(batch_size=60)
|
||||
iterator = BalancedBatchSizeIterator(batch_size=batch_size)
|
||||
|
||||
stop_criterion = Or([MaxEpochs(1600),
|
||||
NoDecrease('valid_misclass', 160)])
|
||||
stop_criterion = Or([MaxEpochs(max_epochs),
|
||||
NoDecrease('valid_misclass', max_increase_epochs)])
|
||||
|
||||
monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]
|
||||
|
||||
|
||||
@@ -30,6 +30,16 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
ival = [-500, 4000]
|
||||
input_time_length = 1000
|
||||
max_epochs = 800
|
||||
max_increase_epochs = 80
|
||||
batch_size = 60
|
||||
high_cut_hz = 38
|
||||
factor_new = 1e-3
|
||||
init_block_size = 1000
|
||||
valid_set_fraction = 0.2
|
||||
|
||||
train_filename = 'A{:02d}T.gdf'.format(subject_id)
|
||||
test_filename = 'A{:02d}E.gdf'.format(subject_id)
|
||||
train_filepath = os.path.join(data_folder, train_filename)
|
||||
@@ -52,12 +62,12 @@ def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
# lets convert to millvolt for numerical stability of next operations
|
||||
train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
|
||||
train_cnt = mne_apply(
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, 38, train_cnt.info['sfreq'],
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'],
|
||||
filt_order=3,
|
||||
axis=1), train_cnt)
|
||||
train_cnt = mne_apply(
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
|
||||
init_block_size=1000,
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
|
||||
init_block_size=init_block_size,
|
||||
eps=1e-4).T,
|
||||
train_cnt)
|
||||
|
||||
@@ -66,30 +76,28 @@ def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
assert len(test_cnt.ch_names) == 22
|
||||
test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
|
||||
test_cnt = mne_apply(
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, 38, test_cnt.info['sfreq'],
|
||||
lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info['sfreq'],
|
||||
filt_order=3,
|
||||
axis=1), test_cnt)
|
||||
test_cnt = mne_apply(
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
|
||||
init_block_size=1000,
|
||||
lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
|
||||
init_block_size=init_block_size,
|
||||
eps=1e-4).T,
|
||||
test_cnt)
|
||||
|
||||
marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
|
||||
('Foot', [3]), ('Tongue', [4])])
|
||||
ival = [-500, 4000]
|
||||
|
||||
train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
|
||||
test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
|
||||
|
||||
train_set, valid_set = split_into_two_sets(train_set,
|
||||
first_set_fraction=0.8)
|
||||
train_set, valid_set = split_into_two_sets(
|
||||
train_set, first_set_fraction=1-valid_set_fraction)
|
||||
|
||||
set_random_seeds(seed=20190706, cuda=cuda)
|
||||
|
||||
n_classes = 4
|
||||
n_chans = int(train_set.X.shape[1])
|
||||
input_time_length=1000
|
||||
if model == 'shallow':
|
||||
model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
|
||||
final_conv_length=30).create_network()
|
||||
@@ -112,12 +120,12 @@ def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
|
||||
|
||||
optimizer = optim.Adam(model.parameters())
|
||||
|
||||
iterator = CropsFromTrialsIterator(batch_size=60,
|
||||
iterator = CropsFromTrialsIterator(batch_size=batch_size,
|
||||
input_time_length=input_time_length,
|
||||
n_preds_per_input=n_preds_per_input)
|
||||
|
||||
stop_criterion = Or([MaxEpochs(800),
|
||||
NoDecrease('valid_misclass', 80)])
|
||||
stop_criterion = Or([MaxEpochs(max_epochs),
|
||||
NoDecrease('valid_misclass', max_increase_epochs)])
|
||||
|
||||
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
|
||||
CroppedTrialMisclassMonitor(
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário