monitors for binary case, experiment works with lists of outputs also
Esse commit está contido em:
@@ -320,7 +320,11 @@ class Experiment(object):
|
||||
target_vars = target_vars.cuda()
|
||||
outputs = self.model(input_vars)
|
||||
loss = self.loss_function(outputs, target_vars)
|
||||
outputs = outputs.cpu().data.numpy()
|
||||
if hasattr(outputs, 'cpu'):
|
||||
outputs = outputs.cpu().data.numpy()
|
||||
else:
|
||||
# assume it is iterable
|
||||
outputs = [o.cpu().data.numpy() for o in outputs]
|
||||
loss = loss.cpu().data.numpy()
|
||||
return outputs, loss
|
||||
|
||||
|
||||
@@ -10,10 +10,15 @@ class MisclassMonitor(object):
|
||||
----------
|
||||
col_suffix: str, optional
|
||||
Name of the column in the monitoring output.
|
||||
threshold_for_binary_case: bool, optional
|
||||
In case of binary classification with only one output prediction
|
||||
per target, define the threshold for separating the classes, i.e.
|
||||
0.5 for sigmoid outputs, or np.log(0.5) for log sigmoid outputs
|
||||
"""
|
||||
|
||||
def __init__(self, col_suffix='misclass'):
|
||||
def __init__(self, col_suffix='misclass', threshold_for_binary_case=None):
|
||||
self.col_suffix = col_suffix
|
||||
self.threshold_for_binary_case = threshold_for_binary_case
|
||||
|
||||
def monitor_epoch(self, ):
|
||||
return
|
||||
@@ -28,11 +33,19 @@ class MisclassMonitor(object):
|
||||
# or just
|
||||
# examples x classes
|
||||
# make sure not to remove first dimension if it only has size one
|
||||
only_one_row = preds.shape[0] == 1
|
||||
pred_labels = np.argmax(preds, axis=1).squeeze()
|
||||
# add first dimension again if needed
|
||||
if only_one_row:
|
||||
pred_labels = pred_labels[None]
|
||||
if preds.ndim > 1:
|
||||
only_one_row = preds.shape[0] == 1
|
||||
|
||||
pred_labels = np.argmax(preds, axis=1).squeeze()
|
||||
# add first dimension again if needed
|
||||
if only_one_row:
|
||||
pred_labels = pred_labels[None]
|
||||
else:
|
||||
assert self.threshold_for_binary_case is not None, (
|
||||
"In case of only one output, please supply the "
|
||||
"threshold_for_binary_case parameter")
|
||||
# binary classification case... assume logits
|
||||
pred_labels = np.int32(preds > self.threshold_for_binary_case)
|
||||
# now examples x time or examples
|
||||
all_pred_labels.extend(pred_labels)
|
||||
targets = all_targets[i_batch]
|
||||
|
||||
@@ -115,8 +115,6 @@ def ax_scalp(v, channels,
|
||||
image = ax.imshow(zz, vmin=vmin, vmax=vmax, cmap=cmap,
|
||||
extent=[min(x), max(x), min(y), max(y)], origin='lower',
|
||||
interpolation=interpolation)
|
||||
# image = ax.contourf(xx, yy, zz, 100, vmin=vmin, vmax=vmax,
|
||||
# cmap=colormap)
|
||||
if scalp_line_width > 0:
|
||||
# paint the head
|
||||
ax.add_artist(plt.Circle((0, 0), 1, linestyle=scalp_line_style,
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário