training.py _weighted_masked_objective fix crash when weights is None (#7068)
* training.py _weighted_masked_objective fix crash when weights is None * unit test _weighted_masked_objective function
Esse commit está contido em:
@@ -444,13 +444,12 @@ def _weighted_masked_objective(fn):
|
||||
# to the number of unmasked samples.
|
||||
score_array /= K.mean(mask)
|
||||
|
||||
# reduce score_array to same ndim as weight array
|
||||
ndim = K.ndim(score_array)
|
||||
weight_ndim = K.ndim(weights)
|
||||
score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
|
||||
|
||||
# apply sample weighting
|
||||
if weights is not None:
|
||||
# reduce score_array to same ndim as weight array
|
||||
ndim = K.ndim(score_array)
|
||||
weight_ndim = K.ndim(weights)
|
||||
score_array = K.mean(score_array, axis=list(range(weight_ndim, ndim)))
|
||||
score_array *= weights
|
||||
score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
|
||||
return K.mean(score_array)
|
||||
|
||||
@@ -6,7 +6,9 @@ import scipy.sparse as sparse
|
||||
|
||||
from keras.layers import Dense, Dropout
|
||||
from keras.engine.topology import Input
|
||||
from keras.engine.training import Model, _check_loss_and_target_compatibility
|
||||
from keras.engine.training import Model
|
||||
from keras.engine.training import _check_loss_and_target_compatibility
|
||||
from keras.engine.training import _weighted_masked_objective
|
||||
from keras.models import Sequential
|
||||
from keras import backend as K
|
||||
from keras.utils import Sequence
|
||||
@@ -27,6 +29,18 @@ class RandomSequence(Sequence):
|
||||
np.random.random((self.batch_size, 3))]
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_weighted_masked_objective():
|
||||
a = Input(shape=(3,), name='input_a')
|
||||
|
||||
# weighted_masked_objective
|
||||
def mask_dummy(y_true=None, y_pred=None, weight=None):
|
||||
return K.placeholder(y_true.shape)
|
||||
|
||||
weighted_function = _weighted_masked_objective(K.categorical_crossentropy)
|
||||
weighted_function(a, a, None)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_model_methods():
|
||||
a = Input(shape=(3,), name='input_a')
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário