Merge branch 'master' of github.com:robintibor/braindecode
Esse commit está contido em:
@@ -137,7 +137,7 @@ class BBCIDataset(object):
|
||||
def _add_markers(self, cnt):
|
||||
with h5py.File(self.filename, 'r') as h5file:
|
||||
event_times_in_ms = h5file['mrk']['time'][:].squeeze()
|
||||
event_classes = h5file['mrk']['event']['desc'][:].squeeze()
|
||||
event_classes = h5file['mrk']['event']['desc'][:].squeeze().astype(np.int64)
|
||||
|
||||
# Check whether class names known and correct order
|
||||
class_name_set = h5file['nfo']['className'][:].squeeze()
|
||||
|
||||
@@ -8,7 +8,7 @@ from braindecode.torch_ext.util import np_to_var
|
||||
class Expression(torch.nn.Module):
|
||||
"""
|
||||
Compute given expression on forward pass.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expression_fn: function
|
||||
@@ -37,7 +37,7 @@ class Expression(torch.nn.Module):
|
||||
class AvgPool2dWithConv(torch.nn.Module):
|
||||
"""
|
||||
Compute average pooling using a convolution, to have the dilation parameter.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
kernel_size: (int,int)
|
||||
@@ -81,3 +81,42 @@ class AvgPool2dWithConv(torch.nn.Module):
|
||||
dilation=self.dilation,
|
||||
groups=in_channels,)
|
||||
return pooled
|
||||
|
||||
|
||||
class IntermediateOutputWrapper(torch.nn.Module):
|
||||
"""Wraps network model such that outputs of intermediate layers can be returned.
|
||||
forward() returns list of intermediate activations in a network during forward pass.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
to_select : list
|
||||
list of module names for which activation should be returned
|
||||
model : model object
|
||||
network model
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> model = Deep4Net()
|
||||
>>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
|
||||
>>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
|
||||
"""
|
||||
def __init__(self, to_select, model):
|
||||
if not len(list(model.children()))==len(list(model.named_children())):
|
||||
raise Exception('All modules in model need to have names!')
|
||||
|
||||
super(IntermediateOutputWrapper, self).__init__()
|
||||
|
||||
modules_list = model.named_children()
|
||||
for key, module in modules_list:
|
||||
self.add_module(key, module)
|
||||
self._modules[key].load_state_dict(module.state_dict())
|
||||
self._to_select = to_select
|
||||
|
||||
def forward(self,x):
|
||||
# Call modules individually and append activation to output if module is in to_select
|
||||
o = []
|
||||
for name, module in self._modules.items():
|
||||
x = module(x)
|
||||
if name in self._to_select:
|
||||
o.append(x)
|
||||
return o
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def calc_Hout(Hin, kernel, stride, dilation):
|
||||
Hout = np.floor((Hin-dilation*(kernel-1)-1)/stride+1)
|
||||
return Hout
|
||||
|
||||
def calc_Hin(Hout, kernel, stride, dilation):
|
||||
Hin = np.ceil((Hout-1)*stride+1+dilation*(kernel-1))
|
||||
return Hin
|
||||
|
||||
def calc_receptive_field_size(model,layer_ind,start_receptive_field=np.ones((2))):
|
||||
"""Calculate receptive field size for unit in specific layer of the network
|
||||
Only tested for 2d convolutions/poolings. Dimshuffle operations may lead to a wrong result
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : model object
|
||||
Network model
|
||||
layer_ind: int
|
||||
Index of the layer of interest in `model.children()`
|
||||
start_receptive_field: int, optional
|
||||
How many units are looked at in specified layer (default: [1,1])
|
||||
|
||||
Returns
|
||||
-------
|
||||
receptive_field_size : numpy array
|
||||
[HxW] in the input layer
|
||||
"""
|
||||
receptive_field = start_receptive_field
|
||||
children = list(model.children())[:layer_ind][::-1]
|
||||
for child in children:
|
||||
if isinstance(child,torch.nn.Sequential):
|
||||
receptive_field = calc_receptive_field(child,-1)
|
||||
elif isinstance(child,torch.nn.Conv2d) or isinstance(child,torch.nn.MaxPool2d) or isinstance(child,torch.nn.AvgPool2d):
|
||||
receptive_field = calc_Hin(receptive_field,np.asarray(child.kernel_size),
|
||||
np.asarray(child.stride), np.asarray(child.dilation))
|
||||
|
||||
receptive_field_size = receptive_field.astype(np.int)
|
||||
return receptive_field_size
|
||||
|
||||
def get_max_act_index(activations,unique_per_input=True,n_units=None):
|
||||
"""Retrieve index of maximum activation in a feature map
|
||||
|
||||
Parameters
|
||||
----------
|
||||
activations : numpy array
|
||||
[Nx1xHxW] can only take 1 filter
|
||||
unique_per_input : bool, optional
|
||||
Specifies if only 1 index (maximum) for each input is returned (default: True)
|
||||
n_units : int, optional
|
||||
How many indeces are returned in total. If None all (default: None)
|
||||
|
||||
Returns
|
||||
-------
|
||||
units : numpy array
|
||||
[Nx4] with the columns `Input_i`,`Filter(0)`,`H_i`,`W_i` indeces of the units
|
||||
units_activation : numpy array
|
||||
Activation of the units
|
||||
"""
|
||||
assert len(activations.shape)==4,"Has to be 4d array"
|
||||
assert activations.shape[1] == 1,"Can only handle individual filter activations"
|
||||
|
||||
activations_sorted = activations.argsort(axis=None)[::-1]
|
||||
activations_sorted_ind = np.unravel_index(activations_sorted,activations.shape)
|
||||
unique_ind = np.arange(len(activations_sorted_ind[0]))
|
||||
if unique_per_input:
|
||||
a,unique_ind = np.unique(activations_sorted_ind[0],return_index=True)
|
||||
unique_ind = sorted(unique_ind)
|
||||
|
||||
if n_units==None:
|
||||
n_units = len(unique_ind)
|
||||
activations_sorted_ind = np.asarray(activations_sorted_ind).T
|
||||
units = activations_sorted_ind[unique_ind[:n_units],:].astype(np.int)
|
||||
units_activation = activations.flat[activations_sorted[unique_ind[:n_units]]]
|
||||
|
||||
return units,units_activation
|
||||
|
||||
def calc_receptive_field_for_units_2d(units,receptive_field_size):
|
||||
recptive_field_tmp = receptive_field_size[np.newaxis,np.newaxis,:,:]
|
||||
start_inds = units[:,0,]
|
||||
stop_inds = start_inds+receptive_field_tmp
|
||||
|
||||
return start_inds,stop_inds
|
||||
|
||||
def get_input_windows_from_units_2d(inputs,units,receptive_field_size):
|
||||
"""Cut input windows in receptive field of specified units from inputs
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : numpy array
|
||||
[NxCxHxW] Inputs x Channels x Time x W
|
||||
units : numpy array
|
||||
[Mx4] unit indeces specifying Input and time indeces.
|
||||
Second dimension consists of Input x Filter(1) x H x W indeces. Can only handle 1 filter.
|
||||
receptive_field_size : int
|
||||
Size of receptive field of units on input
|
||||
|
||||
Returns
|
||||
-------
|
||||
windows : numpy array
|
||||
Cut input windows
|
||||
"""
|
||||
windows = np.zeros((units.shape[0],inputs.shape[1],receptive_field_size[0],receptive_field_size[1]))
|
||||
for i,unit in enumerate(units):
|
||||
windows[i] = inputs[unit[0],:,
|
||||
unit[2]:unit[2]+receptive_field_size[0],
|
||||
unit[3]:unit[3]+receptive_field_size[1]]
|
||||
return windows
|
||||
|
||||
|
||||
def most_activating_input_windows(inputs,activations,receptive_field_size,top_percentage=0.05):
|
||||
"""Get the input windows that evoked highest activation in a feature map
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : numpy array
|
||||
[NxCxTx1] Inputs used for the calculation of activations
|
||||
activations : numpy array
|
||||
[NxFxHx1] Activations
|
||||
receptive_field_size : numpy array
|
||||
[Wx1] Receptive field of a unit of the layer on the input
|
||||
top_percentage : float, optional
|
||||
How many of the most activating input windows should be returned (default: top 5%)
|
||||
|
||||
Returns
|
||||
-------
|
||||
input_windows : numpy array
|
||||
[FxUxCxWx1] Returns `U` (resulting from top_percentage) input windows for each filter
|
||||
"""
|
||||
n_units = int(inputs.shape[0]*top_percentage)
|
||||
input_windows = np.zeros((activations.shape[1],n_units,inputs.shape[1],
|
||||
receptive_field_size[0],receptive_field_size[1]))
|
||||
for filt in range(activations.shape[1]):
|
||||
units,w = get_max_act_index(activations[:,[filt]],n_units=n_units)
|
||||
w = w.squeeze()
|
||||
windows_tmp = get_input_windows_from_units_2d(inputs,units,receptive_field_size)
|
||||
input_windows[filt] = windows_tmp
|
||||
return input_windows
|
||||
|
||||
def activation_reverse_correlation(inputs,activations,receptive_field_size):
|
||||
"""Get reverse correlations for filters
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : numpy array
|
||||
[NxCxTx1] Inputs used for the calculation of activations
|
||||
activations : numpy array
|
||||
[NxFxHx1] Activations
|
||||
receptive_field_size : numpy array
|
||||
[Wx1] Receptive field of a unit of the layer on the input
|
||||
|
||||
Returns
|
||||
-------
|
||||
reverse_corr : numpy array
|
||||
[FxCxWx1] Reverse correlations over all input windows for each filter
|
||||
"""
|
||||
reverse_corr = np.zeros((activations.shape[1],inputs.shape[1],
|
||||
receptive_field_size[0],receptive_field_size[1]))
|
||||
for filt in range(activations.shape[1]):
|
||||
act_tmp = activations[:,filt]
|
||||
divisor = 0
|
||||
for start_ind in range(activations.shape[2]):
|
||||
end_ind = start_ind+receptive_field_size[0]
|
||||
w = act_tmp[:,[start_ind],:,np.newaxis]
|
||||
windows_tmp = inputs[:,:,start_ind:end_ind]
|
||||
reverse_corr[filt] += np.sum(w*windows_tmp,axis=0)
|
||||
divisor += np.sum(np.abs(w))
|
||||
reverse_corr[filt] /= divisor
|
||||
return reverse_corr
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from numpy.random import RandomState
|
||||
|
||||
from braindecode.datautil.iterators import get_balanced_batches
|
||||
from braindecode.util import wrap_reshape_apply_fn, corr
|
||||
@@ -9,28 +8,241 @@ from braindecode.util import wrap_reshape_apply_fn, corr
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gaussian_perturbation(amps, rng):
|
||||
"""
|
||||
Create gaussian noise tensor with same shape as amplitudes.
|
||||
def phase_perturbation(amps,phases,rng=np.random.RandomState()):
|
||||
"""Takes amps and phases of BxCxF with B input, C channels, F frequencies
|
||||
Shifts spectral phases randomly U(-pi,pi) for input and frequencies, but same for all channels
|
||||
|
||||
Parameters
|
||||
----------
|
||||
amps: ndarray
|
||||
Amplitudes.
|
||||
rng: RandomState
|
||||
Random generator.
|
||||
amps : numpy array
|
||||
Spectral amplitude (not used)
|
||||
phases : numpy array
|
||||
Spectral phases
|
||||
rng : object
|
||||
Random Seed
|
||||
|
||||
Returns
|
||||
-------
|
||||
perturbation: ndarray
|
||||
Perturbations to add to the amplitudes.
|
||||
amps : numpy array
|
||||
Input amps (not modified)
|
||||
phases_pert : numpy array
|
||||
Shifted phases
|
||||
pert_vals : numpy array
|
||||
Absolute phase shifts
|
||||
"""
|
||||
perturbation = rng.randn(*amps.shape).astype(np.float32)
|
||||
return perturbation
|
||||
noise_shape = list(phases.shape)
|
||||
noise_shape[1] = 1 # Do not sample noise for channels individually
|
||||
|
||||
# Sample phase perturbation noise
|
||||
phase_noise = rng.uniform(-np.pi,np.pi,noise_shape).astype(np.float32)
|
||||
phase_noise_rep = phase_noise.repeat(phases.shape[1],axis=1)
|
||||
# Apply noise to inputs
|
||||
phases_pert = phases+phase_noise_rep
|
||||
phases_pert[phases_pert<-np.pi] += 2*np.pi
|
||||
phases_pert[phases_pert>np.pi] -= 2*np.pi
|
||||
|
||||
pert_vals = np.abs(phase_noise)
|
||||
return amps,phases_pert,pert_vals
|
||||
|
||||
def amp_perturbation_additive(amps,phases,rng=np.random.RandomState()):
|
||||
"""Takes amplitudes and phases of BxCxF with B input, C channels, F frequencies
|
||||
Adds additive noise N(0,0.02) to amplitudes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
amps : numpy array
|
||||
Spectral amplitude
|
||||
phases : numpy array
|
||||
Spectral phases (not used)
|
||||
rng : object
|
||||
Random Seed
|
||||
|
||||
Returns
|
||||
-------
|
||||
amps_pert : numpy array
|
||||
Scaled amplitudes
|
||||
phases_pert : numpy array
|
||||
Input phases (not modified)
|
||||
pert_vals : numpy array
|
||||
Amplitude noise
|
||||
"""
|
||||
amp_noise = rng.normal(0,1,amps.shape).astype(np.float32)
|
||||
amps_pert = amps+amp_noise
|
||||
amps_pert[amps_pert<0] = 0
|
||||
amp_noise = amps_pert-amps
|
||||
return amps_pert,phases,amp_noise
|
||||
|
||||
def amp_perturbation_multiplicative(amps,phases,rng=np.random.RandomState()):
|
||||
"""Takes amplitude and phases of BxCxF with B input, C channels, F frequencies
|
||||
Adds multiplicative noise N(1,0.02) to amplitudes
|
||||
|
||||
Parameters
|
||||
----------
|
||||
amps : numpy array
|
||||
Spectral amplitude
|
||||
phases : numpy array
|
||||
Spectral phases (not used)
|
||||
rng : object
|
||||
Random Seed
|
||||
|
||||
Returns
|
||||
-------
|
||||
amps_pert : numpy array
|
||||
Scaled amplitudes
|
||||
phases_pert : numpy array
|
||||
Input phases (not modified)
|
||||
pert_vals : numpy array
|
||||
Amplitude scaling factor
|
||||
"""
|
||||
amp_noise = rng.normal(1,0.02,amps.shape).astype(np.float32)
|
||||
amps_pert = amps*amp_noise
|
||||
amps_pert[amps_pert<0] = 0
|
||||
return amps_pert,phases,amp_noise
|
||||
|
||||
def correlate_feature_maps(x,y):
|
||||
"""Takes two activation matrices of the form Bx[F]xT where B is batch size, F number of filters (optional) and T time points
|
||||
Returns correlations of the corresponding activations over T
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : numpy array
|
||||
Activations Bx[F]xT
|
||||
y : numpy array
|
||||
Activations Bx[F]xT
|
||||
|
||||
Returns
|
||||
correlations : numpy array
|
||||
Correlations of `x` and `y` Bx[F]
|
||||
"""
|
||||
shape_x = x.shape
|
||||
shape_y = y.shape
|
||||
assert np.array_equal(shape_x,shape_y)
|
||||
assert len(shape_x)<4
|
||||
x = x.reshape((-1,shape_x[-1]))
|
||||
y = y.reshape((-1,shape_y[-1]))
|
||||
|
||||
x = (x-x.mean(axis=1,keepdims=True))/x.std(axis=1,keepdims=True)
|
||||
y = (y-y.mean(axis=1,keepdims=True))/y.std(axis=1,keepdims=True)
|
||||
|
||||
tmp_corr = x*y
|
||||
corr_ = tmp_corr.sum(axis=1)
|
||||
#corr_ = np.zeros((x.shape[0]))
|
||||
#for i in range(x.shape[0]):
|
||||
# # Correlation of standardized variables
|
||||
# corr_[i] = np.correlate((x[i]-x[i].mean())/x[i].std(),(y[i]-y[i].mean())/y[i].std())
|
||||
|
||||
correlations = corr_.reshape(*shape_x[:-1])
|
||||
return correlations
|
||||
|
||||
def mean_diff_feature_maps(x,y):
|
||||
"""Takes two activation matrices of the form BxFxT where B is batch size, F number of filters and T time points
|
||||
Returns mean difference between feature map activations
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : numpy array
|
||||
Activations Bx[F]xT
|
||||
y : numpy array
|
||||
Activations Bx[F]xT
|
||||
|
||||
Returns
|
||||
mean_diff : numpy array
|
||||
Mean difference between `x` and `y` Bx[F]
|
||||
"""
|
||||
mean_diff = np.mean(x-y,axis=2)
|
||||
return mean_diff
|
||||
|
||||
def spectral_perturbation_correlation(pert_fn, diff_fn, pred_fn, n_layers, inputs, n_iterations,
|
||||
batch_size=30,
|
||||
seed=((2017, 7, 10))):
|
||||
"""Calculates perturbation correlations for layers in network by perturbing either amplitudes or phases
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pert_fn : function
|
||||
Function that perturbs spectral phase and amplitudes of inputs
|
||||
diff_fn : function
|
||||
Function that calculates difference between original and perturbed activations
|
||||
pred_fn : function
|
||||
Function that returns a list of activations.
|
||||
Each entry in the list corresponds to the output of 1 layer in a network
|
||||
n_layers : int
|
||||
Number of layers pred_fn returns activations for.
|
||||
inputs : numpy array
|
||||
Original inputs that are used for perturbation [B,X,T,1]
|
||||
Phase perturbations are sampled for each input individually, but applied to all X of that input
|
||||
n_iterations : int
|
||||
Number of iterations of correlation computation. The higher the better
|
||||
batch_size : int
|
||||
Number of inputs that are used for one forward pass. (Concatenated for all inputs)
|
||||
|
||||
Returns
|
||||
-------
|
||||
pert_corrs : numpy array
|
||||
List of length n_layers containing average perturbation correlations over iterations
|
||||
L x CxFrxFi (Channels,Frequencies,Filters)
|
||||
"""
|
||||
rng = np.random.RandomState(seed)
|
||||
|
||||
# Get batch indeces
|
||||
batch_inds = get_balanced_batches(
|
||||
n_trials=len(inputs), rng=rng, shuffle=False, batch_size=batch_size)
|
||||
# Calculate layer activations and reshape
|
||||
log.info("Compute original predictions...")
|
||||
orig_preds = [pred_fn(inputs[inds])
|
||||
for inds in batch_inds]
|
||||
use_shape = []
|
||||
for l in range(n_layers):
|
||||
tmp = list(orig_preds[0][l].shape)
|
||||
tmp.extend([1]*(4-len(tmp)))
|
||||
tmp[0] = len(inputs)
|
||||
use_shape.append(tmp)
|
||||
orig_preds_layers = [np.concatenate([orig_preds[o][l] for o in range(len(orig_preds))]).reshape(use_shape[l])
|
||||
for l in range(n_layers)]
|
||||
|
||||
# Compute FFT of inputs
|
||||
fft_input = np.fft.rfft(inputs, n=inputs.shape[2], axis=2)
|
||||
amps = np.abs(fft_input)
|
||||
phases = np.angle(fft_input)
|
||||
|
||||
pert_corrs = [0]*n_layers
|
||||
for i in range(n_iterations):
|
||||
log.info("Iteration {:d}...".format(i))
|
||||
log.info("Sample perturbation...")
|
||||
amps_pert,phases_pert,pert_vals = pert_fn(amps,phases,rng=rng)
|
||||
|
||||
# Compute perturbed inputs
|
||||
log.info("Compute perturbed complex inputs...")
|
||||
fft_pert = amps_pert*np.exp(1j*phases_pert)
|
||||
log.info("Compute perturbed real inputs...")
|
||||
inputs_pert = np.fft.irfft(fft_pert, n=inputs.shape[2], axis=2).astype(np.float32)
|
||||
|
||||
# Calculate layer activations for perturbed inputs
|
||||
log.info("Compute new predictions...")
|
||||
new_preds = [pred_fn(inputs_pert[inds])
|
||||
for inds in batch_inds]
|
||||
new_preds_layers = [np.concatenate([new_preds[o][l] for o in range(len(new_preds))]).reshape(use_shape[l])
|
||||
for l in range(n_layers)]
|
||||
|
||||
for l in range(n_layers):
|
||||
log.info("Layer {:d}...".format(l))
|
||||
# Calculate difference of original and perturbed feature map activations
|
||||
log.info("Compute activation difference...")
|
||||
preds_diff = diff_fn(new_preds_layers[l][:,:,:,0],orig_preds_layers[l][:,:,:,0])
|
||||
|
||||
# Calculate feature map differences with perturbations
|
||||
log.info("Compute correlation...")
|
||||
pert_corrs_tmp = wrap_reshape_apply_fn(corr,
|
||||
pert_vals[:,:,:,0],preds_diff,
|
||||
axis_a=(0,), axis_b=(0))
|
||||
pert_corrs[l] += pert_corrs_tmp
|
||||
|
||||
pert_corrs = [pert_corrs[l]/n_iterations for l in range(n_layers)] #mean over iterations
|
||||
return pert_corrs
|
||||
|
||||
|
||||
def compute_amplitude_prediction_correlations(pred_fn, examples, n_iterations,
|
||||
perturb_fn=gaussian_perturbation,
|
||||
perturb_fn=amp_perturbation_additive,
|
||||
batch_size=30,
|
||||
seed=((2017, 7, 10))):
|
||||
"""
|
||||
@@ -71,46 +283,9 @@ def compute_amplitude_prediction_correlations(pred_fn, examples, n_iterations,
|
||||
visualization.
|
||||
arXiv preprint arXiv:1703.05051.
|
||||
"""
|
||||
inds_per_batch = get_balanced_batches(
|
||||
n_trials=len(examples), rng=None, shuffle=False, batch_size=batch_size)
|
||||
log.info("Compute original predictions...")
|
||||
orig_preds = [pred_fn(examples[example_inds])
|
||||
for example_inds in inds_per_batch]
|
||||
orig_preds_arr = np.concatenate(orig_preds)
|
||||
rng = RandomState(seed)
|
||||
fft_input = np.fft.rfft(examples, axis=2)
|
||||
amps = np.abs(fft_input)
|
||||
phases = np.angle(fft_input)
|
||||
pred_fn_new = lambda x: [pred_fn(x)]
|
||||
pred_corrs = spectral_perturbation_correlation(perturb_fn, mean_diff_feature_maps,
|
||||
pred_fn_new, 1, examples, n_iterations,
|
||||
batch_size=batch_size, seed=seed)
|
||||
|
||||
amp_pred_corrs = []
|
||||
for i_iteration in range(n_iterations):
|
||||
log.info("Iteration {:d}...".format(i_iteration))
|
||||
log.info("Sample perturbation...")
|
||||
perturbation = perturb_fn(amps, rng)
|
||||
log.info("Compute new amplitudes...")
|
||||
# do not allow perturbation to make amplitudes go below
|
||||
# zero
|
||||
perturbation = np.maximum(-amps, perturbation)
|
||||
new_amps = amps + perturbation
|
||||
log.info("Compute new complex inputs...")
|
||||
new_complex = _amplitude_phase_to_complex(new_amps, phases)
|
||||
log.info("Compute new real inputs...")
|
||||
new_in = np.fft.irfft(new_complex, axis=2).astype(np.float32)
|
||||
log.info("Compute new predictions...")
|
||||
new_preds = [pred_fn(new_in[example_inds])
|
||||
for example_inds in inds_per_batch]
|
||||
|
||||
new_preds_arr = np.concatenate(new_preds)
|
||||
|
||||
diff_preds = new_preds_arr - orig_preds_arr
|
||||
|
||||
log.info("Compute correlation...")
|
||||
amp_pred_corr = wrap_reshape_apply_fn(corr, perturbation[:, :, :, 0],
|
||||
diff_preds,
|
||||
axis_a=(0,), axis_b=(0))
|
||||
amp_pred_corrs.append(amp_pred_corr)
|
||||
return amp_pred_corrs
|
||||
|
||||
|
||||
def _amplitude_phase_to_complex(amplitude, phase):
|
||||
return amplitude * np.cos(phase) + amplitude * np.sin(phase) * 1j
|
||||
return pred_corrs[0]
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from scipy.optimize import leastsq
|
||||
import numpy as np
|
||||
|
||||
def err_fn_sin(p,x,y):
|
||||
return (y-fit_fn_sin(x,*p)).flat
|
||||
|
||||
def err_fn_lin(p,x,y):
|
||||
return (y-fit_fn_lin(x,*p)).flat
|
||||
|
||||
def fit_fn_lin(x,*kwargs):
|
||||
return kwargs[0]+kwargs[1]*x
|
||||
|
||||
def fit_fn_sin(x,*kwargs):
|
||||
freqs = kwargs[0]
|
||||
amps = kwargs[1]
|
||||
phases = kwargs[2]
|
||||
offset = kwargs[3]
|
||||
sig = np.zeros((len(x)))+offset
|
||||
sig += amps*np.cos(x*freqs+phases)
|
||||
return sig
|
||||
|
||||
def signal_fit(signals,fs):
|
||||
"""Fits sinusoid and linear function to signals
|
||||
see sinfit.fit_fn_sin and sinfit.fit_fn_lin
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signals : numpy array
|
||||
[FxCxTx1] Filters x Channels x Time x 1
|
||||
fs : float
|
||||
Sampling frequency
|
||||
|
||||
Returns
|
||||
-------
|
||||
params_sin : numpy array
|
||||
[FxCx4] Parameters of sinusoid fit
|
||||
Parameters are: Frequency,Amplitude,Phase,DCOffset
|
||||
params_lin : numpy array
|
||||
[FxCx2] Parameters of sinusoid fit
|
||||
Parameters are: Frequency,Amplitude,Phase,DCOffset
|
||||
err_sin : numpy array
|
||||
[FxCx1] MSE for sinusoid fit
|
||||
err_lin : numpy array
|
||||
[FxCx1] MSE for linear fit
|
||||
"""
|
||||
params_sin = []
|
||||
params_lin = []
|
||||
|
||||
err_sin = []
|
||||
err_lin = []
|
||||
|
||||
freqs = np.fft.rfftfreq(signals.shape[2], d=1.0/fs)[1:]
|
||||
x = np.linspace(0,signals.shape[2]/fs,signals.shape[2])*2*np.pi
|
||||
for filt in range(signals.shape[0]):
|
||||
params_sin_tmp = []
|
||||
params_lin_tmp = []
|
||||
|
||||
err_sin_tmp = []
|
||||
err_lin_tmp = []
|
||||
for ch in range(signals.shape[1]):
|
||||
X_tmp = signals[filt,ch].squeeze()
|
||||
fft_X = np.fft.rfft(X_tmp,axis=0)
|
||||
amps_mean = np.abs(fft_X)[1:]
|
||||
phases_mean = np.angle(fft_X)[1:]
|
||||
offset = X_tmp.mean()
|
||||
|
||||
sort = np.argsort(amps_mean)[::-1][0]
|
||||
p0 = [freqs[sort],amps_mean[sort],phases_mean[sort],offset]
|
||||
|
||||
fit_sin_ch = leastsq(err_fn_sin, p0, args=(x, X_tmp),maxfev=100000)
|
||||
fit_lin_ch = leastsq(err_fn_lin, [0,0], args=(x, X_tmp),maxfev=100000)
|
||||
|
||||
err_sin_ch = np.square(fit_fn_sin(x,*fit_sin_ch[0]) - X_tmp).mean()
|
||||
err_lin_ch = np.square(fit_fn_lin(x,*fit_lin_ch[0]) - X_tmp).mean()
|
||||
|
||||
params_sin_tmp.append(fit_sin_ch[0])
|
||||
params_lin_tmp.append(fit_lin_ch[0])
|
||||
err_sin_tmp.append(err_sin_ch)
|
||||
err_lin_tmp.append(err_lin_ch)
|
||||
params_sin.append(params_sin_tmp)
|
||||
params_lin.append(params_lin_tmp)
|
||||
err_sin.append(err_sin_tmp)
|
||||
err_lin.append(err_lin_tmp)
|
||||
params_sin = np.asarray(params_sin)
|
||||
params_lin = np.asarray(params_lin)
|
||||
err_sin = np.asarray(err_sin)
|
||||
err_lin = np.asarray(err_lin)
|
||||
|
||||
return params_sin,params_lin,err_sin,err_lin
|
||||
@@ -553,24 +553,6 @@
|
||||
"## Plot correlations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Compute the mean correlations across iterations."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mean_corr = np.mean(amp_pred_corrs, axis=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -738,21 +720,21 @@
|
||||
"keep_outputs": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 2",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "python2"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.2"
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário