Merge branch 'master' of github.com:robintibor/braindecode

Esse commit está contido em:
Robin Tibor Schirrmeister
2018-06-12 17:14:08 +02:00
6 arquivos alterados com 536 adições e 81 exclusões
+1 -1
Ver Arquivo
@@ -137,7 +137,7 @@ class BBCIDataset(object):
def _add_markers(self, cnt):
with h5py.File(self.filename, 'r') as h5file:
event_times_in_ms = h5file['mrk']['time'][:].squeeze()
event_classes = h5file['mrk']['event']['desc'][:].squeeze()
event_classes = h5file['mrk']['event']['desc'][:].squeeze().astype(np.int64)
# Check whether class names known and correct order
class_name_set = h5file['nfo']['className'][:].squeeze()
+41 -2
Ver Arquivo
@@ -8,7 +8,7 @@ from braindecode.torch_ext.util import np_to_var
class Expression(torch.nn.Module):
"""
Compute given expression on forward pass.
Parameters
----------
expression_fn: function
@@ -37,7 +37,7 @@ class Expression(torch.nn.Module):
class AvgPool2dWithConv(torch.nn.Module):
"""
Compute average pooling using a convolution, to have the dilation parameter.
Parameters
----------
kernel_size: (int,int)
@@ -81,3 +81,42 @@ class AvgPool2dWithConv(torch.nn.Module):
dilation=self.dilation,
groups=in_channels,)
return pooled
class IntermediateOutputWrapper(torch.nn.Module):
"""Wraps network model such that outputs of intermediate layers can be returned.
forward() returns list of intermediate activations in a network during forward pass.
Parameters
----------
to_select : list
list of module names for which activation should be returned
model : model object
network model
Examples
--------
>>> model = Deep4Net()
>>> select_modules = ['conv_spat','conv_2','conv_3','conv_4'] # Specify intermediate outputs
>>> model_pert = IntermediateOutputWrapper(select_modules,model) # Wrap model
"""
def __init__(self, to_select, model):
if not len(list(model.children()))==len(list(model.named_children())):
raise Exception('All modules in model need to have names!')
super(IntermediateOutputWrapper, self).__init__()
modules_list = model.named_children()
for key, module in modules_list:
self.add_module(key, module)
self._modules[key].load_state_dict(module.state_dict())
self._to_select = to_select
def forward(self,x):
# Call modules individually and append activation to output if module is in to_select
o = []
for name, module in self._modules.items():
x = module(x)
if name in self._to_select:
o.append(x)
return o
+170
Ver Arquivo
@@ -0,0 +1,170 @@
import numpy as np
import torch
def calc_Hout(Hin, kernel, stride, dilation):
Hout = np.floor((Hin-dilation*(kernel-1)-1)/stride+1)
return Hout
def calc_Hin(Hout, kernel, stride, dilation):
Hin = np.ceil((Hout-1)*stride+1+dilation*(kernel-1))
return Hin
def calc_receptive_field_size(model,layer_ind,start_receptive_field=np.ones((2))):
"""Calculate receptive field size for unit in specific layer of the network
Only tested for 2d convolutions/poolings. Dimshuffle operations may lead to a wrong result
Parameters
----------
model : model object
Network model
layer_ind: int
Index of the layer of interest in `model.children()`
start_receptive_field: int, optional
How many units are looked at in specified layer (default: [1,1])
Returns
-------
receptive_field_size : numpy array
[HxW] in the input layer
"""
receptive_field = start_receptive_field
children = list(model.children())[:layer_ind][::-1]
for child in children:
if isinstance(child,torch.nn.Sequential):
receptive_field = calc_receptive_field(child,-1)
elif isinstance(child,torch.nn.Conv2d) or isinstance(child,torch.nn.MaxPool2d) or isinstance(child,torch.nn.AvgPool2d):
receptive_field = calc_Hin(receptive_field,np.asarray(child.kernel_size),
np.asarray(child.stride), np.asarray(child.dilation))
receptive_field_size = receptive_field.astype(np.int)
return receptive_field_size
def get_max_act_index(activations,unique_per_input=True,n_units=None):
"""Retrieve index of maximum activation in a feature map
Parameters
----------
activations : numpy array
[Nx1xHxW] can only take 1 filter
unique_per_input : bool, optional
Specifies if only 1 index (maximum) for each input is returned (default: True)
n_units : int, optional
How many indeces are returned in total. If None all (default: None)
Returns
-------
units : numpy array
[Nx4] with the columns `Input_i`,`Filter(0)`,`H_i`,`W_i` indeces of the units
units_activation : numpy array
Activation of the units
"""
assert len(activations.shape)==4,"Has to be 4d array"
assert activations.shape[1] == 1,"Can only handle individual filter activations"
activations_sorted = activations.argsort(axis=None)[::-1]
activations_sorted_ind = np.unravel_index(activations_sorted,activations.shape)
unique_ind = np.arange(len(activations_sorted_ind[0]))
if unique_per_input:
a,unique_ind = np.unique(activations_sorted_ind[0],return_index=True)
unique_ind = sorted(unique_ind)
if n_units==None:
n_units = len(unique_ind)
activations_sorted_ind = np.asarray(activations_sorted_ind).T
units = activations_sorted_ind[unique_ind[:n_units],:].astype(np.int)
units_activation = activations.flat[activations_sorted[unique_ind[:n_units]]]
return units,units_activation
def calc_receptive_field_for_units_2d(units,receptive_field_size):
recptive_field_tmp = receptive_field_size[np.newaxis,np.newaxis,:,:]
start_inds = units[:,0,]
stop_inds = start_inds+receptive_field_tmp
return start_inds,stop_inds
def get_input_windows_from_units_2d(inputs,units,receptive_field_size):
"""Cut input windows in receptive field of specified units from inputs
Parameters
----------
inputs : numpy array
[NxCxHxW] Inputs x Channels x Time x W
units : numpy array
[Mx4] unit indeces specifying Input and time indeces.
Second dimension consists of Input x Filter(1) x H x W indeces. Can only handle 1 filter.
receptive_field_size : int
Size of receptive field of units on input
Returns
-------
windows : numpy array
Cut input windows
"""
windows = np.zeros((units.shape[0],inputs.shape[1],receptive_field_size[0],receptive_field_size[1]))
for i,unit in enumerate(units):
windows[i] = inputs[unit[0],:,
unit[2]:unit[2]+receptive_field_size[0],
unit[3]:unit[3]+receptive_field_size[1]]
return windows
def most_activating_input_windows(inputs,activations,receptive_field_size,top_percentage=0.05):
"""Get the input windows that evoked highest activation in a feature map
Parameters
----------
inputs : numpy array
[NxCxTx1] Inputs used for the calculation of activations
activations : numpy array
[NxFxHx1] Activations
receptive_field_size : numpy array
[Wx1] Receptive field of a unit of the layer on the input
top_percentage : float, optional
How many of the most activating input windows should be returned (default: top 5%)
Returns
-------
input_windows : numpy array
[FxUxCxWx1] Returns `U` (resulting from top_percentage) input windows for each filter
"""
n_units = int(inputs.shape[0]*top_percentage)
input_windows = np.zeros((activations.shape[1],n_units,inputs.shape[1],
receptive_field_size[0],receptive_field_size[1]))
for filt in range(activations.shape[1]):
units,w = get_max_act_index(activations[:,[filt]],n_units=n_units)
w = w.squeeze()
windows_tmp = get_input_windows_from_units_2d(inputs,units,receptive_field_size)
input_windows[filt] = windows_tmp
return input_windows
def activation_reverse_correlation(inputs,activations,receptive_field_size):
"""Get reverse correlations for filters
Parameters
----------
inputs : numpy array
[NxCxTx1] Inputs used for the calculation of activations
activations : numpy array
[NxFxHx1] Activations
receptive_field_size : numpy array
[Wx1] Receptive field of a unit of the layer on the input
Returns
-------
reverse_corr : numpy array
[FxCxWx1] Reverse correlations over all input windows for each filter
"""
reverse_corr = np.zeros((activations.shape[1],inputs.shape[1],
receptive_field_size[0],receptive_field_size[1]))
for filt in range(activations.shape[1]):
act_tmp = activations[:,filt]
divisor = 0
for start_ind in range(activations.shape[2]):
end_ind = start_ind+receptive_field_size[0]
w = act_tmp[:,[start_ind],:,np.newaxis]
windows_tmp = inputs[:,:,start_ind:end_ind]
reverse_corr[filt] += np.sum(w*windows_tmp,axis=0)
divisor += np.sum(np.abs(w))
reverse_corr[filt] /= divisor
return reverse_corr
+230 -55
Ver Arquivo
@@ -1,7 +1,6 @@
import logging
import numpy as np
from numpy.random import RandomState
from braindecode.datautil.iterators import get_balanced_batches
from braindecode.util import wrap_reshape_apply_fn, corr
@@ -9,28 +8,241 @@ from braindecode.util import wrap_reshape_apply_fn, corr
log = logging.getLogger(__name__)
def gaussian_perturbation(amps, rng):
"""
Create gaussian noise tensor with same shape as amplitudes.
def phase_perturbation(amps,phases,rng=np.random.RandomState()):
"""Takes amps and phases of BxCxF with B input, C channels, F frequencies
Shifts spectral phases randomly U(-pi,pi) for input and frequencies, but same for all channels
Parameters
----------
amps: ndarray
Amplitudes.
rng: RandomState
Random generator.
amps : numpy array
Spectral amplitude (not used)
phases : numpy array
Spectral phases
rng : object
Random Seed
Returns
-------
perturbation: ndarray
Perturbations to add to the amplitudes.
amps : numpy array
Input amps (not modified)
phases_pert : numpy array
Shifted phases
pert_vals : numpy array
Absolute phase shifts
"""
perturbation = rng.randn(*amps.shape).astype(np.float32)
return perturbation
noise_shape = list(phases.shape)
noise_shape[1] = 1 # Do not sample noise for channels individually
# Sample phase perturbation noise
phase_noise = rng.uniform(-np.pi,np.pi,noise_shape).astype(np.float32)
phase_noise_rep = phase_noise.repeat(phases.shape[1],axis=1)
# Apply noise to inputs
phases_pert = phases+phase_noise_rep
phases_pert[phases_pert<-np.pi] += 2*np.pi
phases_pert[phases_pert>np.pi] -= 2*np.pi
pert_vals = np.abs(phase_noise)
return amps,phases_pert,pert_vals
def amp_perturbation_additive(amps,phases,rng=np.random.RandomState()):
"""Takes amplitudes and phases of BxCxF with B input, C channels, F frequencies
Adds additive noise N(0,0.02) to amplitudes
Parameters
----------
amps : numpy array
Spectral amplitude
phases : numpy array
Spectral phases (not used)
rng : object
Random Seed
Returns
-------
amps_pert : numpy array
Scaled amplitudes
phases_pert : numpy array
Input phases (not modified)
pert_vals : numpy array
Amplitude noise
"""
amp_noise = rng.normal(0,1,amps.shape).astype(np.float32)
amps_pert = amps+amp_noise
amps_pert[amps_pert<0] = 0
amp_noise = amps_pert-amps
return amps_pert,phases,amp_noise
def amp_perturbation_multiplicative(amps,phases,rng=np.random.RandomState()):
"""Takes amplitude and phases of BxCxF with B input, C channels, F frequencies
Adds multiplicative noise N(1,0.02) to amplitudes
Parameters
----------
amps : numpy array
Spectral amplitude
phases : numpy array
Spectral phases (not used)
rng : object
Random Seed
Returns
-------
amps_pert : numpy array
Scaled amplitudes
phases_pert : numpy array
Input phases (not modified)
pert_vals : numpy array
Amplitude scaling factor
"""
amp_noise = rng.normal(1,0.02,amps.shape).astype(np.float32)
amps_pert = amps*amp_noise
amps_pert[amps_pert<0] = 0
return amps_pert,phases,amp_noise
def correlate_feature_maps(x,y):
"""Takes two activation matrices of the form Bx[F]xT where B is batch size, F number of filters (optional) and T time points
Returns correlations of the corresponding activations over T
Parameters
----------
x : numpy array
Activations Bx[F]xT
y : numpy array
Activations Bx[F]xT
Returns
correlations : numpy array
Correlations of `x` and `y` Bx[F]
"""
shape_x = x.shape
shape_y = y.shape
assert np.array_equal(shape_x,shape_y)
assert len(shape_x)<4
x = x.reshape((-1,shape_x[-1]))
y = y.reshape((-1,shape_y[-1]))
x = (x-x.mean(axis=1,keepdims=True))/x.std(axis=1,keepdims=True)
y = (y-y.mean(axis=1,keepdims=True))/y.std(axis=1,keepdims=True)
tmp_corr = x*y
corr_ = tmp_corr.sum(axis=1)
#corr_ = np.zeros((x.shape[0]))
#for i in range(x.shape[0]):
# # Correlation of standardized variables
# corr_[i] = np.correlate((x[i]-x[i].mean())/x[i].std(),(y[i]-y[i].mean())/y[i].std())
correlations = corr_.reshape(*shape_x[:-1])
return correlations
def mean_diff_feature_maps(x,y):
"""Takes two activation matrices of the form BxFxT where B is batch size, F number of filters and T time points
Returns mean difference between feature map activations
Parameters
----------
x : numpy array
Activations Bx[F]xT
y : numpy array
Activations Bx[F]xT
Returns
mean_diff : numpy array
Mean difference between `x` and `y` Bx[F]
"""
mean_diff = np.mean(x-y,axis=2)
return mean_diff
def spectral_perturbation_correlation(pert_fn, diff_fn, pred_fn, n_layers, inputs, n_iterations,
batch_size=30,
seed=((2017, 7, 10))):
"""Calculates perturbation correlations for layers in network by perturbing either amplitudes or phases
Parameters
----------
pert_fn : function
Function that perturbs spectral phase and amplitudes of inputs
diff_fn : function
Function that calculates difference between original and perturbed activations
pred_fn : function
Function that returns a list of activations.
Each entry in the list corresponds to the output of 1 layer in a network
n_layers : int
Number of layers pred_fn returns activations for.
inputs : numpy array
Original inputs that are used for perturbation [B,X,T,1]
Phase perturbations are sampled for each input individually, but applied to all X of that input
n_iterations : int
Number of iterations of correlation computation. The higher the better
batch_size : int
Number of inputs that are used for one forward pass. (Concatenated for all inputs)
Returns
-------
pert_corrs : numpy array
List of length n_layers containing average perturbation correlations over iterations
L x CxFrxFi (Channels,Frequencies,Filters)
"""
rng = np.random.RandomState(seed)
# Get batch indeces
batch_inds = get_balanced_batches(
n_trials=len(inputs), rng=rng, shuffle=False, batch_size=batch_size)
# Calculate layer activations and reshape
log.info("Compute original predictions...")
orig_preds = [pred_fn(inputs[inds])
for inds in batch_inds]
use_shape = []
for l in range(n_layers):
tmp = list(orig_preds[0][l].shape)
tmp.extend([1]*(4-len(tmp)))
tmp[0] = len(inputs)
use_shape.append(tmp)
orig_preds_layers = [np.concatenate([orig_preds[o][l] for o in range(len(orig_preds))]).reshape(use_shape[l])
for l in range(n_layers)]
# Compute FFT of inputs
fft_input = np.fft.rfft(inputs, n=inputs.shape[2], axis=2)
amps = np.abs(fft_input)
phases = np.angle(fft_input)
pert_corrs = [0]*n_layers
for i in range(n_iterations):
log.info("Iteration {:d}...".format(i))
log.info("Sample perturbation...")
amps_pert,phases_pert,pert_vals = pert_fn(amps,phases,rng=rng)
# Compute perturbed inputs
log.info("Compute perturbed complex inputs...")
fft_pert = amps_pert*np.exp(1j*phases_pert)
log.info("Compute perturbed real inputs...")
inputs_pert = np.fft.irfft(fft_pert, n=inputs.shape[2], axis=2).astype(np.float32)
# Calculate layer activations for perturbed inputs
log.info("Compute new predictions...")
new_preds = [pred_fn(inputs_pert[inds])
for inds in batch_inds]
new_preds_layers = [np.concatenate([new_preds[o][l] for o in range(len(new_preds))]).reshape(use_shape[l])
for l in range(n_layers)]
for l in range(n_layers):
log.info("Layer {:d}...".format(l))
# Calculate difference of original and perturbed feature map activations
log.info("Compute activation difference...")
preds_diff = diff_fn(new_preds_layers[l][:,:,:,0],orig_preds_layers[l][:,:,:,0])
# Calculate feature map differences with perturbations
log.info("Compute correlation...")
pert_corrs_tmp = wrap_reshape_apply_fn(corr,
pert_vals[:,:,:,0],preds_diff,
axis_a=(0,), axis_b=(0))
pert_corrs[l] += pert_corrs_tmp
pert_corrs = [pert_corrs[l]/n_iterations for l in range(n_layers)] #mean over iterations
return pert_corrs
def compute_amplitude_prediction_correlations(pred_fn, examples, n_iterations,
perturb_fn=gaussian_perturbation,
perturb_fn=amp_perturbation_additive,
batch_size=30,
seed=((2017, 7, 10))):
"""
@@ -71,46 +283,9 @@ def compute_amplitude_prediction_correlations(pred_fn, examples, n_iterations,
visualization.
arXiv preprint arXiv:1703.05051.
"""
inds_per_batch = get_balanced_batches(
n_trials=len(examples), rng=None, shuffle=False, batch_size=batch_size)
log.info("Compute original predictions...")
orig_preds = [pred_fn(examples[example_inds])
for example_inds in inds_per_batch]
orig_preds_arr = np.concatenate(orig_preds)
rng = RandomState(seed)
fft_input = np.fft.rfft(examples, axis=2)
amps = np.abs(fft_input)
phases = np.angle(fft_input)
pred_fn_new = lambda x: [pred_fn(x)]
pred_corrs = spectral_perturbation_correlation(perturb_fn, mean_diff_feature_maps,
pred_fn_new, 1, examples, n_iterations,
batch_size=batch_size, seed=seed)
amp_pred_corrs = []
for i_iteration in range(n_iterations):
log.info("Iteration {:d}...".format(i_iteration))
log.info("Sample perturbation...")
perturbation = perturb_fn(amps, rng)
log.info("Compute new amplitudes...")
# do not allow perturbation to make amplitudes go below
# zero
perturbation = np.maximum(-amps, perturbation)
new_amps = amps + perturbation
log.info("Compute new complex inputs...")
new_complex = _amplitude_phase_to_complex(new_amps, phases)
log.info("Compute new real inputs...")
new_in = np.fft.irfft(new_complex, axis=2).astype(np.float32)
log.info("Compute new predictions...")
new_preds = [pred_fn(new_in[example_inds])
for example_inds in inds_per_batch]
new_preds_arr = np.concatenate(new_preds)
diff_preds = new_preds_arr - orig_preds_arr
log.info("Compute correlation...")
amp_pred_corr = wrap_reshape_apply_fn(corr, perturbation[:, :, :, 0],
diff_preds,
axis_a=(0,), axis_b=(0))
amp_pred_corrs.append(amp_pred_corr)
return amp_pred_corrs
def _amplitude_phase_to_complex(amplitude, phase):
return amplitude * np.cos(phase) + amplitude * np.sin(phase) * 1j
return pred_corrs[0]
+89
Ver Arquivo
@@ -0,0 +1,89 @@
from scipy.optimize import leastsq
import numpy as np
def err_fn_sin(p,x,y):
return (y-fit_fn_sin(x,*p)).flat
def err_fn_lin(p,x,y):
return (y-fit_fn_lin(x,*p)).flat
def fit_fn_lin(x,*kwargs):
return kwargs[0]+kwargs[1]*x
def fit_fn_sin(x,*kwargs):
freqs = kwargs[0]
amps = kwargs[1]
phases = kwargs[2]
offset = kwargs[3]
sig = np.zeros((len(x)))+offset
sig += amps*np.cos(x*freqs+phases)
return sig
def signal_fit(signals,fs):
"""Fits sinusoid and linear function to signals
see sinfit.fit_fn_sin and sinfit.fit_fn_lin
Parameters
----------
signals : numpy array
[FxCxTx1] Filters x Channels x Time x 1
fs : float
Sampling frequency
Returns
-------
params_sin : numpy array
[FxCx4] Parameters of sinusoid fit
Parameters are: Frequency,Amplitude,Phase,DCOffset
params_lin : numpy array
[FxCx2] Parameters of sinusoid fit
Parameters are: Frequency,Amplitude,Phase,DCOffset
err_sin : numpy array
[FxCx1] MSE for sinusoid fit
err_lin : numpy array
[FxCx1] MSE for linear fit
"""
params_sin = []
params_lin = []
err_sin = []
err_lin = []
freqs = np.fft.rfftfreq(signals.shape[2], d=1.0/fs)[1:]
x = np.linspace(0,signals.shape[2]/fs,signals.shape[2])*2*np.pi
for filt in range(signals.shape[0]):
params_sin_tmp = []
params_lin_tmp = []
err_sin_tmp = []
err_lin_tmp = []
for ch in range(signals.shape[1]):
X_tmp = signals[filt,ch].squeeze()
fft_X = np.fft.rfft(X_tmp,axis=0)
amps_mean = np.abs(fft_X)[1:]
phases_mean = np.angle(fft_X)[1:]
offset = X_tmp.mean()
sort = np.argsort(amps_mean)[::-1][0]
p0 = [freqs[sort],amps_mean[sort],phases_mean[sort],offset]
fit_sin_ch = leastsq(err_fn_sin, p0, args=(x, X_tmp),maxfev=100000)
fit_lin_ch = leastsq(err_fn_lin, [0,0], args=(x, X_tmp),maxfev=100000)
err_sin_ch = np.square(fit_fn_sin(x,*fit_sin_ch[0]) - X_tmp).mean()
err_lin_ch = np.square(fit_fn_lin(x,*fit_lin_ch[0]) - X_tmp).mean()
params_sin_tmp.append(fit_sin_ch[0])
params_lin_tmp.append(fit_lin_ch[0])
err_sin_tmp.append(err_sin_ch)
err_lin_tmp.append(err_lin_ch)
params_sin.append(params_sin_tmp)
params_lin.append(params_lin_tmp)
err_sin.append(err_sin_tmp)
err_lin.append(err_lin_tmp)
params_sin = np.asarray(params_sin)
params_lin = np.asarray(params_lin)
err_sin = np.asarray(err_sin)
err_lin = np.asarray(err_lin)
return params_sin,params_lin,err_sin,err_lin
+5 -23
Ver Arquivo
@@ -553,24 +553,6 @@
"## Plot correlations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the mean correlations across iterations."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"mean_corr = np.mean(amp_pred_corrs, axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -738,21 +720,21 @@
"keep_outputs": true
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 2",
"language": "python",
"name": "python3"
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,