1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras Input Tensor used to track functional API Topology.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.framework import type_spec as type_spec_module 28from tensorflow.python.keras.utils import object_identity 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.util import nest 33 34# pylint: disable=g-classes-have-attributes 35 36_KERAS_TENSORS_ENABLED = True 37 38 39def enable_keras_tensors(): 40 """Enable using KerasTensors in Keras's functional API.""" 41 global _KERAS_TENSORS_ENABLED 42 _KERAS_TENSORS_ENABLED = True 43 44 45def disable_keras_tensors(): 46 """Disable using KerasTensors in Keras's functional API.""" 47 global _KERAS_TENSORS_ENABLED 48 _KERAS_TENSORS_ENABLED = False 49 50 51def keras_tensors_enabled(): 52 """Return a bool specifying if KerasTensors are enabled.""" 53 return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions() 54 55 56# Tensorflow tensors have a maximum rank of 254 57# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h ) 58# So we do not try to infer values for int32 tensors larger than this, 59# As they cannot represent shapes. 60_MAX_TENSOR_RANK = 254 61 62 63class KerasTensor(object): 64 """A representation of a Keras in/output during Functional API construction. 65 66 `KerasTensor`s are tensor-like objects that represent the symbolic inputs 67 and outputs of Keras layers during Functional model construction. They are 68 comprised of the `tf.TypeSpec` of the (Composite)Tensor that will be 69 consumed/produced in the corresponding location of the Functional model. 70 71 KerasTensors are intended as a private API, so users should never need to 72 directly instantiate `KerasTensor`s. 73 74 **Building Functional Models with KerasTensors** 75 `tf.keras.Input` produces `KerasTensor`s that represent the symbolic inputs 76 to your model. 77 78 Passing a `KerasTensor` to a `tf.keras.Layer` `__call__` lets the layer know 79 that you are building a Functional model. The layer __call__ will 80 infer the output signature and return `KerasTensor`s with `tf.TypeSpec`s 81 corresponding to the symbolic outputs of that layer call. These output 82 `KerasTensor`s will have all of the internal KerasHistory metadata attached 83 to them that Keras needs to construct a Functional Model. 84 85 Currently, layers infer the output signature by: 86 * creating a scratch `FuncGraph` 87 * making placeholders in the scratch graph that match the input typespecs 88 * Calling `layer.call` on these placeholders 89 * extracting the signatures of the outputs before clearing the scratch graph 90 91 (Note: names assigned to KerasTensors by this process are not guaranteed to 92 be unique, and are subject to implementation details). 93 94 `tf.nest` methods are used to insure all of the inputs/output data 95 structures get maintained, with elements swapped between KerasTensors and 96 placeholders. 97 98 In rare cases (such as when directly manipulating shapes using Keras layers), 99 the layer may be able to partially infer the value of the output in addition 100 to just inferring the signature. 101 When this happens, the returned KerasTensor will also contain the inferred 102 value information. Follow-on layers can use this information. 103 during their own output signature inference. 104 E.g. if one layer produces a symbolic `KerasTensor` that the next layer uses 105 as the shape of its outputs, partially knowing the value helps infer the 106 output shape. 107 108 **Automatically converting TF APIs to layers**: 109 If you passing a `KerasTensor` to a TF API that supports dispatching, 110 Keras will automatically turn that API call into a lambda 111 layer in the Functional model, and return KerasTensors representing the 112 symbolic outputs. 113 114 Most TF APIs that take only tensors as input and produce output tensors 115 will support dispatching. 116 117 Calling a `tf.function` does not support dispatching, so you cannot pass 118 `KerasTensor`s as inputs to a `tf.function`. 119 120 Higher-order APIs that take methods which produce tensors (e.g. `tf.while`, 121 `tf.map_fn`, `tf.cond`) also do not currently support dispatching. So, you 122 cannot directly pass KerasTensors as inputs to these APIs either. If you 123 want to use these APIs inside of a Functional model, you must put them inside 124 of a custom layer. 125 126 Args: 127 type_spec: The `tf.TypeSpec` for the symbolic input created by 128 `tf.keras.Input`, or symbolically inferred for the output 129 during a symbolic layer `__call__`. 130 inferred_value: (Optional) a non-symbolic static value, possibly partially 131 specified, that could be symbolically inferred for the outputs during 132 a symbolic layer `__call__`. This will generally only happen when 133 grabbing and manipulating `tf.int32` shapes directly as tensors. 134 Statically inferring values in this way and storing them in the 135 KerasTensor allows follow-on layers to infer output signatures 136 more effectively. (e.g. when using a symbolic shape tensor to later 137 construct a tensor with that shape). 138 name: (optional) string name for this KerasTensor. Names automatically 139 generated by symbolic layer `__call__`s are not guaranteed to be unique, 140 and are subject to implementation details. 141 """ 142 143 def __init__(self, type_spec, inferred_value=None, name=None): 144 """Constructs a KerasTensor.""" 145 if not isinstance(type_spec, type_spec_module.TypeSpec): 146 raise ValueError('KerasTensors must be constructed with a `tf.TypeSpec`.') 147 148 self._type_spec = type_spec 149 self._inferred_value = inferred_value 150 self._name = name 151 152 @property 153 def type_spec(self): 154 """Returns the `tf.TypeSpec` symbolically inferred for this Keras output.""" 155 return self._type_spec 156 157 @property 158 def shape(self): 159 """Returns the `TensorShape` symbolically inferred for this Keras output.""" 160 # TODO(kaftan): This is only valid for normal/sparse/ragged tensors. 161 # may need to raise an error when it's not valid for a type_spec, 162 # but some keras code (e.g. build-related stuff) will likely fail when 163 # it can't access shape or dtype 164 return self._type_spec._shape # pylint: disable=protected-access 165 166 @classmethod 167 def from_tensor(cls, tensor): 168 """Convert a traced (composite)tensor to a representative KerasTensor.""" 169 if isinstance(tensor, ops.Tensor): 170 name = getattr(tensor, 'name', None) 171 type_spec = type_spec_module.type_spec_from_value(tensor) 172 inferred_value = None 173 if (type_spec.dtype == dtypes.int32 and type_spec.shape.rank is not None 174 and type_spec.shape.rank < 2): 175 # If this tensor might be representing shape information, 176 # (dtype=int32, rank of 0 or 1, not too large to represent a shape) 177 # we attempt to capture any value information tensorflow's 178 # shape handling can extract from the current scratch graph. 179 # 180 # Even though keras layers each trace in their own scratch 181 # graph, this shape value info extraction allows us to capture 182 # a sizable and useful subset of the C++ shape value inference TF can do 183 # if all tf ops appear in the same graph when using shape ops. 184 # 185 # Examples of things this cannot infer concrete dimensions for 186 # that the full single-graph C++ shape inference sometimes can are: 187 # * cases where the shape tensor is cast out of int32 before being 188 # manipulated w/ floating point numbers then converted back 189 # * cases where int32 tensors w/ rank >= 2 are manipulated before being 190 # used as a shape tensor 191 # * cases where int32 tensors too large to represent shapes are 192 # manipulated to a smaller size before being used as a shape tensor 193 inferred_value = array_ops.ones(shape=tensor).shape 194 if inferred_value.dims: 195 inferred_value = inferred_value.as_list() 196 if len(inferred_value) > _MAX_TENSOR_RANK: 197 inferred_value = None 198 else: 199 inferred_value = None 200 201 return KerasTensor(type_spec, inferred_value=inferred_value, name=name) 202 else: 203 # Fallback to the generic arbitrary-typespec KerasTensor 204 name = getattr(tensor, 'name', None) 205 type_spec = type_spec_module.type_spec_from_value(tensor) 206 return cls(type_spec, name=name) 207 208 @classmethod 209 def from_type_spec(cls, type_spec, name=None): 210 return cls(type_spec=type_spec, name=name) 211 212 def _to_placeholder(self): 213 """Convert this KerasTensor to a placeholder in a graph.""" 214 # If there is an inferred value for this tensor, inject the inferred value 215 if self._inferred_value is not None: 216 # If we suspect this KerasTensor might be representing a shape tensor, 217 # and we were able to extract value information with TensorFlow's shape 218 # handling when making the KerasTensor, we construct the placeholder by 219 # re-injecting the inferred value information into the graph. We 220 # do this injection through the shape of a placeholder, because that 221 # allows us to specify partially-unspecified shape values. 222 # 223 # See the comment on value extraction inside `from_tensor` for more info. 224 inferred_value = array_ops.shape( 225 array_ops.placeholder( 226 shape=self._inferred_value, dtype=dtypes.int32)) 227 if self.type_spec.shape.rank == 0: 228 # `tf.shape` always returns a rank-1, we may need to turn it back to a 229 # scalar. 230 inferred_value = inferred_value[0] 231 return inferred_value 232 233 # Use the generic conversion from typespec to a placeholder. 234 def component_to_placeholder(component): 235 return array_ops.placeholder(component.dtype, component.shape) 236 237 return nest.map_structure( 238 component_to_placeholder, self.type_spec, expand_composites=True) 239 240 def get_shape(self): 241 return self.shape 242 243 def __len__(self): 244 raise TypeError('Keras symbolic inputs/outputs do not ' 245 'implement `__len__`. You may be ' 246 'trying to pass Keras symbolic inputs/outputs ' 247 'to a TF API that does not register dispatching, ' 248 'preventing Keras from automatically ' 249 'converting the API call to a lambda layer ' 250 'in the Functional Model. This error will also get raised ' 251 'if you try asserting a symbolic input/output directly.') 252 253 @property 254 def op(self): 255 raise TypeError('Keras symbolic inputs/outputs do not ' 256 'implement `op`. You may be ' 257 'trying to pass Keras symbolic inputs/outputs ' 258 'to a TF API that does not register dispatching, ' 259 'preventing Keras from automatically ' 260 'converting the API call to a lambda layer ' 261 'in the Functional Model.') 262 263 def __hash__(self): 264 raise TypeError('Tensors are unhashable. (%s)' 265 'Instead, use tensor.ref() as the key.' % self) 266 267 # Note: This enables the KerasTensor's overloaded "right" binary 268 # operators to run when the left operand is an ndarray, because it 269 # accords the Tensor class higher priority than an ndarray, or a 270 # numpy matrix. 271 # In the future explore chaning this to using numpy's __numpy_ufunc__ 272 # mechanism, which allows more control over how Tensors interact 273 # with ndarrays. 274 __array_priority__ = 100 275 276 def __array__(self): 277 raise TypeError( 278 'Cannot convert a symbolic Keras input/output to a numpy array. ' 279 'This error may indicate that you\'re trying to pass a symbolic value ' 280 'to a NumPy call, which is not supported. Or, ' 281 'you may be trying to pass Keras symbolic inputs/outputs ' 282 'to a TF API that does not register dispatching, ' 283 'preventing Keras from automatically ' 284 'converting the API call to a lambda layer ' 285 'in the Functional Model.') 286 287 @property 288 def is_tensor_like(self): 289 return True 290 291 def set_shape(self, shape): 292 """Updates the shape of this KerasTensor. Mimics `tf.Tensor.set_shape()`.""" 293 if not isinstance(shape, tensor_shape.TensorShape): 294 shape = tensor_shape.TensorShape(shape) 295 if shape.dims is not None: 296 dim_list = [dim.value for dim in shape.dims] 297 for dim in range(len(dim_list)): 298 if dim_list[dim] is None and self.shape.dims is not None: 299 dim_list[dim] = self.shape.dims[dim] 300 shape = tensor_shape.TensorShape(dim_list) 301 if not self.shape.is_compatible_with(shape): 302 raise ValueError( 303 "Keras symbolic input/output's shape %s is not" 304 "compatible with supplied shape %s" % 305 (self.shape, shape)) 306 else: 307 self._type_spec._shape = shape # pylint: disable=protected-access 308 309 def __str__(self): 310 symbolic_description = '' 311 inferred_value_string = '' 312 name_string = '' 313 314 if hasattr(self, '_keras_history'): 315 layer = self._keras_history.layer 316 symbolic_description = ( 317 ', description="created by layer \'%s\'"' % (layer.name,)) 318 if self._inferred_value is not None: 319 inferred_value_string = ( 320 ', inferred_value=%s' % self._inferred_value) 321 if self.name is not None: 322 name_string = ', name=\'%s\'' % self._name 323 return 'KerasTensor(type_spec=%s%s%s%s)' % ( 324 self.type_spec, inferred_value_string, 325 name_string, symbolic_description) 326 327 def __repr__(self): 328 symbolic_description = '' 329 inferred_value_string = '' 330 if isinstance(self.type_spec, tensor_spec.TensorSpec): 331 type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name) 332 else: 333 type_spec_string = 'type_spec=%s' % self.type_spec 334 335 if hasattr(self, '_keras_history'): 336 layer = self._keras_history.layer 337 symbolic_description = ' (created by layer \'%s\')' % (layer.name,) 338 if self._inferred_value is not None: 339 inferred_value_string = ( 340 ' inferred_value=%s' % self._inferred_value) 341 return '<KerasTensor: %s%s%s>' % ( 342 type_spec_string, inferred_value_string, symbolic_description) 343 344 @property 345 def dtype(self): 346 """Returns the `dtype` symbolically inferred for this Keras output.""" 347 # TODO(kaftan): This is only valid for normal/sparse/ragged tensors. 348 # may need to raise an error when it's not valid for a type_spec, 349 # but some keras code (e.g. build-related stuff) will likely fail when 350 # it can't access shape or dtype 351 return self._type_spec._dtype # pylint: disable=protected-access 352 353 def ref(self): 354 """Returns a hashable reference object to this KerasTensor. 355 356 The primary use case for this API is to put KerasTensors in a 357 set/dictionary. We can't put tensors in a set/dictionary as 358 `tensor.__hash__()` is not available and tensor equality (`==`) is supposed 359 to produce a tensor representing if the two inputs are equal. 360 361 See the documentation of `tf.Tensor.ref()` for more info. 362 """ 363 return object_identity.Reference(self) 364 365 def __iter__(self): 366 shape = None 367 if self.shape.ndims is not None: 368 shape = [dim.value for dim in self.shape.dims] 369 370 if shape is None: 371 raise TypeError('Cannot iterate over a Tensor with unknown shape.') 372 if not shape: 373 raise TypeError('Cannot iterate over a scalar.') 374 if shape[0] is None: 375 raise TypeError( 376 'Cannot iterate over a Tensor with unknown first dimension.') 377 return _KerasTensorIterator(self, shape[0]) 378 379 @property 380 def name(self): 381 """Returns the (non-unique, optional) name of this symbolic Keras value.""" 382 return self._name 383 384 @classmethod 385 def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name 386 """Register overloads for all operators.""" 387 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 388 cls._overload_operator(tensor_class, operator) 389 390 # We include `experimental_ref` for versions of TensorFlow that 391 # still include the deprecated method in Tensors. 392 if hasattr(tensor_class, 'experimental_ref'): 393 cls._overload_operator(tensor_class, 'experimental_ref') 394 395 @classmethod 396 def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid-name 397 """Overload an operator with the same implementation as a base Tensor class. 398 399 We pull the operator out of the class dynamically to avoid ordering issues. 400 401 Args: 402 tensor_class: The (Composite)Tensor to get the method from. 403 operator: string. The operator name. 404 """ 405 tensor_oper = getattr(tensor_class, operator) 406 407 # Compatibility with Python 2: 408 # Python 2 unbound methods have type checks for the first arg, 409 # so we need to extract the underlying function 410 tensor_oper = getattr(tensor_oper, '__func__', tensor_oper) 411 412 setattr(cls, operator, tensor_oper) 413 414 415KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access 416 417 418class SparseKerasTensor(KerasTensor): 419 """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s. 420 421 Specifically, it specializes the conversion to a placeholder in order 422 to maintain dense shape information. 423 """ 424 425 def _to_placeholder(self): 426 spec = self.type_spec 427 428 # nest.map_structure loses dense shape information for sparse tensors. 429 # So, we special-case sparse placeholder creation. 430 # This only preserves shape information for top-level sparse tensors; 431 # not for sparse tensors that are nested inside another composite 432 # tensor. 433 return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape) 434 435 436class RaggedKerasTensor(KerasTensor): 437 """A specialized KerasTensor representation for `tf.RaggedTensor`s. 438 439 Specifically, it: 440 441 1. Specializes the conversion to a placeholder in order 442 to maintain shape information for non-ragged dimensions. 443 2. Overloads the KerasTensor's operators with the RaggedTensor versions 444 when they don't match the `tf.Tensor` versions 445 3. Exposes some of the instance method/attribute that are unique to 446 the RaggedTensor API (such as ragged_rank). 447 """ 448 449 def _to_placeholder(self): 450 ragged_spec = self.type_spec 451 if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None: 452 return super(RaggedKerasTensor, self)._to_placeholder() 453 454 flat_shape = ragged_spec.shape[ragged_spec.ragged_rank:] 455 result = array_ops.placeholder(ragged_spec.dtype, flat_shape) 456 457 known_num_splits = [] 458 prod = 1 459 for axis_size in ragged_spec.shape: 460 if prod is not None: 461 if axis_size is None or ( 462 getattr(axis_size, 'value', True) is None): 463 prod = None 464 else: 465 prod = prod * axis_size 466 known_num_splits.append(prod) 467 468 for axis in range(ragged_spec.ragged_rank, 0, -1): 469 axis_size = ragged_spec.shape[axis] 470 if axis_size is None or (getattr(axis_size, 'value', True) is None): 471 num_splits = known_num_splits[axis-1] 472 if num_splits is not None: 473 num_splits = num_splits + 1 474 splits = array_ops.placeholder( 475 ragged_spec.row_splits_dtype, [num_splits]) 476 result = ragged_tensor.RaggedTensor.from_row_splits( 477 result, splits, validate=False) 478 else: 479 rowlen = constant_op.constant(axis_size, ragged_spec.row_splits_dtype) 480 result = ragged_tensor.RaggedTensor.from_uniform_row_length( 481 result, rowlen, validate=False) 482 return result 483 484 @property 485 def ragged_rank(self): 486 return self.type_spec.ragged_rank 487 488RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__add__') # pylint: disable=protected-access 489RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__radd__') # pylint: disable=protected-access 490RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__mul__') # pylint: disable=protected-access 491RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__rmul__') # pylint: disable=protected-access 492 493 494# TODO(b/161487382): 495# Special-case user-registered symbolic objects (registered by the 496# private `register_symbolic_tensor_type` method) by passing them between 497# scratch graphs directly. 498# This is needed to not break Tensorflow probability 499# while they finish migrating to composite tensors. 500class UserRegisteredSpec(type_spec_module.TypeSpec): 501 """TypeSpec to represent user-registered symbolic objects.""" 502 503 def __init__(self, shape, dtype): 504 self.shape = shape 505 self._dtype = dtype 506 self.dtype = dtype 507 508 def _component_specs(self): 509 raise NotImplementedError 510 511 def _from_components(self, components): 512 raise NotImplementedError 513 514 def _serialize(self): 515 raise NotImplementedError 516 517 def _to_components(self, value): 518 raise NotImplementedError 519 520 def value_type(self): 521 raise NotImplementedError 522 523 524# TODO(b/161487382): 525# Special-case user-registered symbolic objects (registered by the 526# private `register_symbolic_tensor_type` method) by passing them between 527# scratch graphs directly. 528# This is needed to not break Tensorflow probability 529# while they finish migrating to composite tensors. 530class UserRegisteredTypeKerasTensor(KerasTensor): 531 """KerasTensor that represents legacy register_symbolic_tensor_type.""" 532 533 def __init__(self, user_registered_symbolic_object): 534 x = user_registered_symbolic_object 535 self._user_registered_symbolic_object = x 536 type_spec = UserRegisteredSpec(x.shape, x.dtype) 537 name = getattr(x, 'name', None) 538 539 super(UserRegisteredTypeKerasTensor, self).__init__(type_spec, name) 540 541 @classmethod 542 def from_tensor(cls, tensor): 543 return cls(tensor) 544 545 @classmethod 546 def from_type_spec(cls, type_spec, name=None): 547 raise NotImplementedError('You cannot instantiate a KerasTensor ' 548 'directly from TypeSpec: %s' % type_spec) 549 550 def _to_placeholder(self): 551 return self._user_registered_symbolic_object 552 553 554class _KerasTensorIterator(object): 555 """Iterates over the leading dim of a KerasTensor. Performs 0 error checks.""" 556 557 def __init__(self, tensor, dim0): 558 self._tensor = tensor 559 self._index = 0 560 self._limit = dim0 561 562 def __iter__(self): 563 return self 564 565 def __next__(self): 566 if self._index == self._limit: 567 raise StopIteration 568 result = self._tensor[self._index] 569 self._index += 1 570 return result 571 572 next = __next__ # python2.x compatibility. 573 574 575# Specify the mappings of tensor class to KerasTensor class. 576# This is specifically a list instead of a dict for now because 577# 1. we do a check w/ isinstance because a key lookup based on class 578# would miss subclasses 579# 2. a list allows us to control lookup ordering 580# We include ops.Tensor -> KerasTensor in the first position as a fastpath, 581# *and* include object -> KerasTensor at the end as a catch-all. 582# We can re-visit these choices in the future as needed. 583keras_tensor_classes = [ 584 (ops.Tensor, KerasTensor), 585 (sparse_tensor.SparseTensor, SparseKerasTensor), 586 (ragged_tensor.RaggedTensor, RaggedKerasTensor), 587 (object, KerasTensor) 588] 589 590 591def register_keras_tensor_specialization(cls, keras_tensor_subclass): 592 """Register a specialized KerasTensor subclass for a Tensor type.""" 593 # We always leave (object, KerasTensor) at the end as a generic fallback 594 keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass)) 595 596 597def keras_tensor_to_placeholder(x): 598 """Construct a graph placeholder to represent a KerasTensor when tracing.""" 599 if isinstance(x, KerasTensor): 600 return x._to_placeholder() # pylint: disable=protected-access 601 else: 602 return x 603 604 605def keras_tensor_from_tensor(tensor): 606 """Convert a traced (composite)tensor to a representative KerasTensor.""" 607 # Create a specialized KerasTensor that supports instance methods, 608 # operators, and additional value inference if possible 609 keras_tensor_cls = None 610 for tensor_type, cls in keras_tensor_classes: 611 if isinstance(tensor, tensor_type): 612 keras_tensor_cls = cls 613 break 614 615 out = keras_tensor_cls.from_tensor(tensor) 616 617 if hasattr(tensor, '_keras_mask'): 618 out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access 619 return out 620 621 622def keras_tensor_from_type_spec(type_spec, name=None): 623 """Convert a TypeSpec to a representative KerasTensor.""" 624 # Create a specialized KerasTensor that supports instance methods, 625 # operators, and additional value inference if possible 626 keras_tensor_cls = None 627 value_type = type_spec.value_type 628 for tensor_type, cls in keras_tensor_classes: 629 if issubclass(value_type, tensor_type): 630 keras_tensor_cls = cls 631 break 632 633 return keras_tensor_cls.from_type_spec(type_spec, name=name) 634