monitors for binary case, experiment works with lists of outputs also

Esse commit está contido em:
Robin Tibor Schirrmeister
2017-11-03 17:19:32 +01:00
commit 3f809ddbeb
3 arquivos alterados com 24 adições e 9 exclusões
+5 -1
Ver Arquivo
@@ -320,7 +320,11 @@ class Experiment(object):
target_vars = target_vars.cuda()
outputs = self.model(input_vars)
loss = self.loss_function(outputs, target_vars)
outputs = outputs.cpu().data.numpy()
if hasattr(outputs, 'cpu'):
outputs = outputs.cpu().data.numpy()
else:
# assume it is iterable
outputs = [o.cpu().data.numpy() for o in outputs]
loss = loss.cpu().data.numpy()
return outputs, loss
+19 -6
Ver Arquivo
@@ -10,10 +10,15 @@ class MisclassMonitor(object):
----------
col_suffix: str, optional
Name of the column in the monitoring output.
threshold_for_binary_case: bool, optional
In case of binary classification with only one output prediction
per target, define the threshold for separating the classes, i.e.
0.5 for sigmoid outputs, or np.log(0.5) for log sigmoid outputs
"""
def __init__(self, col_suffix='misclass'):
def __init__(self, col_suffix='misclass', threshold_for_binary_case=None):
self.col_suffix = col_suffix
self.threshold_for_binary_case = threshold_for_binary_case
def monitor_epoch(self, ):
return
@@ -28,11 +33,19 @@ class MisclassMonitor(object):
# or just
# examples x classes
# make sure not to remove first dimension if it only has size one
only_one_row = preds.shape[0] == 1
pred_labels = np.argmax(preds, axis=1).squeeze()
# add first dimension again if needed
if only_one_row:
pred_labels = pred_labels[None]
if preds.ndim > 1:
only_one_row = preds.shape[0] == 1
pred_labels = np.argmax(preds, axis=1).squeeze()
# add first dimension again if needed
if only_one_row:
pred_labels = pred_labels[None]
else:
assert self.threshold_for_binary_case is not None, (
"In case of only one output, please supply the "
"threshold_for_binary_case parameter")
# binary classification case... assume logits
pred_labels = np.int32(preds > self.threshold_for_binary_case)
# now examples x time or examples
all_pred_labels.extend(pred_labels)
targets = all_targets[i_batch]
-2
Ver Arquivo
@@ -115,8 +115,6 @@ def ax_scalp(v, channels,
image = ax.imshow(zz, vmin=vmin, vmax=vmax, cmap=cmap,
extent=[min(x), max(x), min(y), max(y)], origin='lower',
interpolation=interpolation)
# image = ax.contourf(xx, yy, zz, 100, vmin=vmin, vmax=vmax,
# cmap=colormap)
if scalp_line_width > 0:
# paint the head
ax.add_artist(plt.Circle((0, 0), 1, linestyle=scalp_line_style,