1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Defines an input type specification for tf.function.""" 16 17import functools 18import itertools 19import weakref 20 21import numpy as np 22import six 23 24from tensorflow.python.framework import composite_tensor 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.ops import resource_variable_ops 29from tensorflow.python.util import _pywrap_utils 30from tensorflow.python.util import nest 31from tensorflow.python.util import tf_decorator 32from tensorflow.python.util import tf_inspect 33 34# Sentinel value used by with ConcreteFunction's structured signature to 35# indicate that a non-tensor parameter should use the value that was 36# specified when the concrete function was created. 37BOUND_VALUE = object() 38 39 40# TODO(b/214462107): Clean up and migrate to core/function when unblocked. 41class FunctionSpec(object): 42 """Specification of how to bind arguments to a function.""" 43 44 @classmethod 45 def from_function_and_signature(cls, python_function, 46 input_signature, 47 is_pure=False, 48 experimental_follow_type_hints=False, 49 jit_compile=None): 50 """Creates a FunctionSpec instance given a python function and signature. 51 52 Args: 53 python_function: a function to inspect 54 input_signature: a signature of the function (None, if variable) 55 is_pure: if True all input arguments (including variables and constants) 56 will be converted to tensors and no variable changes allowed. 57 experimental_follow_type_hints: see `tf.function` 58 jit_compile: see `tf.function` 59 60 Returns: 61 instance of FunctionSpec 62 """ 63 _validate_signature(input_signature) 64 _validate_python_function(python_function, input_signature) 65 66 fullargspec = tf_inspect.getfullargspec(python_function) 67 # Checks if the `fullargspec` contains self or cls as its first argument. 68 is_method = tf_inspect.isanytargetmethod(python_function) 69 70 # Treat a wrapped partial function as a special case. For all arguments that 71 # were overridden with keywords in the partial: 72 # - remove the corresponding arguments, 73 # - remove the corresponding keywords. 74 _, unwrapped = tf_decorator.unwrap(python_function) 75 if isinstance(unwrapped, functools.partial): 76 # Also consider the Python3 case with kwonlydefaults. 77 if fullargspec.defaults or fullargspec.kwonlydefaults: 78 new_defaults = fullargspec.defaults 79 new_args = fullargspec.args 80 if fullargspec.defaults: 81 # To be able to canonicalize the function properly, we want to ignore 82 # default values that are overridden via a partial kwarg. For example: 83 # 84 # def func(a, b, c, d=5, e=7): 85 # return a, b, c, d, e 86 # p_func = tf.function(functools.partial(func, 10, e=9)) 87 # 88 # Here we want to drop from the defaults the parameter `e`. If we 89 # forwarded the call to the partial function with a default for `e` 90 # we would get an error for passing two values for one parameter. 91 # 92 # Note that this has a limitation: we can only override parameters at 93 # the end of the parameter list. 94 # 95 # In this case we want to end up with 3 arguments (b, c, d) and 1 96 # default value (5). We do this by constructing a mask where 0 stands 97 # for a value that was overridden by a partial kwarg. The seemingly 98 # complicated logic below does just that - for arguments (b, c, d, e) 99 # we would get a mask (1, 1, 1, 0). 100 old_args = fullargspec.args 101 old_defaults = fullargspec.defaults 102 103 no_default = object() 104 num_args_without_defaults = len(old_args) - len(old_defaults) 105 left_padding = tuple([no_default] * num_args_without_defaults) 106 107 args_with_defaults = zip(old_args, left_padding + old_defaults) 108 109 # Create a mask where 0 stands for args that had a partial kwarg 110 # defined. 111 non_keyword_defaults_mask = [ 112 0 if key in unwrapped.keywords else 1 for key in old_args 113 ] 114 # Keep only arguments and defaults that were not kwargs of partial. 115 new_args_with_defaults = list( 116 itertools.compress(args_with_defaults, non_keyword_defaults_mask)) 117 # Keep all args. 118 new_args = [arg for arg, _ in new_args_with_defaults] 119 # Keep only real default values. 120 new_defaults = [ 121 default for _, default in new_args_with_defaults 122 if default is not no_default 123 ] 124 fullargspec = tf_inspect.FullArgSpec( 125 args=new_args, 126 varargs=fullargspec.varargs, 127 varkw=fullargspec.varkw, 128 defaults=new_defaults, 129 kwonlyargs=[], 130 kwonlydefaults={}, 131 annotations=fullargspec.annotations) 132 133 # Get the function's name. Remove functools.partial wrappers if necessary. 134 while isinstance(python_function, functools.partial): 135 python_function = python_function.func 136 name = getattr(python_function, "__name__", "f") 137 138 return FunctionSpec( 139 fullargspec, 140 is_method, 141 input_signature, 142 is_pure=is_pure, 143 jit_compile=jit_compile, 144 experimental_follow_type_hints=experimental_follow_type_hints, 145 name=name) 146 147 def __init__(self, 148 fullargspec, 149 is_method, 150 input_signature, 151 is_pure=False, 152 experimental_follow_type_hints=False, 153 name=None, 154 jit_compile=None): 155 """Constructs a FunctionSpec describing a python function. 156 157 Args: 158 fullargspec: `tf_inspect.FullArgSpec` object describing the function. 159 is_method: True if the function is a method. 160 input_signature: a signature of the function (None, if variable) 161 is_pure: if True all input arguments (including variables and constants) 162 will be converted to tensors and no variable changes allowed. 163 experimental_follow_type_hints: see `tf.function`. 164 name: Name of the function 165 jit_compile: see `tf.function`. 166 """ 167 self._fullargspec = fullargspec 168 self._is_method = is_method 169 self._is_pure = is_pure 170 self._jit_compile = jit_compile 171 self._experimental_follow_type_hints = experimental_follow_type_hints 172 173 # TODO(edloper): Include name when serializing for SavedModel? 174 self._name = name or "f" 175 176 if self._is_method: 177 # Remove `self`: default arguments shouldn't be matched to it. 178 # TODO(b/127938157): Should this error out if there is no arg to 179 # be removed? 180 args = fullargspec.args[1:] 181 else: 182 args = fullargspec.args 183 184 # A cache mapping from argument name to index, for canonicalizing 185 # arguments that are called in a keyword-like fashion. 186 self._args_to_indices = {arg: i for i, arg in enumerate(args)} 187 self._arg_names = args 188 189 # A cache mapping from arg index to default value, for canonicalization. 190 default_values = fullargspec.defaults 191 offset = len(args) - len(default_values or []) 192 self._arg_indices_to_default_values = { 193 offset + index: default 194 for index, default in enumerate(default_values or []) 195 } 196 self._arg_indices_no_default_values = set(range(len(args))) - set( 197 self._arg_indices_to_default_values) 198 199 _validate_signature(input_signature) 200 if input_signature is None: 201 self._input_signature = None 202 else: 203 self._input_signature = tuple(input_signature) 204 self._flat_input_signature = tuple(nest.flatten(input_signature, 205 expand_composites=True)) 206 self.validate_input_signature_with_argspec() 207 208 @property 209 def fullargspec(self): 210 return self._fullargspec 211 212 @property 213 def is_method(self): 214 return self._is_method 215 216 @property 217 def args_to_indices(self): 218 return self._args_to_indices 219 220 @property 221 def kwargs_to_include(self): 222 return self._kwargs_to_include 223 224 @property 225 def input_signature(self): 226 return self._input_signature 227 228 @property 229 def flat_input_signature(self): 230 return self._flat_input_signature 231 232 @property 233 def is_pure(self): 234 return self._is_pure 235 236 @property 237 def jit_compile(self): 238 return self._jit_compile 239 240 @property 241 def arg_names(self): 242 return self._arg_names 243 244 @property 245 def vararg_name(self): 246 return self._fullargspec.varargs 247 248 @property 249 def varkw_name(self): 250 return self._fullargspec.varkw 251 252 def signature_summary(self, default_values=False): 253 """Returns a string summarizing this function's signature. 254 255 Args: 256 default_values: If true, then include default values in the signature. 257 258 Returns: 259 A `string`. 260 """ 261 args = list(self._arg_names) 262 if default_values: 263 for (i, default) in self._arg_indices_to_default_values.items(): 264 args[i] += "={}".format(default) 265 if self._fullargspec.kwonlyargs: 266 args.append("*") 267 for arg_name in self._fullargspec.kwonlyargs: 268 args.append(arg_name) 269 if default_values and arg_name in self._fullargspec.kwonlydefaults: 270 args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name]) 271 return f"{self._name}({', '.join(args)})" 272 273 def validate_input_signature_with_argspec(self): 274 """Checks the python_function's args to be valid against input_signature.""" 275 if self.input_signature is not None: 276 arglen = len(self.input_signature) 277 arg_names_len = len(self.arg_names) 278 defaults = self.fullargspec.defaults or () 279 unbound_self_arg = 1 if (not self.is_method and arg_names_len > 0 and 280 self.arg_names[0] == "self") else 0 281 if not all(d is BOUND_VALUE for d in defaults): 282 default_arg_len = len(defaults) 283 required_arg_len = arg_names_len - default_arg_len - unbound_self_arg 284 # The input signature must cover all required function arguments. 285 if arglen < required_arg_len: 286 missing_tensor_specs = self.arg_names[ 287 arglen:required_arg_len] 288 raise TypeError( 289 f"The decorated tf.function has {required_arg_len} " 290 f"required argument(s), but tf.function was only passed an " 291 f"input_signature of length {arglen}. This covers {arglen} " 292 f"required argument(s): {self.arg_names[:arglen]}, " 293 f"but TensorSpecs are still required for the remaining " 294 f"{len(missing_tensor_specs)} argument(s):" 295 f" {missing_tensor_specs}.") 296 297 def _convert_annotated_args_to_tensors(self, args, kwargs): 298 """Attempts to autobox arguments annotated as tf.Tensor.""" 299 if self.input_signature is not None: 300 return 301 302 args = list(args) 303 for i, arg in enumerate(args): 304 # See 305 # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec 306 if i < len(self._fullargspec.args): 307 annotation_key = self._fullargspec.args[i] 308 else: 309 annotation_key = self._fullargspec.varargs 310 arg_annotation = self._fullargspec.annotations.get(annotation_key, None) 311 312 # TODO(rahulkamat): Change to TensorLike (here ans below) 313 if arg_annotation == ops.Tensor: 314 args[i] = _to_tensor_or_tensor_spec(arg) 315 316 for kw, v in kwargs.items(): 317 if kw in self._fullargspec.kwonlyargs or kw in self._fullargspec.args: 318 annotation_key = kw 319 else: 320 annotation_key = self._fullargspec.varkw 321 kwarg_annotation = self._fullargspec.annotations.get(annotation_key, None) 322 if kwarg_annotation == ops.Tensor: 323 kwargs[kw] = _to_tensor_or_tensor_spec(v) 324 return tuple(args), kwargs 325 326 def _validate_inputs(self, flat_inputs): 327 """Raises an error if inputs contain illegal values.""" 328 for inp in flat_inputs: 329 # TODO(b/183107079): Allow these once they're handled properly. 330 if isinstance(inp, weakref.ref): 331 raise ValueError( 332 f"weakref input {inp} not supported for function {self._name}") 333 334 def validate_inputs_with_signature(self, args, kwargs): 335 """Checks args and kwargs against the specified input_signature.""" 336 if kwargs: 337 raise ValueError("Cannot define a TensorFlow function from a Python " 338 "function with keyword arguments when " 339 "input_signature is provided, got keyword arguments " 340 f"({kwargs}) with input_signature " 341 f"({self.input_signature}).") 342 if args: 343 # If args are provided, they must match the input signature. 344 if not is_same_structure(self.input_signature, args): 345 raise ValueError("Structure of Python function inputs does not match " 346 f"input_signature: inputs ({args}), " 347 f"input_signature ({self.input_signature}).") 348 flat_inputs = nest.flatten(args, expand_composites=True) 349 if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec, 350 resource_variable_ops.BaseResourceVariable)) 351 for arg in flat_inputs): 352 raise ValueError("When input_signature is provided, all inputs to " 353 "the Python function must be Tensors, Variables, " 354 "tf.TensorSpec or tf.VariableSpec objects.") 355 if any(not spec.is_compatible_with(other) 356 for spec, other in zip(self.flat_input_signature, flat_inputs)): 357 raise ValueError("Python inputs incompatible with input_signature: " 358 f"inputs ({args}), input_signature " 359 f"({self.input_signature}).") 360 361 def canonicalize_function_inputs(self, args, kwargs): 362 """Canonicalizes `args` and `kwargs`. 363 364 Canonicalize the inputs to the Python function using a `FunctionSpec` 365 instance. In particular, we parse the varargs and kwargs that the 366 original function was called with into a tuple corresponding to the 367 Python function's positional (named) arguments and a dictionary 368 corresponding to its kwargs. Missing default arguments are added. 369 370 If this `FunctionSpec` has an input signature, then it is used to convert 371 arguments to tensors; otherwise, any inputs containing numpy arrays are 372 converted to tensors. 373 374 Additionally, any inputs containing numpy arrays are converted to Tensors. 375 376 Args: 377 args: The varargs this object was called with. 378 kwargs: The keyword args this function was called with. 379 380 Returns: 381 A canonicalized ordering of the inputs, as well as full and filtered 382 (Tensors and Variables only) versions of their concatenated flattened 383 representations, represented by a tuple in the form (args, kwargs, 384 flat_args, filtered_flat_args). Here: `args` is a full list of bound 385 arguments, and `kwargs` contains only true keyword arguments, as opposed 386 to named arguments called in a keyword-like fashion. 387 388 Raises: 389 ValueError: If a keyword in `kwargs` cannot be matched with a positional 390 argument when an input signature is specified, or when the inputs 391 do not conform to the input signature. 392 """ 393 kwargs = {key: kwargs[key] for key in kwargs} 394 if self._is_pure: 395 args, kwargs = _convert_variables_to_tensors(args, kwargs) 396 if self._experimental_follow_type_hints: 397 args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs) 398 # Pre-calculate to reduce overhead 399 arglen = len(args) 400 if self._input_signature is not None: 401 if arglen > len(self._input_signature): 402 raise TypeError(f"{self.signature_summary()} has an input_signature " 403 f"{self._input_signature} which specifies " 404 f"{len(self._input_signature)} positional arguments, " 405 f"but got {arglen}.") 406 for arg in six.iterkeys(kwargs): 407 index = self._args_to_indices.get(arg, None) 408 if index is None: 409 raise TypeError(f"{self.signature_summary()} got unexpected keyword " 410 f"argument `{arg}`.") 411 if index >= len(self._input_signature): 412 raise TypeError( 413 f"{self.signature_summary()} got keyword argument `{arg}` that " 414 "was not included in input_signature.") 415 416 if not kwargs: 417 inputs = args 418 if self._arg_indices_to_default_values: 419 try: 420 inputs += tuple(self._arg_indices_to_default_values[i] 421 for i in range(arglen, len(self._arg_names))) 422 except KeyError: 423 missing_args = [ 424 self._arg_names[i] 425 for i in range(arglen, len(self._arg_names)) 426 if i not in self._arg_indices_to_default_values 427 ] 428 raise TypeError(f"{self.signature_summary()} missing required " 429 f"arguments: {', '.join(missing_args)}.") 430 431 if self._fullargspec.kwonlydefaults: 432 kwargs.update(self._fullargspec.kwonlydefaults) 433 else: 434 # Maps from index of arg to its corresponding value, according to `args` 435 # and `kwargs`; seeded with the default values for the named args that 436 # aren't in `args`. 437 arg_indices_to_values = { 438 index: default for index, default in six.iteritems( 439 self._arg_indices_to_default_values) if index >= arglen 440 } 441 consumed_args = [] 442 missing_arg_indices = self._arg_indices_no_default_values - set( 443 range(arglen)) 444 for arg, value in six.iteritems(kwargs): 445 index = self._args_to_indices.get(arg, None) 446 if index is not None: 447 if index < arglen: 448 raise TypeError(f"{self.signature_summary()} got two values for " 449 f"{arg!r}.") 450 arg_indices_to_values[index] = value 451 # These arguments in 'kwargs' might also belong to 452 # positional arguments 453 missing_arg_indices.discard(index) 454 consumed_args.append(arg) 455 for arg in consumed_args: 456 # After this loop, `kwargs` will only contain keyword_only arguments, 457 # and all positional_or_keyword arguments have been moved to `inputs`. 458 kwargs.pop(arg) 459 inputs = args + _deterministic_dict_values(arg_indices_to_values) 460 # Exclude positional args with values 461 if missing_arg_indices: 462 missing_args = [self._arg_names[i] for i in sorted(missing_arg_indices)] 463 if len(missing_args) == 1: 464 raise TypeError(f"{self.signature_summary()} missing 1 required " 465 f"argument: {missing_args[0]}.") 466 else: 467 raise TypeError(f"{self.signature_summary()} missing required " 468 f"arguments: {', '.join(missing_args)}.") 469 470 if kwargs and self._input_signature is not None: 471 raise TypeError("Keyword arguments are not supported when " 472 "input_signature is provided. Signature: " 473 f"{self.signature_summary()}. Keyword arguments: " 474 f"{kwargs}.") 475 476 if self._fullargspec.kwonlydefaults: 477 for (kwarg, default) in self._fullargspec.kwonlydefaults.items(): 478 kwargs.setdefault(kwarg, default) 479 480 if self._input_signature is None: 481 inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs) 482 kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs) 483 flat_inputs += flat_kwargs 484 filtered_flat_inputs += filtered_flat_kwargs 485 else: 486 inputs, flat_inputs, filtered_flat_inputs = convert_inputs_to_signature( 487 inputs, self._input_signature, self._flat_input_signature) 488 489 self._validate_inputs(flat_inputs) 490 491 return inputs, kwargs, filtered_flat_inputs 492 493 494def _validate_signature(signature): 495 """Checks the input_signature to be valid.""" 496 if signature is None: 497 return 498 499 if not isinstance(signature, (tuple, list)): 500 raise TypeError("input_signature must be either a tuple or a list, got " 501 f"{type(signature)}.") 502 503 if any(not isinstance(arg, tensor_spec.DenseSpec) 504 for arg in nest.flatten(signature, expand_composites=True)): 505 bad_args = [arg for arg in nest.flatten(signature, expand_composites=True) 506 if not isinstance(arg, tensor_spec.DenseSpec)] 507 raise TypeError("input_signature must be a possibly nested sequence of " 508 f"TensorSpec objects, got invalid args {bad_args} with " 509 f"types {list(six.moves.map(type, bad_args))}.") 510 511 512def _validate_python_function(python_function, input_signature): 513 """Checks the python_function to be valid against the input_signature.""" 514 if not callable(python_function): 515 raise TypeError(f"{python_function} is not a callable object.") 516 517 if input_signature is not None: 518 fullargspec = tf_inspect.getfullargspec(python_function) 519 if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()): 520 nodefault_kwonlyargs = set(fullargspec.kwonlyargs) 521 if fullargspec.kwonlydefaults is not None: 522 nodefault_kwonlyargs -= set(fullargspec.kwonlydefaults) 523 raise ValueError("Cannot build TF function from " 524 f"{python_function.__name__}: keyword-only arguments " 525 "must have default values when input_signature is " 526 "provided. Got keyword-only arguments without default " 527 f"values: {sorted(nodefault_kwonlyargs)}.") 528 529 530def is_same_structure(structure1, structure2, check_values=False): 531 """Check two structures for equality, optionally of types and of values.""" 532 try: 533 nest.assert_same_structure(structure1, structure2, expand_composites=True) 534 except (ValueError, TypeError): 535 return False 536 if check_values: 537 flattened1 = nest.flatten(structure1, expand_composites=True) 538 flattened2 = nest.flatten(structure2, expand_composites=True) 539 # First check the types to avoid AttributeErrors. 540 if any(type(f1) is not type(f2) for f1, f2 in zip(flattened1, flattened2)): 541 return False 542 return flattened1 == flattened2 543 return True 544 545 546def _to_tensor_or_tensor_spec(x): 547 return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else 548 ops.convert_to_tensor(x)) 549 550 551def _deterministic_dict_values(dictionary): 552 return tuple(dictionary[key] for key in sorted(dictionary)) 553 554 555def _convert_variables_to_tensors(args, kwargs): 556 args = [_to_tensor_or_tensor_spec(x) for x in args] 557 kwargs = {kw: _to_tensor_or_tensor_spec(x) 558 for kw, x in kwargs.items()} 559 return tuple(args), kwargs 560 561 562def _convert_numpy_inputs(inputs): 563 """Converts numpy array inputs to tensors.""" 564 # We assume that any CompositeTensors have already converted their components 565 # from numpy arrays to Tensors, so we don't need to expand composites here for 566 # the numpy array conversion. Instead, we do so because the flattened inputs 567 # are eventually passed to ConcreteFunction()._call_flat, which requires 568 # expanded composites. 569 flat_inputs = nest.flatten(inputs, expand_composites=True) 570 571 # Check for NumPy arrays in arguments and convert them to Tensors. 572 # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps 573 # finding a way to store them directly in the cache key (currently not 574 # possible since ndarrays are not hashable). 575 need_packing = False 576 filtered_flat_inputs = [] 577 for index, value in enumerate(flat_inputs): 578 if isinstance(value, 579 (ops.Tensor, resource_variable_ops.BaseResourceVariable)): 580 filtered_flat_inputs.append(value) 581 elif hasattr(value, "__array__") and not ( 582 hasattr(value, "_should_act_as_resource_variable") or 583 isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))): 584 # This case is equivalent to _is_ndarray(value) == True 585 a = value.__array__() 586 if not isinstance(a, np.ndarray): 587 raise TypeError(f"The output of __array__ must be an np.ndarray, " 588 f"got {type(a)} from {value}.") 589 flat_inputs[index] = constant_op.constant(a) 590 filtered_flat_inputs.append(flat_inputs[index]) 591 need_packing = True 592 if need_packing: 593 return (nest.pack_sequence_as( 594 structure=inputs, flat_sequence=flat_inputs, 595 expand_composites=True), flat_inputs, filtered_flat_inputs) 596 else: 597 return inputs, flat_inputs, filtered_flat_inputs 598 599 600def convert_inputs_to_signature(inputs, input_signature, flat_input_signature): 601 """Converts inputs to pass into a function with an explicit signature.""" 602 603 def format_error_message(inputs, input_signature): 604 return (" inputs: (\n" + " " + ",\n ".join(str(i) for i in inputs) + 605 ")\n" + " input_signature: (\n" + " " + 606 ",\n ".join(str(i) for i in input_signature) + ")") 607 608 try: 609 flatten_inputs = nest.flatten_up_to( 610 input_signature, 611 inputs[:len(input_signature)], 612 expand_composites=True, 613 check_types=False) # lists are convert to tuples for `tf.data`. 614 except ValueError: 615 raise ValueError("Structure of Python function inputs does not match " 616 "input_signature:\n" 617 f"{format_error_message(inputs, input_signature)}.") 618 619 need_packing = False 620 for index, (value, spec) in enumerate(zip(flatten_inputs, 621 flat_input_signature)): 622 if (isinstance(spec, tensor_spec.TensorSpec) and 623 not isinstance(value, tensor_spec.TensorSpec) and 624 not _pywrap_utils.IsTensor(value)): 625 try: 626 flatten_inputs[index] = ops.convert_to_tensor( 627 value, dtype_hint=spec.dtype) 628 need_packing = True 629 except ValueError: 630 raise ValueError("When input_signature is provided, all inputs to " 631 "the Python function must be convertible to " 632 "tensors:\n" 633 f"{format_error_message(inputs, input_signature)}.") 634 635 if any(not spec.is_compatible_with(other) for spec, other in zip( 636 flat_input_signature, 637 flatten_inputs)): 638 raise ValueError("Python inputs incompatible with input_signature:\n" 639 f"{format_error_message(inputs, input_signature)}.") 640 641 if need_packing: 642 inputs = nest.pack_sequence_as( 643 structure=input_signature, 644 flat_sequence=flatten_inputs, 645 expand_composites=True) 646 647 flat_inputs = nest.flatten(inputs, expand_composites=True) 648 649 return (inputs, flat_inputs, [ 650 t for t in flat_inputs 651 if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 652 ]) 653