Fixes
Esse commit está contido em:
@@ -92,10 +92,10 @@ class IntermediateOutputWrapper(torch.nn.Module):
|
||||
network model
|
||||
|
||||
Examples
|
||||
--------
|
||||
model = Deep4Net()
|
||||
select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
|
||||
model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
|
||||
--------
|
||||
>>> model = Deep4Net()
|
||||
>>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
|
||||
>>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
|
||||
"""
|
||||
def __init__(self, to_select, model):
|
||||
if not len(list(model.children()))==len(list(model.named_children())):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import RandomState
|
||||
|
||||
from braindecode.datautil.iterators import get_balanced_batches
|
||||
from braindecode.util import wrap_reshape_apply_fn, corr
|
||||
@@ -46,7 +45,7 @@ def phase_perturbation(amps,phases,rng=np.random.RandomState()):
|
||||
return amps,phases_pert,pert_vals
|
||||
|
||||
def amp_perturbation_additive(amps,phases,rng=np.random.RandomState()):
|
||||
"""Takes amps and phases of BxCxF with B input, C channels, F frequencies
|
||||
"""Takes amplitudes and phases of BxCxF with B input, C channels, F frequencies
|
||||
Adds additive noise N(0,0.02) to amplitudes
|
||||
|
||||
Parameters
|
||||
@@ -74,7 +73,7 @@ def amp_perturbation_additive(amps,phases,rng=np.random.RandomState()):
|
||||
return amps_pert,phases,amp_noise
|
||||
|
||||
def amp_perturbation_multiplicative(amps,phases,rng=np.random.RandomState()):
|
||||
"""Takes amps and phases of BxCxF with B input, C channels, F frequencies
|
||||
"""Takes amplitude and phases of BxCxF with B input, C channels, F frequencies
|
||||
Adds multiplicative noise N(1,0.02) to amplitudes
|
||||
|
||||
Parameters
|
||||
@@ -156,7 +155,7 @@ def mean_diff_feature_maps(x,y):
|
||||
def spectral_perturbation_correlation(pert_fn, diff_fn, pred_fn, n_layers, inputs, n_iterations,
|
||||
batch_size=30,
|
||||
seed=((2017, 7, 10))):
|
||||
"""Calculates phase perturbation correlation for layers in network
|
||||
"""Calculates perturbation correlations for layers in network by perturbing either amplitudes or phases
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário