• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Defines an input type specification for tf.function."""
16
17import functools
18import itertools
19import weakref
20
21import numpy as np
22import six
23
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.util import _pywrap_utils
30from tensorflow.python.util import nest
31from tensorflow.python.util import tf_decorator
32from tensorflow.python.util import tf_inspect
33
34# Sentinel value used by with ConcreteFunction's structured signature to
35# indicate that a non-tensor parameter should use the value that was
36# specified when the concrete function was created.
37BOUND_VALUE = object()
38
39
40# TODO(b/214462107): Clean up and migrate to core/function when unblocked.
41class FunctionSpec(object):
42  """Specification of how to bind arguments to a function."""
43
44  @classmethod
45  def from_function_and_signature(cls, python_function,
46                                  input_signature,
47                                  is_pure=False,
48                                  experimental_follow_type_hints=False,
49                                  jit_compile=None):
50    """Creates a FunctionSpec instance given a python function and signature.
51
52    Args:
53      python_function: a function to inspect
54      input_signature: a signature of the function (None, if variable)
55      is_pure: if True all input arguments (including variables and constants)
56      will be converted to tensors and no variable changes allowed.
57      experimental_follow_type_hints: see `tf.function`
58      jit_compile: see `tf.function`
59
60    Returns:
61      instance of FunctionSpec
62    """
63    _validate_signature(input_signature)
64    _validate_python_function(python_function, input_signature)
65
66    fullargspec = tf_inspect.getfullargspec(python_function)
67    # Checks if the `fullargspec` contains self or cls as its first argument.
68    is_method = tf_inspect.isanytargetmethod(python_function)
69
70    # Treat a wrapped partial function as a special case. For all arguments that
71    # were overridden with keywords in the partial:
72    #   - remove the corresponding arguments,
73    #   - remove the corresponding keywords.
74    _, unwrapped = tf_decorator.unwrap(python_function)
75    if isinstance(unwrapped, functools.partial):
76      # Also consider the Python3 case with kwonlydefaults.
77      if fullargspec.defaults or fullargspec.kwonlydefaults:
78        new_defaults = fullargspec.defaults
79        new_args = fullargspec.args
80        if fullargspec.defaults:
81          # To be able to canonicalize the function properly, we want to ignore
82          # default values that are overridden via a partial kwarg. For example:
83          #
84          #   def func(a, b, c, d=5, e=7):
85          #     return a, b, c, d, e
86          #   p_func = tf.function(functools.partial(func, 10, e=9))
87          #
88          # Here we want to drop from the defaults the parameter `e`. If we
89          # forwarded the call to the partial function with a default for `e`
90          # we would get an error for passing two values for one parameter.
91          #
92          # Note that this has a limitation: we can only override parameters at
93          # the end of the parameter list.
94          #
95          # In this case we want to end up with 3 arguments (b, c, d) and 1
96          # default value (5). We do this by constructing a mask where 0 stands
97          # for a value that was overridden by a partial kwarg. The seemingly
98          # complicated logic below does just that - for arguments (b, c, d, e)
99          # we would get a mask (1, 1, 1, 0).
100          old_args = fullargspec.args
101          old_defaults = fullargspec.defaults
102
103          no_default = object()
104          num_args_without_defaults = len(old_args) - len(old_defaults)
105          left_padding = tuple([no_default] * num_args_without_defaults)
106
107          args_with_defaults = zip(old_args, left_padding + old_defaults)
108
109          # Create a mask where 0 stands for args that had a partial kwarg
110          # defined.
111          non_keyword_defaults_mask = [
112              0 if key in unwrapped.keywords else 1 for key in old_args
113          ]
114          # Keep only arguments and defaults that were not kwargs of partial.
115          new_args_with_defaults = list(
116              itertools.compress(args_with_defaults, non_keyword_defaults_mask))
117          # Keep all args.
118          new_args = [arg for arg, _ in new_args_with_defaults]
119          # Keep only real default values.
120          new_defaults = [
121              default for _, default in new_args_with_defaults
122              if default is not no_default
123          ]
124        fullargspec = tf_inspect.FullArgSpec(
125            args=new_args,
126            varargs=fullargspec.varargs,
127            varkw=fullargspec.varkw,
128            defaults=new_defaults,
129            kwonlyargs=[],
130            kwonlydefaults={},
131            annotations=fullargspec.annotations)
132
133    # Get the function's name.  Remove functools.partial wrappers if necessary.
134    while isinstance(python_function, functools.partial):
135      python_function = python_function.func
136    name = getattr(python_function, "__name__", "f")
137
138    return FunctionSpec(
139        fullargspec,
140        is_method,
141        input_signature,
142        is_pure=is_pure,
143        jit_compile=jit_compile,
144        experimental_follow_type_hints=experimental_follow_type_hints,
145        name=name)
146
147  def __init__(self,
148               fullargspec,
149               is_method,
150               input_signature,
151               is_pure=False,
152               experimental_follow_type_hints=False,
153               name=None,
154               jit_compile=None):
155    """Constructs a FunctionSpec describing a python function.
156
157    Args:
158      fullargspec: `tf_inspect.FullArgSpec` object describing the function.
159      is_method: True if the function is a method.
160      input_signature: a signature of the function (None, if variable)
161      is_pure: if True all input arguments (including variables and constants)
162        will be converted to tensors and no variable changes allowed.
163      experimental_follow_type_hints: see `tf.function`.
164      name: Name of the function
165      jit_compile: see `tf.function`.
166    """
167    self._fullargspec = fullargspec
168    self._is_method = is_method
169    self._is_pure = is_pure
170    self._jit_compile = jit_compile
171    self._experimental_follow_type_hints = experimental_follow_type_hints
172
173    # TODO(edloper): Include name when serializing for SavedModel?
174    self._name = name or "f"
175
176    if self._is_method:
177      # Remove `self`: default arguments shouldn't be matched to it.
178      # TODO(b/127938157): Should this error out if there is no arg to
179      # be removed?
180      args = fullargspec.args[1:]
181    else:
182      args = fullargspec.args
183
184    # A cache mapping from argument name to index, for canonicalizing
185    # arguments that are called in a keyword-like fashion.
186    self._args_to_indices = {arg: i for i, arg in enumerate(args)}
187    self._arg_names = args
188
189    # A cache mapping from arg index to default value, for canonicalization.
190    default_values = fullargspec.defaults
191    offset = len(args) - len(default_values or [])
192    self._arg_indices_to_default_values = {
193        offset + index: default
194        for index, default in enumerate(default_values or [])
195    }
196    self._arg_indices_no_default_values = set(range(len(args))) - set(
197        self._arg_indices_to_default_values)
198
199    _validate_signature(input_signature)
200    if input_signature is None:
201      self._input_signature = None
202    else:
203      self._input_signature = tuple(input_signature)
204      self._flat_input_signature = tuple(nest.flatten(input_signature,
205                                                      expand_composites=True))
206    self.validate_input_signature_with_argspec()
207
208  @property
209  def fullargspec(self):
210    return self._fullargspec
211
212  @property
213  def is_method(self):
214    return self._is_method
215
216  @property
217  def args_to_indices(self):
218    return self._args_to_indices
219
220  @property
221  def kwargs_to_include(self):
222    return self._kwargs_to_include
223
224  @property
225  def input_signature(self):
226    return self._input_signature
227
228  @property
229  def flat_input_signature(self):
230    return self._flat_input_signature
231
232  @property
233  def is_pure(self):
234    return self._is_pure
235
236  @property
237  def jit_compile(self):
238    return self._jit_compile
239
240  @property
241  def arg_names(self):
242    return self._arg_names
243
244  @property
245  def vararg_name(self):
246    return self._fullargspec.varargs
247
248  @property
249  def varkw_name(self):
250    return self._fullargspec.varkw
251
252  def signature_summary(self, default_values=False):
253    """Returns a string summarizing this function's signature.
254
255    Args:
256      default_values: If true, then include default values in the signature.
257
258    Returns:
259      A `string`.
260    """
261    args = list(self._arg_names)
262    if default_values:
263      for (i, default) in self._arg_indices_to_default_values.items():
264        args[i] += "={}".format(default)
265    if self._fullargspec.kwonlyargs:
266      args.append("*")
267      for arg_name in self._fullargspec.kwonlyargs:
268        args.append(arg_name)
269        if default_values and arg_name in self._fullargspec.kwonlydefaults:
270          args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
271    return f"{self._name}({', '.join(args)})"
272
273  def validate_input_signature_with_argspec(self):
274    """Checks the python_function's args to be valid against input_signature."""
275    if self.input_signature is not None:
276      arglen = len(self.input_signature)
277      arg_names_len = len(self.arg_names)
278      defaults = self.fullargspec.defaults or ()
279      unbound_self_arg = 1 if (not self.is_method and arg_names_len > 0 and
280                               self.arg_names[0] == "self") else 0
281      if not all(d is BOUND_VALUE for d in defaults):
282        default_arg_len = len(defaults)
283        required_arg_len = arg_names_len - default_arg_len - unbound_self_arg
284        # The input signature must cover all required function arguments.
285        if arglen < required_arg_len:
286          missing_tensor_specs = self.arg_names[
287              arglen:required_arg_len]
288          raise TypeError(
289              f"The decorated tf.function has {required_arg_len} "
290              f"required argument(s), but tf.function was only passed an "
291              f"input_signature of length {arglen}. This covers {arglen} "
292              f"required argument(s): {self.arg_names[:arglen]}, "
293              f"but TensorSpecs are still required for the remaining "
294              f"{len(missing_tensor_specs)} argument(s):"
295              f" {missing_tensor_specs}.")
296
297  def _convert_annotated_args_to_tensors(self, args, kwargs):
298    """Attempts to autobox arguments annotated as tf.Tensor."""
299    if self.input_signature is not None:
300      return
301
302    args = list(args)
303    for i, arg in enumerate(args):
304      # See
305      # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
306      if i < len(self._fullargspec.args):
307        annotation_key = self._fullargspec.args[i]
308      else:
309        annotation_key = self._fullargspec.varargs
310      arg_annotation = self._fullargspec.annotations.get(annotation_key, None)
311
312      # TODO(rahulkamat): Change to TensorLike (here ans below)
313      if arg_annotation == ops.Tensor:
314        args[i] = _to_tensor_or_tensor_spec(arg)
315
316    for kw, v in kwargs.items():
317      if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args:
318        annotation_key = kw
319      else:
320        annotation_key = self._fullargspec.varkw
321      kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None)
322      if kwarg_annotation == ops.Tensor:
323        kwargs[kw] = _to_tensor_or_tensor_spec(v)
324    return tuple(args), kwargs
325
326  def _validate_inputs(self, flat_inputs):
327    """Raises an error if inputs contain illegal values."""
328    for inp in flat_inputs:
329      # TODO(b/183107079): Allow these once they're handled properly.
330      if isinstance(inp, weakref.ref):
331        raise ValueError(
332            f"weakref input {inp} not supported for function {self._name}")
333
334  def validate_inputs_with_signature(self, args, kwargs):
335    """Checks args and kwargs against the specified input_signature."""
336    if kwargs:
337      raise ValueError("Cannot define a TensorFlow function from a Python "
338                       "function with keyword arguments when "
339                       "input_signature is provided, got keyword arguments "
340                       f"({kwargs}) with input_signature "
341                       f"({self.input_signature}).")
342    if args:
343      # If args are provided, they must match the input signature.
344      if not is_same_structure(self.input_signature, args):
345        raise ValueError("Structure of Python function inputs does not match "
346                         f"input_signature: inputs ({args}), "
347                         f"input_signature ({self.input_signature}).")
348      flat_inputs = nest.flatten(args, expand_composites=True)
349      if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec,
350                                  resource_variable_ops.BaseResourceVariable))
351             for arg in flat_inputs):
352        raise ValueError("When input_signature is provided, all inputs to "
353                         "the Python function must be Tensors, Variables, "
354                         "tf.TensorSpec or tf.VariableSpec objects.")
355      if any(not spec.is_compatible_with(other)
356             for spec, other in zip(self.flat_input_signature, flat_inputs)):
357        raise ValueError("Python inputs incompatible with input_signature: "
358                         f"inputs ({args}), input_signature "
359                         f"({self.input_signature}).")
360
361  def canonicalize_function_inputs(self, args, kwargs):
362    """Canonicalizes `args` and `kwargs`.
363
364    Canonicalize the inputs to the Python function using a `FunctionSpec`
365    instance. In particular, we parse the varargs and kwargs that the
366    original function was called with into a tuple corresponding to the
367    Python function's positional (named) arguments and a dictionary
368    corresponding to its kwargs.  Missing default arguments are added.
369
370    If this `FunctionSpec` has an input signature, then it is used to convert
371    arguments to tensors; otherwise, any inputs containing numpy arrays are
372    converted to tensors.
373
374    Additionally, any inputs containing numpy arrays are converted to Tensors.
375
376    Args:
377      args: The varargs this object was called with.
378      kwargs: The keyword args this function was called with.
379
380    Returns:
381      A canonicalized ordering of the inputs, as well as full and filtered
382      (Tensors and Variables only) versions of their concatenated flattened
383      representations, represented by a tuple in the form (args, kwargs,
384      flat_args, filtered_flat_args). Here: `args` is a full list of bound
385      arguments, and `kwargs` contains only true keyword arguments, as opposed
386      to named arguments called in a keyword-like fashion.
387
388    Raises:
389      ValueError: If a keyword in `kwargs` cannot be matched with a positional
390        argument when an input signature is specified, or when the inputs
391        do not conform to the input signature.
392    """
393    kwargs = {key: kwargs[key] for key in kwargs}
394    if self._is_pure:
395      args, kwargs = _convert_variables_to_tensors(args, kwargs)
396    if self._experimental_follow_type_hints:
397      args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
398    # Pre-calculate to reduce overhead
399    arglen = len(args)
400    if self._input_signature is not None:
401      if arglen > len(self._input_signature):
402        raise TypeError(f"{self.signature_summary()} has an input_signature "
403                        f"{self._input_signature} which specifies "
404                        f"{len(self._input_signature)} positional arguments, "
405                        f"but got {arglen}.")
406      for arg in six.iterkeys(kwargs):
407        index = self._args_to_indices.get(arg, None)
408        if index is None:
409          raise TypeError(f"{self.signature_summary()} got unexpected keyword "
410                          f"argument `{arg}`.")
411        if index >= len(self._input_signature):
412          raise TypeError(
413              f"{self.signature_summary()} got keyword argument `{arg}` that "
414              "was not included in input_signature.")
415
416    if not kwargs:
417      inputs = args
418      if self._arg_indices_to_default_values:
419        try:
420          inputs += tuple(self._arg_indices_to_default_values[i]
421                          for i in range(arglen, len(self._arg_names)))
422        except KeyError:
423          missing_args = [
424              self._arg_names[i]
425              for i in range(arglen, len(self._arg_names))
426              if i not in self._arg_indices_to_default_values
427          ]
428          raise TypeError(f"{self.signature_summary()} missing required "
429                          f"arguments: {', '.join(missing_args)}.")
430
431      if self._fullargspec.kwonlydefaults:
432        kwargs.update(self._fullargspec.kwonlydefaults)
433    else:
434      # Maps from index of arg to its corresponding value, according to `args`
435      # and `kwargs`; seeded with the default values for the named args that
436      # aren't in `args`.
437      arg_indices_to_values = {
438          index: default for index, default in six.iteritems(
439              self._arg_indices_to_default_values) if index >= arglen
440      }
441      consumed_args = []
442      missing_arg_indices = self._arg_indices_no_default_values - set(
443          range(arglen))
444      for arg, value in six.iteritems(kwargs):
445        index = self._args_to_indices.get(arg, None)
446        if index is not None:
447          if index < arglen:
448            raise TypeError(f"{self.signature_summary()} got two values for "
449                            f"{arg!r}.")
450          arg_indices_to_values[index] = value
451          # These arguments in 'kwargs' might also belong to
452          # positional arguments
453          missing_arg_indices.discard(index)
454          consumed_args.append(arg)
455      for arg in consumed_args:
456        # After this loop, `kwargs` will only contain keyword_only arguments,
457        # and all positional_or_keyword arguments have been moved to `inputs`.
458        kwargs.pop(arg)
459      inputs = args + _deterministic_dict_values(arg_indices_to_values)
460      # Exclude positional args with values
461      if missing_arg_indices:
462        missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)]
463        if len(missing_args) == 1:
464          raise TypeError(f"{self.signature_summary()} missing 1 required "
465                          f"argument: {missing_args[0]}.")
466        else:
467          raise TypeError(f"{self.signature_summary()} missing required "
468                          f"arguments: {', '.join(missing_args)}.")
469
470      if kwargs and self._input_signature is not None:
471        raise TypeError("Keyword arguments are not supported when "
472                        "input_signature is provided. Signature: "
473                        f"{self.signature_summary()}. Keyword arguments: "
474                        f"{kwargs}.")
475
476      if self._fullargspec.kwonlydefaults:
477        for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
478          kwargs.setdefault(kwarg, default)
479
480    if self._input_signature is None:
481      inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs)
482      kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs)
483      flat_inputs += flat_kwargs
484      filtered_flat_inputs += filtered_flat_kwargs
485    else:
486      inputs, flat_inputs, filtered_flat_inputs = convert_inputs_to_signature(
487          inputs, self._input_signature, self._flat_input_signature)
488
489    self._validate_inputs(flat_inputs)
490
491    return inputs, kwargs, filtered_flat_inputs
492
493
494def _validate_signature(signature):
495  """Checks the input_signature to be valid."""
496  if signature is None:
497    return
498
499  if not isinstance(signature, (tuple, list)):
500    raise TypeError("input_signature must be either a tuple or a list, got "
501                    f"{type(signature)}.")
502
503  if any(not isinstance(arg, tensor_spec.DenseSpec)
504         for arg in nest.flatten(signature, expand_composites=True)):
505    bad_args = [arg for arg in nest.flatten(signature, expand_composites=True)
506                if not isinstance(arg, tensor_spec.DenseSpec)]
507    raise TypeError("input_signature must be a possibly nested sequence of "
508                    f"TensorSpec objects, got invalid args {bad_args} with "
509                    f"types {list(six.moves.map(type, bad_args))}.")
510
511
512def _validate_python_function(python_function, input_signature):
513  """Checks the python_function to be valid against the input_signature."""
514  if not callable(python_function):
515    raise TypeError(f"{python_function} is not a callable object.")
516
517  if input_signature is not None:
518    fullargspec = tf_inspect.getfullargspec(python_function)
519    if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
520      nodefault_kwonlyargs = set(fullargspec.kwonlyargs)
521      if fullargspec.kwonlydefaults is not None:
522        nodefault_kwonlyargs -= set(fullargspec.kwonlydefaults)
523      raise ValueError("Cannot build TF function from "
524                       f"{python_function.__name__}: keyword-only arguments "
525                       "must have default values when input_signature is "
526                       "provided. Got keyword-only arguments without default "
527                       f"values: {sorted(nodefault_kwonlyargs)}.")
528
529
530def is_same_structure(structure1, structure2, check_values=False):
531  """Check two structures for equality, optionally of types and of values."""
532  try:
533    nest.assert_same_structure(structure1, structure2, expand_composites=True)
534  except (ValueError, TypeError):
535    return False
536  if check_values:
537    flattened1 = nest.flatten(structure1, expand_composites=True)
538    flattened2 = nest.flatten(structure2, expand_composites=True)
539    # First check the types to avoid AttributeErrors.
540    if any(type(f1) is not type(f2) for f1, f2 in zip(flattened1, flattened2)):
541      return False
542    return flattened1 == flattened2
543  return True
544
545
546def _to_tensor_or_tensor_spec(x):
547  return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else
548          ops.convert_to_tensor(x))
549
550
551def _deterministic_dict_values(dictionary):
552  return tuple(dictionary[key] for key in sorted(dictionary))
553
554
555def _convert_variables_to_tensors(args, kwargs):
556  args = [_to_tensor_or_tensor_spec(x) for x in args]
557  kwargs = {kw: _to_tensor_or_tensor_spec(x)
558            for kw, x in kwargs.items()}
559  return tuple(args), kwargs
560
561
562def _convert_numpy_inputs(inputs):
563  """Converts numpy array inputs to tensors."""
564  # We assume that any CompositeTensors have already converted their components
565  # from numpy arrays to Tensors, so we don't need to expand composites here for
566  # the numpy array conversion. Instead, we do so because the flattened inputs
567  # are eventually passed to ConcreteFunction()._call_flat, which requires
568  # expanded composites.
569  flat_inputs = nest.flatten(inputs, expand_composites=True)
570
571  # Check for NumPy arrays in arguments and convert them to Tensors.
572  # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
573  # finding a way to store them directly in the cache key (currently not
574  # possible since ndarrays are not hashable).
575  need_packing = False
576  filtered_flat_inputs = []
577  for index, value in enumerate(flat_inputs):
578    if isinstance(value,
579                  (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
580      filtered_flat_inputs.append(value)
581    elif hasattr(value, "__array__") and not (
582        hasattr(value, "_should_act_as_resource_variable") or
583        isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))):
584      # This case is equivalent to _is_ndarray(value) == True
585      a = value.__array__()
586      if not isinstance(a, np.ndarray):
587        raise TypeError(f"The output of __array__ must be an np.ndarray, "
588                        f"got {type(a)} from {value}.")
589      flat_inputs[index] = constant_op.constant(a)
590      filtered_flat_inputs.append(flat_inputs[index])
591      need_packing = True
592  if need_packing:
593    return (nest.pack_sequence_as(
594        structure=inputs, flat_sequence=flat_inputs,
595        expand_composites=True), flat_inputs, filtered_flat_inputs)
596  else:
597    return inputs, flat_inputs, filtered_flat_inputs
598
599
600def convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
601  """Converts inputs to pass into a function with an explicit signature."""
602
603  def format_error_message(inputs, input_signature):
604    return ("  inputs: (\n" + "    " + ",\n    ".join(str(i) for i in inputs) +
605            ")\n" + "  input_signature: (\n" + "    " +
606            ",\n    ".join(str(i) for i in input_signature) + ")")
607
608  try:
609    flatten_inputs = nest.flatten_up_to(
610        input_signature,
611        inputs[:len(input_signature)],
612        expand_composites=True,
613        check_types=False)  # lists are convert to tuples for `tf.data`.
614  except ValueError:
615    raise ValueError("Structure of Python function inputs does not match "
616                     "input_signature:\n"
617                     f"{format_error_message(inputs, input_signature)}.")
618
619  need_packing = False
620  for index, (value, spec) in enumerate(zip(flatten_inputs,
621                                            flat_input_signature)):
622    if (isinstance(spec, tensor_spec.TensorSpec) and
623        not isinstance(value, tensor_spec.TensorSpec) and
624        not _pywrap_utils.IsTensor(value)):
625      try:
626        flatten_inputs[index] = ops.convert_to_tensor(
627            value, dtype_hint=spec.dtype)
628        need_packing = True
629      except ValueError:
630        raise ValueError("When input_signature is provided, all inputs to "
631                         "the Python function must be convertible to "
632                         "tensors:\n"
633                         f"{format_error_message(inputs, input_signature)}.")
634
635  if any(not spec.is_compatible_with(other) for spec, other in zip(
636      flat_input_signature,
637      flatten_inputs)):
638    raise ValueError("Python inputs incompatible with input_signature:\n"
639                     f"{format_error_message(inputs, input_signature)}.")
640
641  if need_packing:
642    inputs = nest.pack_sequence_as(
643        structure=input_signature,
644        flat_sequence=flatten_inputs,
645        expand_composites=True)
646
647  flat_inputs = nest.flatten(inputs, expand_composites=True)
648
649  return (inputs, flat_inputs, [
650      t for t in flat_inputs
651      if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
652  ])
653