• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras Input Tensor used to track functional API Topology."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import type_spec as type_spec_module
28from tensorflow.python.keras.utils import object_identity
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops.ragged import ragged_operators  # pylint: disable=unused-import
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.util import nest
33
34# pylint: disable=g-classes-have-attributes
35
36_KERAS_TENSORS_ENABLED = True
37
38
39def enable_keras_tensors():
40  """Enable using KerasTensors in Keras's functional API."""
41  global _KERAS_TENSORS_ENABLED
42  _KERAS_TENSORS_ENABLED = True
43
44
45def disable_keras_tensors():
46  """Disable using KerasTensors in Keras's functional API."""
47  global _KERAS_TENSORS_ENABLED
48  _KERAS_TENSORS_ENABLED = False
49
50
51def keras_tensors_enabled():
52  """Return a bool specifying if KerasTensors are enabled."""
53  return _KERAS_TENSORS_ENABLED and ops.executing_eagerly_outside_functions()
54
55
56# Tensorflow tensors have a maximum rank of 254
57# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h )
58# So we do not try to infer values for int32 tensors larger than this,
59# As they cannot represent shapes.
60_MAX_TENSOR_RANK = 254
61
62
63class KerasTensor(object):
64  """A representation of a Keras in/output during Functional API construction.
65
66  `KerasTensor`s are tensor-like objects that represent the symbolic inputs
67  and outputs of Keras layers during Functional model construction. They are
68  comprised of the `tf.TypeSpec` of the (Composite)Tensor that will be
69  consumed/produced in the corresponding location of the Functional model.
70
71  KerasTensors are intended as a private API, so users should never need to
72  directly instantiate `KerasTensor`s.
73
74  **Building Functional Models with KerasTensors**
75  `tf.keras.Input` produces `KerasTensor`s that represent the symbolic inputs
76  to your model.
77
78  Passing a `KerasTensor` to a `tf.keras.Layer` `__call__` lets the layer know
79  that you are building a Functional model. The layer __call__ will
80  infer the output signature and return `KerasTensor`s with `tf.TypeSpec`s
81  corresponding to the symbolic outputs of that layer call. These output
82  `KerasTensor`s will have all of the internal KerasHistory metadata attached
83  to them that Keras needs to construct a Functional Model.
84
85  Currently, layers infer the output signature by:
86    * creating a scratch `FuncGraph`
87    * making placeholders in the scratch graph that match the input typespecs
88    * Calling `layer.call` on these placeholders
89    * extracting the signatures of the outputs before clearing the scratch graph
90
91  (Note: names assigned to KerasTensors by this process are not guaranteed to
92  be unique, and are subject to implementation details).
93
94  `tf.nest` methods are used to insure all of the inputs/output data
95  structures get maintained, with elements swapped between KerasTensors and
96  placeholders.
97
98  In rare cases (such as when directly manipulating shapes using Keras layers),
99  the layer may be able to partially infer the value of the output in addition
100  to just inferring the signature.
101  When this happens, the returned KerasTensor will also contain the inferred
102  value information. Follow-on layers can use this information.
103  during their own output signature inference.
104  E.g. if one layer produces a symbolic `KerasTensor` that the next layer uses
105  as the shape of its outputs, partially knowing the value helps infer the
106  output shape.
107
108  **Automatically converting TF APIs to layers**:
109  If you passing a `KerasTensor` to a TF API that supports dispatching,
110  Keras will automatically turn that API call into a lambda
111  layer in the Functional model, and return KerasTensors representing the
112  symbolic outputs.
113
114  Most TF APIs that take only tensors as input and produce output tensors
115  will support dispatching.
116
117  Calling a `tf.function` does not support dispatching, so you cannot pass
118  `KerasTensor`s as inputs to a `tf.function`.
119
120  Higher-order APIs that take methods which produce tensors (e.g. `tf.while`,
121  `tf.map_fn`, `tf.cond`) also do not currently support dispatching. So, you
122  cannot directly pass KerasTensors as inputs to these APIs either. If you
123  want to use these APIs inside of a Functional model, you must put them inside
124  of a custom layer.
125
126  Args:
127    type_spec: The `tf.TypeSpec` for the symbolic input created by
128      `tf.keras.Input`, or symbolically inferred for the output
129      during a symbolic layer `__call__`.
130    inferred_value: (Optional) a non-symbolic static value, possibly partially
131      specified, that could be symbolically inferred for the outputs during
132      a symbolic layer `__call__`. This will generally only happen when
133      grabbing and manipulating `tf.int32` shapes directly as tensors.
134      Statically inferring values in this way and storing them in the
135      KerasTensor allows follow-on layers to infer output signatures
136      more effectively. (e.g. when using a symbolic shape tensor to later
137      construct a tensor with that shape).
138    name: (optional) string name for this KerasTensor. Names automatically
139      generated by symbolic layer `__call__`s are not guaranteed to be unique,
140      and are subject to implementation details.
141  """
142
143  def __init__(self, type_spec, inferred_value=None, name=None):
144    """Constructs a KerasTensor."""
145    if not isinstance(type_spec, type_spec_module.TypeSpec):
146      raise ValueError('KerasTensors must be constructed with a `tf.TypeSpec`.')
147
148    self._type_spec = type_spec
149    self._inferred_value = inferred_value
150    self._name = name
151
152  @property
153  def type_spec(self):
154    """Returns the `tf.TypeSpec` symbolically inferred for this Keras output."""
155    return self._type_spec
156
157  @property
158  def shape(self):
159    """Returns the `TensorShape` symbolically inferred for this Keras output."""
160    # TODO(kaftan): This is only valid for normal/sparse/ragged tensors.
161    # may need to raise an error when it's not valid for a type_spec,
162    # but some keras code (e.g. build-related stuff) will likely fail when
163    # it can't access shape or dtype
164    return self._type_spec._shape  # pylint: disable=protected-access
165
166  @classmethod
167  def from_tensor(cls, tensor):
168    """Convert a traced (composite)tensor to a representative KerasTensor."""
169    if isinstance(tensor, ops.Tensor):
170      name = getattr(tensor, 'name', None)
171      type_spec = type_spec_module.type_spec_from_value(tensor)
172      inferred_value = None
173      if (type_spec.dtype == dtypes.int32 and type_spec.shape.rank is not None
174          and type_spec.shape.rank < 2):
175        # If this tensor might be representing shape information,
176        # (dtype=int32, rank of 0 or 1, not too large to represent a shape)
177        # we attempt to capture any value information tensorflow's
178        # shape handling can extract from the current scratch graph.
179        #
180        # Even though keras layers each trace in their own scratch
181        # graph, this shape value info extraction allows us to capture
182        # a sizable and useful subset of the C++ shape value inference TF can do
183        # if all tf ops appear in the same graph when using shape ops.
184        #
185        # Examples of things this cannot infer concrete dimensions for
186        # that the full single-graph C++ shape inference sometimes can are:
187        # * cases where the shape tensor is cast out of int32 before being
188        #   manipulated w/ floating point numbers then converted back
189        # * cases where int32 tensors w/ rank >= 2 are manipulated before being
190        #   used as a shape tensor
191        # * cases where int32 tensors too large to represent shapes are
192        #   manipulated to a smaller size before being used as a shape tensor
193        inferred_value = array_ops.ones(shape=tensor).shape
194        if inferred_value.dims:
195          inferred_value = inferred_value.as_list()
196          if len(inferred_value) > _MAX_TENSOR_RANK:
197            inferred_value = None
198        else:
199          inferred_value = None
200
201      return KerasTensor(type_spec, inferred_value=inferred_value, name=name)
202    else:
203      # Fallback to the generic arbitrary-typespec KerasTensor
204      name = getattr(tensor, 'name', None)
205      type_spec = type_spec_module.type_spec_from_value(tensor)
206      return cls(type_spec, name=name)
207
208  @classmethod
209  def from_type_spec(cls, type_spec, name=None):
210    return cls(type_spec=type_spec, name=name)
211
212  def _to_placeholder(self):
213    """Convert this KerasTensor to a placeholder in a graph."""
214    # If there is an inferred value for this tensor, inject the inferred value
215    if self._inferred_value is not None:
216      # If we suspect this KerasTensor might be representing a shape tensor,
217      # and we were able to extract value information with TensorFlow's shape
218      # handling when making the KerasTensor, we construct the placeholder by
219      # re-injecting the inferred value information into the graph. We
220      # do this injection through the shape of a placeholder, because that
221      # allows us to specify partially-unspecified shape values.
222      #
223      # See the comment on value extraction inside `from_tensor` for more info.
224      inferred_value = array_ops.shape(
225          array_ops.placeholder(
226              shape=self._inferred_value, dtype=dtypes.int32))
227      if self.type_spec.shape.rank == 0:
228        # `tf.shape` always returns a rank-1, we may need to turn it back to a
229        # scalar.
230        inferred_value = inferred_value[0]
231      return inferred_value
232
233    # Use the generic conversion from typespec to a placeholder.
234    def component_to_placeholder(component):
235      return array_ops.placeholder(component.dtype, component.shape)
236
237    return nest.map_structure(
238        component_to_placeholder, self.type_spec, expand_composites=True)
239
240  def get_shape(self):
241    return self.shape
242
243  def __len__(self):
244    raise TypeError('Keras symbolic inputs/outputs do not '
245                    'implement `__len__`. You may be '
246                    'trying to pass Keras symbolic inputs/outputs '
247                    'to a TF API that does not register dispatching, '
248                    'preventing Keras from automatically '
249                    'converting the API call to a lambda layer '
250                    'in the Functional Model. This error will also get raised '
251                    'if you try asserting a symbolic input/output directly.')
252
253  @property
254  def op(self):
255    raise TypeError('Keras symbolic inputs/outputs do not '
256                    'implement `op`. You may be '
257                    'trying to pass Keras symbolic inputs/outputs '
258                    'to a TF API that does not register dispatching, '
259                    'preventing Keras from automatically '
260                    'converting the API call to a lambda layer '
261                    'in the Functional Model.')
262
263  def __hash__(self):
264    raise TypeError('Tensors are unhashable. (%s)'
265                    'Instead, use tensor.ref() as the key.' % self)
266
267  # Note: This enables the KerasTensor's overloaded "right" binary
268  # operators to run when the left operand is an ndarray, because it
269  # accords the Tensor class higher priority than an ndarray, or a
270  # numpy matrix.
271  # In the future explore chaning this to using numpy's __numpy_ufunc__
272  # mechanism, which allows more control over how Tensors interact
273  # with ndarrays.
274  __array_priority__ = 100
275
276  def __array__(self):
277    raise TypeError(
278        'Cannot convert a symbolic Keras input/output to a numpy array. '
279        'This error may indicate that you\'re trying to pass a symbolic value '
280        'to a NumPy call, which is not supported. Or, '
281        'you may be trying to pass Keras symbolic inputs/outputs '
282        'to a TF API that does not register dispatching, '
283        'preventing Keras from automatically '
284        'converting the API call to a lambda layer '
285        'in the Functional Model.')
286
287  @property
288  def is_tensor_like(self):
289    return True
290
291  def set_shape(self, shape):
292    """Updates the shape of this KerasTensor. Mimics `tf.Tensor.set_shape()`."""
293    if not isinstance(shape, tensor_shape.TensorShape):
294      shape = tensor_shape.TensorShape(shape)
295    if shape.dims is not None:
296      dim_list = [dim.value for dim in shape.dims]
297      for dim in range(len(dim_list)):
298        if dim_list[dim] is None and self.shape.dims is not None:
299          dim_list[dim] = self.shape.dims[dim]
300      shape = tensor_shape.TensorShape(dim_list)
301    if not self.shape.is_compatible_with(shape):
302      raise ValueError(
303          "Keras symbolic input/output's shape %s is not"
304          "compatible with supplied shape %s" %
305          (self.shape, shape))
306    else:
307      self._type_spec._shape = shape  # pylint: disable=protected-access
308
309  def __str__(self):
310    symbolic_description = ''
311    inferred_value_string = ''
312    name_string = ''
313
314    if hasattr(self, '_keras_history'):
315      layer = self._keras_history.layer
316      symbolic_description = (
317          ', description="created by layer \'%s\'"' % (layer.name,))
318    if self._inferred_value is not None:
319      inferred_value_string = (
320          ', inferred_value=%s' % self._inferred_value)
321    if self.name is not None:
322      name_string = ', name=\'%s\'' % self._name
323    return 'KerasTensor(type_spec=%s%s%s%s)' % (
324        self.type_spec, inferred_value_string,
325        name_string, symbolic_description)
326
327  def __repr__(self):
328    symbolic_description = ''
329    inferred_value_string = ''
330    if isinstance(self.type_spec, tensor_spec.TensorSpec):
331      type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name)
332    else:
333      type_spec_string = 'type_spec=%s' % self.type_spec
334
335    if hasattr(self, '_keras_history'):
336      layer = self._keras_history.layer
337      symbolic_description = ' (created by layer \'%s\')' % (layer.name,)
338    if self._inferred_value is not None:
339      inferred_value_string = (
340          ' inferred_value=%s' % self._inferred_value)
341    return '<KerasTensor: %s%s%s>' % (
342        type_spec_string, inferred_value_string, symbolic_description)
343
344  @property
345  def dtype(self):
346    """Returns the `dtype` symbolically inferred for this Keras output."""
347    # TODO(kaftan): This is only valid for normal/sparse/ragged tensors.
348    # may need to raise an error when it's not valid for a type_spec,
349    # but some keras code (e.g. build-related stuff) will likely fail when
350    # it can't access shape or dtype
351    return self._type_spec._dtype  # pylint: disable=protected-access
352
353  def ref(self):
354    """Returns a hashable reference object to this KerasTensor.
355
356    The primary use case for this API is to put KerasTensors in a
357    set/dictionary. We can't put tensors in a set/dictionary as
358    `tensor.__hash__()` is not available and tensor equality (`==`) is supposed
359    to produce a tensor representing if the two inputs are equal.
360
361    See the documentation of `tf.Tensor.ref()` for more info.
362    """
363    return object_identity.Reference(self)
364
365  def __iter__(self):
366    shape = None
367    if self.shape.ndims is not None:
368      shape = [dim.value for dim in self.shape.dims]
369
370    if shape is None:
371      raise TypeError('Cannot iterate over a Tensor with unknown shape.')
372    if not shape:
373      raise TypeError('Cannot iterate over a scalar.')
374    if shape[0] is None:
375      raise TypeError(
376          'Cannot iterate over a Tensor with unknown first dimension.')
377    return _KerasTensorIterator(self, shape[0])
378
379  @property
380  def name(self):
381    """Returns the (non-unique, optional) name of this symbolic Keras value."""
382    return self._name
383
384  @classmethod
385  def _overload_all_operators(cls, tensor_class):  # pylint: disable=invalid-name
386    """Register overloads for all operators."""
387    for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
388      cls._overload_operator(tensor_class, operator)
389
390    # We include `experimental_ref` for versions of TensorFlow that
391    # still include the deprecated method in Tensors.
392    if hasattr(tensor_class, 'experimental_ref'):
393      cls._overload_operator(tensor_class, 'experimental_ref')
394
395  @classmethod
396  def _overload_operator(cls, tensor_class, operator):  # pylint: disable=invalid-name
397    """Overload an operator with the same implementation as a base Tensor class.
398
399    We pull the operator out of the class dynamically to avoid ordering issues.
400
401    Args:
402      tensor_class: The (Composite)Tensor to get the method from.
403      operator: string. The operator name.
404    """
405    tensor_oper = getattr(tensor_class, operator)
406
407    # Compatibility with Python 2:
408    # Python 2 unbound methods have type checks for the first arg,
409    # so we need to extract the underlying function
410    tensor_oper = getattr(tensor_oper, '__func__', tensor_oper)
411
412    setattr(cls, operator, tensor_oper)
413
414
415KerasTensor._overload_all_operators(ops.Tensor)  # pylint: disable=protected-access
416
417
418class SparseKerasTensor(KerasTensor):
419  """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s.
420
421  Specifically, it specializes the conversion to a placeholder in order
422  to maintain dense shape information.
423  """
424
425  def _to_placeholder(self):
426    spec = self.type_spec
427
428    # nest.map_structure loses dense shape information for sparse tensors.
429    # So, we special-case sparse placeholder creation.
430    # This only preserves shape information for top-level sparse tensors;
431    # not for sparse tensors that are nested inside another composite
432    # tensor.
433    return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape)
434
435
436class RaggedKerasTensor(KerasTensor):
437  """A specialized KerasTensor representation for `tf.RaggedTensor`s.
438
439  Specifically, it:
440
441  1. Specializes the conversion to a placeholder in order
442  to maintain shape information for non-ragged dimensions.
443  2. Overloads the KerasTensor's operators with the RaggedTensor versions
444  when they don't match the `tf.Tensor` versions
445  3. Exposes some of the instance method/attribute that are unique to
446  the RaggedTensor API (such as ragged_rank).
447  """
448
449  def _to_placeholder(self):
450    ragged_spec = self.type_spec
451    if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None:
452      return super(RaggedKerasTensor, self)._to_placeholder()
453
454    flat_shape = ragged_spec.shape[ragged_spec.ragged_rank:]
455    result = array_ops.placeholder(ragged_spec.dtype, flat_shape)
456
457    known_num_splits = []
458    prod = 1
459    for axis_size in ragged_spec.shape:
460      if prod is not None:
461        if axis_size is None or (
462            getattr(axis_size, 'value', True) is None):
463          prod = None
464        else:
465          prod = prod * axis_size
466      known_num_splits.append(prod)
467
468    for axis in range(ragged_spec.ragged_rank, 0, -1):
469      axis_size = ragged_spec.shape[axis]
470      if axis_size is None or (getattr(axis_size, 'value', True) is None):
471        num_splits = known_num_splits[axis-1]
472        if num_splits is not None:
473          num_splits = num_splits + 1
474        splits = array_ops.placeholder(
475            ragged_spec.row_splits_dtype, [num_splits])
476        result = ragged_tensor.RaggedTensor.from_row_splits(
477            result, splits, validate=False)
478      else:
479        rowlen = constant_op.constant(axis_size, ragged_spec.row_splits_dtype)
480        result = ragged_tensor.RaggedTensor.from_uniform_row_length(
481            result, rowlen, validate=False)
482    return result
483
484  @property
485  def ragged_rank(self):
486    return self.type_spec.ragged_rank
487
488RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__add__')  # pylint: disable=protected-access
489RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__radd__')  # pylint: disable=protected-access
490RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__mul__')  # pylint: disable=protected-access
491RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__rmul__')  # pylint: disable=protected-access
492
493
494# TODO(b/161487382):
495# Special-case user-registered symbolic objects (registered by the
496# private `register_symbolic_tensor_type` method) by passing them between
497# scratch graphs directly.
498# This is needed to not break Tensorflow probability
499# while they finish migrating to composite tensors.
500class UserRegisteredSpec(type_spec_module.TypeSpec):
501  """TypeSpec to represent user-registered symbolic objects."""
502
503  def __init__(self, shape, dtype):
504    self.shape = shape
505    self._dtype = dtype
506    self.dtype = dtype
507
508  def _component_specs(self):
509    raise NotImplementedError
510
511  def _from_components(self, components):
512    raise NotImplementedError
513
514  def _serialize(self):
515    raise NotImplementedError
516
517  def _to_components(self, value):
518    raise NotImplementedError
519
520  def value_type(self):
521    raise NotImplementedError
522
523
524# TODO(b/161487382):
525# Special-case user-registered symbolic objects (registered by the
526# private `register_symbolic_tensor_type` method) by passing them between
527# scratch graphs directly.
528# This is needed to not break Tensorflow probability
529# while they finish migrating to composite tensors.
530class UserRegisteredTypeKerasTensor(KerasTensor):
531  """KerasTensor that represents legacy register_symbolic_tensor_type."""
532
533  def __init__(self, user_registered_symbolic_object):
534    x = user_registered_symbolic_object
535    self._user_registered_symbolic_object = x
536    type_spec = UserRegisteredSpec(x.shape, x.dtype)
537    name = getattr(x, 'name', None)
538
539    super(UserRegisteredTypeKerasTensor, self).__init__(type_spec, name)
540
541  @classmethod
542  def from_tensor(cls, tensor):
543    return cls(tensor)
544
545  @classmethod
546  def from_type_spec(cls, type_spec, name=None):
547    raise NotImplementedError('You cannot instantiate a KerasTensor '
548                              'directly from TypeSpec: %s' % type_spec)
549
550  def _to_placeholder(self):
551    return self._user_registered_symbolic_object
552
553
554class _KerasTensorIterator(object):
555  """Iterates over the leading dim of a KerasTensor. Performs 0 error checks."""
556
557  def __init__(self, tensor, dim0):
558    self._tensor = tensor
559    self._index = 0
560    self._limit = dim0
561
562  def __iter__(self):
563    return self
564
565  def __next__(self):
566    if self._index == self._limit:
567      raise StopIteration
568    result = self._tensor[self._index]
569    self._index += 1
570    return result
571
572  next = __next__  # python2.x compatibility.
573
574
575# Specify the mappings of tensor class to KerasTensor class.
576# This is specifically a list instead of a dict for now because
577# 1. we do a check w/ isinstance because a key lookup based on class
578#    would miss subclasses
579# 2. a list allows us to control lookup ordering
580# We include ops.Tensor -> KerasTensor in the first position as a fastpath,
581# *and* include object -> KerasTensor at the end as a catch-all.
582# We can re-visit these choices in the future as needed.
583keras_tensor_classes = [
584    (ops.Tensor, KerasTensor),
585    (sparse_tensor.SparseTensor, SparseKerasTensor),
586    (ragged_tensor.RaggedTensor, RaggedKerasTensor),
587    (object, KerasTensor)
588]
589
590
591def register_keras_tensor_specialization(cls, keras_tensor_subclass):
592  """Register a specialized KerasTensor subclass for a Tensor type."""
593  # We always leave (object, KerasTensor) at the end as a generic fallback
594  keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass))
595
596
597def keras_tensor_to_placeholder(x):
598  """Construct a graph placeholder to represent a KerasTensor when tracing."""
599  if isinstance(x, KerasTensor):
600    return x._to_placeholder()  # pylint: disable=protected-access
601  else:
602    return x
603
604
605def keras_tensor_from_tensor(tensor):
606  """Convert a traced (composite)tensor to a representative KerasTensor."""
607  # Create a specialized KerasTensor that supports instance methods,
608  # operators, and additional value inference if possible
609  keras_tensor_cls = None
610  for tensor_type, cls in keras_tensor_classes:
611    if isinstance(tensor, tensor_type):
612      keras_tensor_cls = cls
613      break
614
615  out = keras_tensor_cls.from_tensor(tensor)
616
617  if hasattr(tensor, '_keras_mask'):
618    out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask)  # pylint: disable=protected-access
619  return out
620
621
622def keras_tensor_from_type_spec(type_spec, name=None):
623  """Convert a TypeSpec to a representative KerasTensor."""
624  # Create a specialized KerasTensor that supports instance methods,
625  # operators, and additional value inference if possible
626  keras_tensor_cls = None
627  value_type = type_spec.value_type
628  for tensor_type, cls in keras_tensor_classes:
629    if issubclass(value_type, tensor_type):
630      keras_tensor_cls = cls
631      break
632
633  return keras_tensor_cls.from_type_spec(type_spec, name=name)
634