Avoid DeprecationWarning from inspect.getargspec (rebased) (#7035)
* Utility function to check if a callable has a given keyword argument * Added unit tests for the has_arg function * Replace uses of getargspec with the new has_arg function Not changing keras.backend, because that gives ImportErrors due to a circular import (conv_utils uses the backend, and is imported before generic_utils in utils/__init__.py) Not changing keras.utils.test_utils, because that change exposes (what looks to me like) a latent bug * Replace incorrect use of getargspec in test_utils.py The previous code would always fail to detect the 'weights' argument. Simply replacing getargspec would cause the tests for some of the legacy layers to fail because the passed 'weights' argument is bad. Instead, I have added a check for whether the passed `weights` array is empty, this avoids tripping the bug. * Replacing getargspec with has_arg in the backend modules This requires reordering imports to avoid errors caused by conv_utils trying to import the backend, the backend wanting to import generic_utils, and utils/__init__.py listing conv_utils before generic_utils. * Removed getargspec from legacy wrapping function Instead save the wrapped function in an attribute and call getargspec on this attribute during documentation generation.
Esse commit está contido em:
+4
-2
@@ -320,9 +320,11 @@ def get_classes_ancestors(classes):
|
||||
|
||||
|
||||
def get_function_signature(function, method=True):
|
||||
signature = getattr(function, '_legacy_support_signature', None)
|
||||
if signature is None:
|
||||
wrapped = getattr(function, '_original_function', None)
|
||||
if wrapped is None:
|
||||
signature = inspect.getargspec(function)
|
||||
else:
|
||||
signature = inspect.getargspec(wrapped)
|
||||
defaults = signature.defaults
|
||||
if method:
|
||||
args = signature.args[1:]
|
||||
|
||||
+1
-1
@@ -1,5 +1,6 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from . import utils
|
||||
from . import activations
|
||||
from . import applications
|
||||
from . import backend
|
||||
@@ -7,7 +8,6 @@ from . import datasets
|
||||
from . import engine
|
||||
from . import layers
|
||||
from . import preprocessing
|
||||
from . import utils
|
||||
from . import wrappers
|
||||
from . import callbacks
|
||||
from . import constraints
|
||||
|
||||
@@ -7,13 +7,14 @@ from tensorflow.python.ops import ctc_ops as ctc
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
|
||||
from collections import defaultdict
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from .common import floatx
|
||||
from .common import _EPSILON
|
||||
from .common import image_data_format
|
||||
from ..utils.generic_utils import has_arg
|
||||
|
||||
# Legacy functions
|
||||
from .common import set_image_dim_ordering
|
||||
@@ -2285,8 +2286,7 @@ def function(inputs, outputs, updates=None, **kwargs):
|
||||
"""
|
||||
if kwargs:
|
||||
for key in kwargs:
|
||||
if (key not in inspect.getargspec(tf.Session.run)[0] and
|
||||
key not in inspect.getargspec(Function.__init__)[0]):
|
||||
if not (has_arg(tf.Session.run, key, True) or has_arg(Function.__init__, key, True)):
|
||||
msg = 'Invalid argument "%s" passed to K.function with Tensorflow backend' % key
|
||||
raise ValueError(msg)
|
||||
return Function(inputs, outputs, updates=updates, **kwargs)
|
||||
|
||||
@@ -14,9 +14,10 @@ try:
|
||||
from theano.tensor.nnet.nnet import softsign as T_softsign
|
||||
except ImportError:
|
||||
from theano.sandbox.softsign import softsign as T_softsign
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
from .common import _FLOATX, floatx, _EPSILON, image_data_format
|
||||
from ..utils.generic_utils import has_arg
|
||||
# Legacy functions
|
||||
from .common import set_image_dim_ordering, image_dim_ordering
|
||||
|
||||
@@ -1194,9 +1195,8 @@ class Function(object):
|
||||
|
||||
def function(inputs, outputs, updates=[], **kwargs):
|
||||
if len(kwargs) > 0:
|
||||
function_args = inspect.getargspec(theano.function)[0]
|
||||
for key in kwargs.keys():
|
||||
if key not in function_args:
|
||||
if not has_arg(theano.function, key, True):
|
||||
msg = 'Invalid argument "%s" passed to K.function with Theano backend' % key
|
||||
raise ValueError(msg)
|
||||
return Function(inputs, outputs, updates=updates, **kwargs)
|
||||
|
||||
@@ -10,13 +10,13 @@ import warnings
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import inspect
|
||||
from six.moves import zip
|
||||
|
||||
from .. import backend as K
|
||||
from .. import initializers
|
||||
from ..utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from ..utils.layer_utils import print_summary as print_layer_summary
|
||||
from ..utils.generic_utils import has_arg
|
||||
from ..utils import conv_utils
|
||||
from ..legacy import interfaces
|
||||
|
||||
@@ -584,7 +584,7 @@ class Layer(object):
|
||||
user_kwargs = copy.copy(kwargs)
|
||||
if not _is_all_none(previous_mask):
|
||||
# The previous layer generated a mask.
|
||||
if 'mask' in inspect.getargspec(self.call).args:
|
||||
if has_arg(self.call, 'mask'):
|
||||
if 'mask' not in kwargs:
|
||||
# If mask is explicitly passed to __call__,
|
||||
# we should override the default mask.
|
||||
@@ -2206,7 +2206,7 @@ class Container(Layer):
|
||||
kwargs = {}
|
||||
if len(computed_data) == 1:
|
||||
computed_tensor, computed_mask = computed_data[0]
|
||||
if 'mask' in inspect.getargspec(layer.call).args:
|
||||
if has_arg(layer.call, 'mask'):
|
||||
if 'mask' not in kwargs:
|
||||
kwargs['mask'] = computed_mask
|
||||
output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
|
||||
@@ -2217,7 +2217,7 @@ class Container(Layer):
|
||||
else:
|
||||
computed_tensors = [x[0] for x in computed_data]
|
||||
computed_masks = [x[1] for x in computed_data]
|
||||
if 'mask' in inspect.getargspec(layer.call).args:
|
||||
if has_arg(layer.call, 'mask'):
|
||||
if 'mask' not in kwargs:
|
||||
kwargs['mask'] = computed_masks
|
||||
output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import division
|
||||
import numpy as np
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import types as python_types
|
||||
import warnings
|
||||
|
||||
@@ -19,6 +18,7 @@ from ..engine import Layer
|
||||
from ..utils.generic_utils import func_dump
|
||||
from ..utils.generic_utils import func_load
|
||||
from ..utils.generic_utils import deserialize_keras_object
|
||||
from ..utils.generic_utils import has_arg
|
||||
from ..legacy import interfaces
|
||||
|
||||
|
||||
@@ -642,8 +642,7 @@ class Lambda(Layer):
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
arguments = self.arguments
|
||||
arg_spec = inspect.getargspec(self.function)
|
||||
if 'mask' in arg_spec.args:
|
||||
if has_arg(self.function, 'mask'):
|
||||
arguments['mask'] = mask
|
||||
return self.function(inputs, **arguments)
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from ..engine import Layer
|
||||
from ..engine import InputSpec
|
||||
from ..utils.generic_utils import has_arg
|
||||
from .. import backend as K
|
||||
|
||||
|
||||
@@ -155,8 +155,7 @@ class TimeDistributed(Wrapper):
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
kwargs = {}
|
||||
func_args = inspect.getargspec(self.layer.call).args
|
||||
if 'training' in func_args:
|
||||
if has_arg(self.layer.call, 'training'):
|
||||
kwargs['training'] = training
|
||||
uses_learning_phase = False
|
||||
|
||||
@@ -272,10 +271,9 @@ class Bidirectional(Wrapper):
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
kwargs = {}
|
||||
func_args = inspect.getargspec(self.layer.call).args
|
||||
if 'training' in func_args:
|
||||
if has_arg(self.layer.call, 'training'):
|
||||
kwargs['training'] = training
|
||||
if 'mask' in func_args:
|
||||
if has_arg(self.layer.call, 'mask'):
|
||||
kwargs['mask'] = mask
|
||||
|
||||
y = self.forward_layer.call(inputs, **kwargs)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import six
|
||||
import warnings
|
||||
import functools
|
||||
import inspect
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -86,7 +85,7 @@ def generate_legacy_interface(allowed_positional_args=None,
|
||||
warnings.warn('Update your `' + object_name +
|
||||
'` call to the Keras 2 API: ' + signature, stacklevel=2)
|
||||
return func(*args, **kwargs)
|
||||
wrapper._legacy_support_signature = inspect.getargspec(func)
|
||||
wrapper._original_function = func
|
||||
return wrapper
|
||||
return legacy_support
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import inspect
|
||||
import types as python_types
|
||||
import warnings
|
||||
|
||||
from ..engine.topology import Layer, InputSpec
|
||||
from .. import backend as K
|
||||
from ..utils.generic_utils import func_dump, func_load
|
||||
from ..utils.generic_utils import func_dump, func_load, has_arg
|
||||
from .. import regularizers
|
||||
from .. import constraints
|
||||
from .. import activations
|
||||
@@ -197,8 +196,7 @@ class Merge(Layer):
|
||||
# Case: "mode" is a lambda or function.
|
||||
if callable(self.mode):
|
||||
arguments = self.arguments
|
||||
arg_spec = inspect.getargspec(self.mode)
|
||||
if 'mask' in arg_spec.args:
|
||||
if has_arg(self.mode, 'mask'):
|
||||
arguments['mask'] = mask
|
||||
return self.mode(inputs, **arguments)
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import absolute_import
|
||||
from . import np_utils
|
||||
from . import conv_utils
|
||||
from . import data_utils
|
||||
from . import generic_utils
|
||||
from . import data_utils
|
||||
from . import io_utils
|
||||
from . import conv_utils
|
||||
|
||||
# Globally-importable utils.
|
||||
from .io_utils import HDF5Matrix
|
||||
|
||||
@@ -132,9 +132,8 @@ def deserialize_keras_object(identifier, module_objects=None,
|
||||
raise ValueError('Unknown ' + printable_module_name +
|
||||
': ' + class_name)
|
||||
if hasattr(cls, 'from_config'):
|
||||
arg_spec = inspect.getargspec(cls.from_config)
|
||||
custom_objects = custom_objects or {}
|
||||
if 'custom_objects' in arg_spec.args:
|
||||
if has_arg(cls.from_config, 'custom_objects'):
|
||||
return cls.from_config(config['config'],
|
||||
custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
||||
list(custom_objects.items())))
|
||||
@@ -207,6 +206,48 @@ def func_load(code, defaults=None, closure=None, globs=None):
|
||||
closure=closure)
|
||||
|
||||
|
||||
def has_arg(fn, name, accept_all=False):
|
||||
"""Checks if a callable accepts a given keyword argument.
|
||||
|
||||
For Python 2, checks if there is an argument with the given name.
|
||||
|
||||
For Python 3, checks if there is an argument with the given name, and
|
||||
also whether this argument can be called with a keyword (i.e. if it is
|
||||
not a positional-only argument).
|
||||
|
||||
# Arguments
|
||||
fn: Callable to inspect.
|
||||
name: Check if `fn` can be called with `name` as a keyword argument.
|
||||
accept_all: What to return if there is no parameter called `name`
|
||||
but the function accepts a `**kwargs` argument.
|
||||
|
||||
# Returns
|
||||
bool, whether `fn` accepts a `name` keyword argument.
|
||||
"""
|
||||
if sys.version_info < (3,):
|
||||
arg_spec = inspect.getargspec(fn)
|
||||
if accept_all and arg_spec.keywords is not None:
|
||||
return True
|
||||
return (name in arg_spec.args)
|
||||
elif sys.version_info < (3, 3):
|
||||
arg_spec = inspect.getfullargspec(fn)
|
||||
if accept_all and arg_spec.varkw is not None:
|
||||
return True
|
||||
return (name in arg_spec.args or
|
||||
name in arg_spec.kwonlyargs)
|
||||
else:
|
||||
signature = inspect.signature(fn)
|
||||
parameter = signature.parameters.get(name)
|
||||
if parameter is None:
|
||||
if accept_all:
|
||||
for param in signature.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
return True
|
||||
return False
|
||||
return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY))
|
||||
|
||||
|
||||
class Progbar(object):
|
||||
"""Displays a progress bar.
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Utilities related to Keras unit tests."""
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
import inspect
|
||||
import six
|
||||
|
||||
from .generic_utils import has_arg
|
||||
from ..engine import Model, Input
|
||||
from ..models import Sequential
|
||||
from ..models import model_from_json
|
||||
@@ -71,7 +71,9 @@ def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
|
||||
layer.set_weights(weights)
|
||||
|
||||
# test and instantiation from weights
|
||||
if 'weights' in inspect.getargspec(layer_cls.__init__):
|
||||
# Checking for empty weights array to avoid a problem where some
|
||||
# legacy layers return bad values from get_weights()
|
||||
if has_arg(layer_cls.__init__, 'weights') and len(weights):
|
||||
kwargs['weights'] = weights
|
||||
layer = layer_cls(**kwargs)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils.np_utils import to_categorical
|
||||
from ..utils.generic_utils import has_arg
|
||||
from ..models import Sequential
|
||||
|
||||
|
||||
@@ -75,13 +75,11 @@ class BaseWrapper(object):
|
||||
else:
|
||||
legal_params_fns.append(self.build_fn)
|
||||
|
||||
legal_params = []
|
||||
for fn in legal_params_fns:
|
||||
legal_params += inspect.getargspec(fn)[0]
|
||||
legal_params = set(legal_params)
|
||||
|
||||
for params_name in params:
|
||||
if params_name not in legal_params:
|
||||
for fn in legal_params_fns:
|
||||
if has_arg(fn, params_name):
|
||||
break
|
||||
else:
|
||||
if params_name != 'nb_epoch':
|
||||
raise ValueError(
|
||||
'{} is not a legal parameter'.format(params_name))
|
||||
@@ -163,9 +161,8 @@ class BaseWrapper(object):
|
||||
"""
|
||||
override = override or {}
|
||||
res = {}
|
||||
fn_args = inspect.getargspec(fn)[0]
|
||||
for name, value in self.sk_params.items():
|
||||
if name in fn_args:
|
||||
if has_arg(fn, name):
|
||||
res.update({name: value})
|
||||
res.update(override)
|
||||
return res
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import pytest
|
||||
from keras.utils.generic_utils import custom_object_scope
|
||||
from keras.utils.generic_utils import custom_object_scope, has_arg
|
||||
from keras import activations
|
||||
from keras import regularizers
|
||||
|
||||
@@ -20,5 +21,46 @@ def test_custom_objects_scope():
|
||||
assert cl.__class__ == CustomClass
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fn, name, accept_all, expected', [
|
||||
('f(x)', 'x', False, True),
|
||||
('f(x)', 'y', False, False),
|
||||
('f(x)', 'y', True, False),
|
||||
('f(x, y)', 'y', False, True),
|
||||
('f(x, y=1)', 'y', False, True),
|
||||
('f(x, **kwargs)', 'x', False, True),
|
||||
('f(x, **kwargs)', 'y', False, False),
|
||||
('f(x, **kwargs)', 'y', True, True),
|
||||
('f(x, y=1, **kwargs)', 'y', False, True),
|
||||
# Keyword-only arguments (Python 3 only)
|
||||
('f(x, *args, y=1)', 'y', False, True),
|
||||
('f(x, *args, y=1)', 'z', True, False),
|
||||
('f(x, *, y=1)', 'x', False, True),
|
||||
('f(x, *, y=1)', 'y', False, True),
|
||||
# lambda
|
||||
(lambda x: x, 'x', False, True),
|
||||
(lambda x: x, 'y', False, False),
|
||||
(lambda x: x, 'y', True, False),
|
||||
])
|
||||
def test_has_arg(fn, name, accept_all, expected):
|
||||
if isinstance(fn, str):
|
||||
context = dict()
|
||||
try:
|
||||
exec('def {}: pass'.format(fn), context)
|
||||
except SyntaxError:
|
||||
if sys.version_info >= (3,):
|
||||
raise
|
||||
pytest.skip('Function is not compatible with Python 2')
|
||||
context.pop('__builtins__', None) # Sometimes exec adds builtins to the context
|
||||
fn, = context.values()
|
||||
|
||||
assert has_arg(fn, name, accept_all) is expected
|
||||
|
||||
|
||||
@pytest.mark.xfail(sys.version_info < (3, 3),
|
||||
reason='inspect API does not reveal positional-only arguments')
|
||||
def test_has_arg_positional_only():
|
||||
assert has_arg(pow, 'x') is False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário