training.py _weighted_masked_objective fix crash when weights is None (#7068)

* training.py _weighted_masked_objective fix crash when weights is None

* unit test _weighted_masked_objective function
Esse commit está contido em:
Andrew Hundt
2017-06-21 16:46:55 -04:00
commit de François Chollet
commit 219d6ee5be
2 arquivos alterados com 19 adições e 6 exclusões
+4 -5
Ver Arquivo
@@ -444,13 +444,12 @@ def _weighted_masked_objective(fn):
# to the number of unmasked samples.
score_array /= K.mean(mask)
# reduce score_array to same ndim as weight array
ndim = K.ndim(score_array)
weight_ndim = K.ndim(weights)
score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
# apply sample weighting
if weights is not None:
# reduce score_array to same ndim as weight array
ndim = K.ndim(score_array)
weight_ndim = K.ndim(weights)
score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
score_array *= weights
score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
return K.mean(score_array)
+15 -1
Ver Arquivo
@@ -6,7 +6,9 @@ import scipy.sparse as sparse
from keras.layers import Dense, Dropout
from keras.engine.topology import Input
from keras.engine.training import Model, _check_loss_and_target_compatibility
from keras.engine.training import Model
from keras.engine.training import _check_loss_and_target_compatibility
from keras.engine.training import _weighted_masked_objective
from keras.models import Sequential
from keras import backend as K
from keras.utils import Sequence
@@ -27,6 +29,18 @@ class RandomSequence(Sequence):
np.random.random((self.batch_size, 3))]
@keras_test
def test_weighted_masked_objective():
a = Input(shape=(3,), name='input_a')
# weighted_masked_objective
def mask_dummy(y_true=None, y_pred=None, weight=None):
return K.placeholder(y_true.shape)
weighted_function = _weighted_masked_objective(K.categorical_crossentropy)
weighted_function(a, a, None)
@keras_test
def test_model_methods():
a = Input(shape=(3,), name='input_a')