fix dtype handling in 2 optimizers and 1 layer (#7088)
* fix dtype handling in 2 optimizers and 1 layer * fix zeros * Add base_dtype argument * Fix base_dtype * remove base_dtype
Esse commit está contido em:
@@ -550,7 +550,7 @@ def dtype(x):
|
||||
'float32_ref'
|
||||
```
|
||||
"""
|
||||
return x.dtype.name
|
||||
return x.dtype.base_dtype.name
|
||||
|
||||
|
||||
def eval(x):
|
||||
|
||||
@@ -75,7 +75,6 @@ class Embedding(Layer):
|
||||
mask_zero=False,
|
||||
input_length=None,
|
||||
**kwargs):
|
||||
kwargs['dtype'] = 'int32'
|
||||
if 'input_shape' not in kwargs:
|
||||
if input_length:
|
||||
kwargs['input_shape'] = (input_length,)
|
||||
@@ -98,7 +97,8 @@ class Embedding(Layer):
|
||||
initializer=self.embeddings_initializer,
|
||||
name='embeddings',
|
||||
regularizer=self.embeddings_regularizer,
|
||||
constraint=self.embeddings_constraint)
|
||||
constraint=self.embeddings_constraint,
|
||||
dtype=self.dtype)
|
||||
self.built = True
|
||||
|
||||
def compute_mask(self, inputs, mask=None):
|
||||
|
||||
@@ -219,8 +219,7 @@ class RMSprop(Optimizer):
|
||||
|
||||
def get_updates(self, params, constraints, loss):
|
||||
grads = self.get_gradients(loss, params)
|
||||
shapes = [K.get_variable_shape(p) for p in params]
|
||||
accumulators = [K.zeros(shape) for shape in shapes]
|
||||
accumulators = [K.zeros(K.get_variable_shape(p), dtype=K.dtype(p)) for p in params]
|
||||
self.weights = accumulators
|
||||
self.updates = []
|
||||
|
||||
@@ -413,9 +412,8 @@ class Adam(Optimizer):
|
||||
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
|
||||
(1. - K.pow(self.beta_1, t)))
|
||||
|
||||
shapes = [K.get_variable_shape(p) for p in params]
|
||||
ms = [K.zeros(shape) for shape in shapes]
|
||||
vs = [K.zeros(shape) for shape in shapes]
|
||||
ms = [K.zeros(K.get_variable_shape(p), dtype=K.dtype(p)) for p in params]
|
||||
vs = [K.zeros(K.get_variable_shape(p), dtype=K.dtype(p)) for p in params]
|
||||
self.weights = [self.iterations] + ms + vs
|
||||
|
||||
for p, g, m, v in zip(params, grads, ms, vs):
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário