Esse commit está contido em:
Kay Gregor Hartmann
2018-06-08 06:46:01 +12:00
commit 83949a4ac7
2 arquivos alterados com 7 adições e 8 exclusões
+4 -4
Ver Arquivo
@@ -92,10 +92,10 @@ class IntermediateOutputWrapper(torch.nn.Module):
network model
Examples
--------
model = Deep4Net()
select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
--------
>>> model = Deep4Net()
>>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
>>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
"""
def __init__(self, to_select, model):
if not len(list(model.children()))==len(list(model.named_children())):
+3 -4
Ver Arquivo
@@ -1,7 +1,6 @@
import logging
import numpy as np
from numpy.random import RandomState
from braindecode.datautil.iterators import get_balanced_batches
from braindecode.util import wrap_reshape_apply_fn, corr
@@ -46,7 +45,7 @@ def phase_perturbation(amps,phases,rng=np.random.RandomState()):
return amps,phases_pert,pert_vals
def amp_perturbation_additive(amps,phases,rng=np.random.RandomState()):
"""Takes amps and phases of BxCxF with B input, C channels, F frequencies
"""Takes amplitudes and phases of BxCxF with B input, C channels, F frequencies
Adds additive noise N(0,0.02) to amplitudes
Parameters
@@ -74,7 +73,7 @@ def amp_perturbation_additive(amps,phases,rng=np.random.RandomState()):
return amps_pert,phases,amp_noise
def amp_perturbation_multiplicative(amps,phases,rng=np.random.RandomState()):
"""Takes amps and phases of BxCxF with B input, C channels, F frequencies
"""Takes amplitude and phases of BxCxF with B input, C channels, F frequencies
Adds multiplicative noise N(1,0.02) to amplitudes
Parameters
@@ -156,7 +155,7 @@ def mean_diff_feature_maps(x,y):
def spectral_perturbation_correlation(pert_fn, diff_fn, pred_fn, n_layers, inputs, n_iterations,
batch_size=30,
seed=((2017, 7, 10))):
"""Calculates phase perturbation correlation for layers in network
"""Calculates perturbation correlations for layers in network by perturbing either amplitudes or phases
Parameters
----------