• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes for tensor shape inference."""
16import functools
17import operator
18from typing import Optional, Sequence, Type
19
20from tensorflow.core.framework import tensor_shape_pb2
21from tensorflow.core.function import trace_type
22from tensorflow.python import tf2
23from tensorflow.python.eager import monitoring
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.types import trace
26from tensorflow.python.util.tf_export import tf_export
27
28_TENSORSHAPE_V2_OVERRIDE = None
29
30_api_usage_gauge = monitoring.BoolGauge(
31    "/tensorflow/api/v2_tensorshape",
32    "Whether tensor_shape.enable_v2_tensorshape() is called.")
33
34
35@tf_export(v1=["enable_v2_tensorshape"])
36def enable_v2_tensorshape():
37  """In TensorFlow 2.0, iterating over a TensorShape instance returns values.
38
39  This enables the new behavior.
40
41  Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but
42  it V2 it returns either an integer, or None.
43
44  Examples:
45
46  ```
47  #######################
48  # If you had this in V1:
49  value = tensor_shape[i].value
50
51  # Do this in V2 instead:
52  value = tensor_shape[i]
53
54  #######################
55  # If you had this in V1:
56  for dim in tensor_shape:
57    value = dim.value
58    print(value)
59
60  # Do this in V2 instead:
61  for value in tensor_shape:
62    print(value)
63
64  #######################
65  # If you had this in V1:
66  dim = tensor_shape[i]
67  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
68
69  # Do this in V2 instead:
70  if tensor_shape.rank is None:
71    dim = Dimension(None)
72  else:
73    dim = tensor_shape.dims[i]
74  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
75
76  # The V2 suggestion above is more explicit, which will save you from
77  # the following trap (present in V1):
78  # you might do in-place modifications to `dim` and expect them to be reflected
79  # in `tensor_shape[i]`, but they would not be.
80  ```
81  """
82  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
83  _TENSORSHAPE_V2_OVERRIDE = True
84  logging.vlog(1, "Enabling v2 tensorshape")
85  _api_usage_gauge.get_cell().set(True)
86
87
88@tf_export(v1=["disable_v2_tensorshape"])
89def disable_v2_tensorshape():
90  """Disables the V2 TensorShape behavior and reverts to V1 behavior.
91
92  See docstring for `enable_v2_tensorshape` for details about the new behavior.
93  """
94  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
95  _TENSORSHAPE_V2_OVERRIDE = False
96  logging.vlog(1, "Disabling v2 tensorshape")
97  _api_usage_gauge.get_cell().set(False)
98
99
100@tf_export(
101    "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"])
102def dimension_value(dimension):
103  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
104
105  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
106  coexist with the new behavior. This utility is a bridge between the two.
107
108  When accessing the value of a TensorShape dimension,
109  use this utility, like this:
110
111  ```
112  # If you had this in your V1 code:
113  value = tensor_shape[i].value
114
115  # Use `dimension_value` as direct replacement compatible with both V1 & V2:
116  value = dimension_value(tensor_shape[i])
117
118  # This would be the V2 equivalent:
119  value = tensor_shape[i]  # Warning: this will return the dim value in V2!
120  ```
121
122  Args:
123    dimension: Either a `Dimension` instance, an integer, or None.
124
125  Returns:
126    A plain value, i.e. an integer or None.
127  """
128  if isinstance(dimension, Dimension):
129    return dimension.value
130  return dimension
131
132
133@tf_export(
134    "compat.dimension_at_index",
135    v1=["dimension_at_index", "compat.dimension_at_index"])
136def dimension_at_index(shape, index):
137  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
138
139  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
140  coexist with the new behavior. This utility is a bridge between the two.
141
142  If you want to retrieve the Dimension instance corresponding to a certain
143  index in a TensorShape instance, use this utility, like this:
144
145  ```
146  # If you had this in your V1 code:
147  dim = tensor_shape[i]
148
149  # Use `dimension_at_index` as direct replacement compatible with both V1 & V2:
150  dim = dimension_at_index(tensor_shape, i)
151
152  # Another possibility would be this, but WARNING: it only works if the
153  # tensor_shape instance has a defined rank.
154  dim = tensor_shape.dims[i]  # `dims` may be None if the rank is undefined!
155
156  # In native V2 code, we recommend instead being more explicit:
157  if tensor_shape.rank is None:
158    dim = Dimension(None)
159  else:
160    dim = tensor_shape.dims[i]
161
162  # Being more explicit will save you from the following trap (present in V1):
163  # you might do in-place modifications to `dim` and expect them to be reflected
164  # in `tensor_shape[i]`, but they would not be (as the Dimension object was
165  # instantiated on the fly.
166  ```
167
168  Args:
169    shape: A TensorShape instance.
170    index: An integer index.
171
172  Returns:
173    A dimension object.
174  """
175  assert isinstance(shape, TensorShape)
176  if shape.rank is None:
177    return Dimension(None)
178  else:
179    return shape.dims[index]
180
181
182@tf_export(v1=["Dimension"])
183class Dimension(object):
184  """Represents the value of one dimension in a TensorShape.
185
186  @compatibility(TF2)
187  In TF2, members of a `TensorShape` object are integers. The `Dimension` class
188  is not part of TF2's data model.
189
190  Please refer to the [TensorShape section of the migration guide]
191  (https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code
192  patterns adapting Dimension objects to a TF2 syntax.
193  @end_compatibility
194  """
195
196  __slots__ = ["_value"]
197
198  def __init__(self, value):
199    """Creates a new Dimension with the given value."""
200    if isinstance(value, int):  # Most common case.
201      if value < 0:
202        raise ValueError("Dimension %d must be >= 0" % value)
203      self._value = value
204    elif value is None:
205      self._value = None
206    elif isinstance(value, Dimension):
207      self._value = value._value
208    else:
209      try:
210        # int(...) compensates for the int/long dichotomy on Python 2.X.
211        # TODO(b/143206389): Remove once we fully migrate to 3.X.
212        self._value = int(value.__index__())
213      except AttributeError:
214        raise TypeError(
215            "Dimension value must be integer or None or have "
216            "an __index__ method, got value '{0!r}' with type '{1!r}'".format(
217                value, type(value))) from None
218      if self._value < 0:
219        raise ValueError("Dimension %d must be >= 0" % self._value)
220
221  def __repr__(self):
222    return "Dimension(%s)" % repr(self._value)
223
224  def __str__(self):
225    value = self._value
226    return "?" if value is None else str(value)
227
228  def __eq__(self, other):
229    """Returns true if `other` has the same known value as this Dimension."""
230    try:
231      other = as_dimension(other)
232    except (TypeError, ValueError):
233      return NotImplemented
234    if self._value is None or other.value is None:
235      return None
236    return self._value == other.value
237
238  def __ne__(self, other):
239    """Returns true if `other` has a different known value from `self`."""
240    try:
241      other = as_dimension(other)
242    except (TypeError, ValueError):
243      return NotImplemented
244    if self._value is None or other.value is None:
245      return None
246    return self._value != other.value
247
248  def __bool__(self):
249    """Equivalent to `bool(self.value)`."""
250    return bool(self._value)
251
252  def __int__(self):
253    return self._value
254
255  # This is needed for Windows.
256  # See https://github.com/tensorflow/tensorflow/pull/9780
257  def __long__(self):
258    return self._value
259
260  def __index__(self):
261    # Allow use in Python 3 range
262    return self._value
263
264  @property
265  def value(self):
266    """The value of this dimension, or None if it is unknown."""
267    return self._value
268
269  # TODO(b/225058047): Reconsider semantics.
270  def is_compatible_with(self, other):
271    """Returns true if `other` is compatible with this Dimension.
272
273    Two known Dimensions are compatible if they have the same value.
274    An unknown Dimension is compatible with all other Dimensions.
275
276    Args:
277      other: Another Dimension.
278
279    Returns:
280      True if this Dimension and `other` are compatible.
281    """
282    other = as_dimension(other)
283    return (self._value is None or other.value is None or
284            self._value == other.value)
285
286  def assert_is_compatible_with(self, other):
287    """Raises an exception if `other` is not compatible with this Dimension.
288
289    Args:
290      other: Another Dimension.
291
292    Raises:
293      ValueError: If `self` and `other` are not compatible (see
294        is_compatible_with).
295    """
296    if not self.is_compatible_with(other):
297      raise ValueError("Dimensions %s and %s are not compatible" %
298                       (self, other))
299
300  def merge_with(self, other):
301    """Returns a Dimension that combines the information in `self` and `other`.
302
303    Dimensions are combined as follows:
304
305    ```python
306    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(n))     ==
307    tf.compat.v1.Dimension(n)
308    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(None))  ==
309    tf.compat.v1.Dimension(n)
310    tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n))     ==
311    tf.compat.v1.Dimension(n)
312    # equivalent to tf.compat.v1.Dimension(None)
313    tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None))
314
315    # raises ValueError for n != m
316    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(m))
317    ```
318
319    Args:
320      other: Another Dimension.
321
322    Returns:
323      A Dimension containing the combined information of `self` and
324      `other`.
325
326    Raises:
327      ValueError: If `self` and `other` are not compatible (see
328        is_compatible_with).
329    """
330    other = as_dimension(other)
331    self.assert_is_compatible_with(other)
332    if self._value is None:
333      return Dimension(other.value)
334    else:
335      return Dimension(self._value)
336
337  def __add__(self, other):
338    """Returns the sum of `self` and `other`.
339
340    Dimensions are summed as follows:
341
342    ```python
343    tf.compat.v1.Dimension(m)    + tf.compat.v1.Dimension(n)     ==
344    tf.compat.v1.Dimension(m + n)
345    tf.compat.v1.Dimension(m)    + tf.compat.v1.Dimension(None)  # equiv. to
346    tf.compat.v1.Dimension(None)
347    tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n)     # equiv. to
348    tf.compat.v1.Dimension(None)
349    tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None)  # equiv. to
350    tf.compat.v1.Dimension(None)
351    ```
352
353    Args:
354      other: Another Dimension, or a value accepted by `as_dimension`.
355
356    Returns:
357      A Dimension whose value is the sum of `self` and `other`.
358    """
359    try:
360      other = as_dimension(other)
361    except (TypeError, ValueError):
362      return NotImplemented
363    if self._value is None or other.value is None:
364      return Dimension(None)
365    else:
366      return Dimension(self._value + other.value)
367
368  def __radd__(self, other):
369    """Returns the sum of `other` and `self`.
370
371    Args:
372      other: Another Dimension, or a value accepted by `as_dimension`.
373
374    Returns:
375      A Dimension whose value is the sum of `self` and `other`.
376    """
377    return self + other
378
379  def __sub__(self, other):
380    """Returns the subtraction of `other` from `self`.
381
382    Dimensions are subtracted as follows:
383
384    ```python
385    tf.compat.v1.Dimension(m)    - tf.compat.v1.Dimension(n)     ==
386    tf.compat.v1.Dimension(m - n)
387    tf.compat.v1.Dimension(m)    - tf.compat.v1.Dimension(None)  # equiv. to
388    tf.compat.v1.Dimension(None)
389    tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n)     # equiv. to
390    tf.compat.v1.Dimension(None)
391    tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None)  # equiv. to
392    tf.compat.v1.Dimension(None)
393    ```
394
395    Args:
396      other: Another Dimension, or a value accepted by `as_dimension`.
397
398    Returns:
399      A Dimension whose value is the subtraction of `other` from `self`.
400    """
401    try:
402      other = as_dimension(other)
403    except (TypeError, ValueError):
404      return NotImplemented
405    if self._value is None or other.value is None:
406      return Dimension(None)
407    else:
408      return Dimension(self._value - other.value)
409
410  def __rsub__(self, other):
411    """Returns the subtraction of `self` from `other`.
412
413    Args:
414      other: Another Dimension, or a value accepted by `as_dimension`.
415
416    Returns:
417      A Dimension whose value is the subtraction of `self` from `other`.
418    """
419    other = as_dimension(other)
420    if self._value is None or other.value is None:
421      return Dimension(None)
422    else:
423      return Dimension(other.value - self._value)
424
425  def __mul__(self, other):
426    """Returns the product of `self` and `other`.
427
428    Dimensions are summed as follows:
429
430    ```python
431    tf.compat.v1.Dimension(m)    * tf.compat.v1.Dimension(n)     ==
432    tf.compat.v1.Dimension(m * n)
433    tf.compat.v1.Dimension(m)    * tf.compat.v1.Dimension(None)  # equiv. to
434    tf.compat.v1.Dimension(None)
435    tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n)     # equiv. to
436    tf.compat.v1.Dimension(None)
437    tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None)  # equiv. to
438    tf.compat.v1.Dimension(None)
439    ```
440
441    Args:
442      other: Another Dimension, or a value accepted by `as_dimension`.
443
444    Returns:
445      A Dimension whose value is the product of `self` and `other`.
446    """
447    try:
448      other = as_dimension(other)
449    except (TypeError, ValueError):
450      return NotImplemented
451
452    if self._value is None or other.value is None:
453      return Dimension(None)
454    else:
455      return Dimension(self._value * other.value)
456
457  def __rmul__(self, other):
458    """Returns the product of `self` and `other`.
459
460    Args:
461      other: Another Dimension, or a value accepted by `as_dimension`.
462
463    Returns:
464      A Dimension whose value is the product of `self` and `other`.
465    """
466    return self * other
467
468  def __floordiv__(self, other):
469    """Returns the quotient of `self` and `other` rounded down.
470
471    Dimensions are divided as follows:
472
473    ```python
474    tf.compat.v1.Dimension(m)    // tf.compat.v1.Dimension(n)     ==
475    tf.compat.v1.Dimension(m // n)
476    tf.compat.v1.Dimension(m)    // tf.compat.v1.Dimension(None)  # equiv. to
477    tf.compat.v1.Dimension(None)
478    tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n)     # equiv. to
479    tf.compat.v1.Dimension(None)
480    tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None)  # equiv. to
481    tf.compat.v1.Dimension(None)
482    ```
483
484    Args:
485      other: Another Dimension, or a value accepted by `as_dimension`.
486
487    Returns:
488      A `Dimension` whose value is the integer quotient of `self` and `other`.
489    """
490    try:
491      other = as_dimension(other)
492    except (TypeError, ValueError):
493      return NotImplemented
494    if self._value is None or other.value is None:
495      return Dimension(None)
496    else:
497      return Dimension(self._value // other.value)
498
499  def __rfloordiv__(self, other):
500    """Returns the quotient of `other` and `self` rounded down.
501
502    Args:
503      other: Another Dimension, or a value accepted by `as_dimension`.
504
505    Returns:
506      A `Dimension` whose value is the integer quotient of `self` and `other`.
507    """
508    other = as_dimension(other)
509    if self._value is None or other.value is None:
510      return Dimension(None)
511    else:
512      return Dimension(other.value // self._value)
513
514  def __div__(self, other):
515    """DEPRECATED: Use `__floordiv__` via `x // y` instead.
516
517    This function exists only for backwards compatibility purposes; new code
518    should use `__floordiv__` via the syntax `x // y`.  Using `x // y`
519    communicates clearly that the result rounds down, and is forward compatible
520    to Python 3.
521
522    Args:
523      other: Another `Dimension`.
524
525    Returns:
526      A `Dimension` whose value is the integer quotient of `self` and `other`.
527    """
528    return self // other
529
530  def __rdiv__(self, other):
531    """Use `__floordiv__` via `x // y` instead.
532
533    This function exists only to have a better error message. Instead of:
534    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
535    this function will explicitly call for usage of `//` instead.
536
537    Args:
538      other: Another `Dimension`.
539
540    Raises:
541      TypeError.
542    """
543    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
544                    "please use // instead".format(type(other).__name__))
545
546  def __truediv__(self, other):
547    """Use `__floordiv__` via `x // y` instead.
548
549    This function exists only to have a better error message. Instead of:
550    `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`,
551    this function will explicitly call for usage of `//` instead.
552
553    Args:
554      other: Another `Dimension`.
555
556    Raises:
557      TypeError.
558    """
559    raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', "
560                    "please use // instead".format(type(other).__name__))
561
562  def __rtruediv__(self, other):
563    """Use `__floordiv__` via `x // y` instead.
564
565    This function exists only to have a better error message. Instead of:
566    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
567    this function will explicitly call for usage of `//` instead.
568
569    Args:
570      other: Another `Dimension`.
571
572    Raises:
573      TypeError.
574    """
575    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
576                    "please use // instead".format(type(other).__name__))
577
578  def __mod__(self, other):
579    """Returns `self` modulo `other`.
580
581    Dimension modulo are computed as follows:
582
583    ```python
584    tf.compat.v1.Dimension(m)    % tf.compat.v1.Dimension(n)     ==
585    tf.compat.v1.Dimension(m % n)
586    tf.compat.v1.Dimension(m)    % tf.compat.v1.Dimension(None)  # equiv. to
587    tf.compat.v1.Dimension(None)
588    tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n)     # equiv. to
589    tf.compat.v1.Dimension(None)
590    tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None)  # equiv. to
591    tf.compat.v1.Dimension(None)
592    ```
593
594    Args:
595      other: Another Dimension, or a value accepted by `as_dimension`.
596
597    Returns:
598      A Dimension whose value is `self` modulo `other`.
599    """
600    other = as_dimension(other)
601    if self._value is None or other.value is None:
602      return Dimension(None)
603    else:
604      return Dimension(self._value % other.value)
605
606  def __rmod__(self, other):
607    """Returns `other` modulo `self`.
608
609    Args:
610      other: Another Dimension, or a value accepted by `as_dimension`.
611
612    Returns:
613      A Dimension whose value is `other` modulo `self`.
614    """
615    other = as_dimension(other)
616    return other % self
617
618  def __lt__(self, other):
619    """Returns True if `self` is known to be less than `other`.
620
621    Dimensions are compared as follows:
622
623    ```python
624    (tf.compat.v1.Dimension(m)    < tf.compat.v1.Dimension(n))    == (m < n)
625    (tf.compat.v1.Dimension(m)    < tf.compat.v1.Dimension(None)) == None
626    (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n))    == None
627    (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None
628    ```
629
630    Args:
631      other: Another Dimension.
632
633    Returns:
634      The value of `self.value < other.value` if both are known, otherwise
635      None.
636    """
637    other = as_dimension(other)
638    if self._value is None or other.value is None:
639      return None
640    else:
641      return self._value < other.value
642
643  def __le__(self, other):
644    """Returns True if `self` is known to be less than or equal to `other`.
645
646    Dimensions are compared as follows:
647
648    ```python
649    (tf.compat.v1.Dimension(m)    <= tf.compat.v1.Dimension(n))    == (m <= n)
650    (tf.compat.v1.Dimension(m)    <= tf.compat.v1.Dimension(None)) == None
651    (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n))    == None
652    (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None
653    ```
654
655    Args:
656      other: Another Dimension.
657
658    Returns:
659      The value of `self.value <= other.value` if both are known, otherwise
660      None.
661    """
662    other = as_dimension(other)
663    if self._value is None or other.value is None:
664      return None
665    else:
666      return self._value <= other.value
667
668  def __gt__(self, other):
669    """Returns True if `self` is known to be greater than `other`.
670
671    Dimensions are compared as follows:
672
673    ```python
674    (tf.compat.v1.Dimension(m)    > tf.compat.v1.Dimension(n))    == (m > n)
675    (tf.compat.v1.Dimension(m)    > tf.compat.v1.Dimension(None)) == None
676    (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n))    == None
677    (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None
678    ```
679
680    Args:
681      other: Another Dimension.
682
683    Returns:
684      The value of `self.value > other.value` if both are known, otherwise
685      None.
686    """
687    other = as_dimension(other)
688    if self._value is None or other.value is None:
689      return None
690    else:
691      return self._value > other.value
692
693  def __ge__(self, other):
694    """Returns True if `self` is known to be greater than or equal to `other`.
695
696    Dimensions are compared as follows:
697
698    ```python
699    (tf.compat.v1.Dimension(m)    >= tf.compat.v1.Dimension(n))    == (m >= n)
700    (tf.compat.v1.Dimension(m)    >= tf.compat.v1.Dimension(None)) == None
701    (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n))    == None
702    (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None
703    ```
704
705    Args:
706      other: Another Dimension.
707
708    Returns:
709      The value of `self.value >= other.value` if both are known, otherwise
710      None.
711    """
712    other = as_dimension(other)
713    if self._value is None or other.value is None:
714      return None
715    else:
716      return self._value >= other.value
717
718  def __reduce__(self):
719    return Dimension, (self._value,)
720
721
722def as_dimension(value):
723  """Converts the given value to a Dimension.
724
725  A Dimension input will be returned unmodified.
726  An input of `None` will be converted to an unknown Dimension.
727  An integer input will be converted to a Dimension with that value.
728
729  Args:
730    value: The value to be converted.
731
732  Returns:
733    A Dimension corresponding to the given value.
734  """
735  if isinstance(value, Dimension):
736    return value
737  else:
738    return Dimension(value)
739
740
741@tf_export("TensorShape")
742class TensorShape(trace.TraceType, trace_type.Serializable):
743  """Represents the shape of a `Tensor`.
744
745  A `TensorShape` represents a possibly-partial shape specification for a
746  `Tensor`. It may be one of the following:
747
748  * *Fully-known shape:* has a known number of dimensions and a known size
749    for each dimension. e.g. `TensorShape([16, 256])`
750  * *Partially-known shape:* has a known number of dimensions, and an unknown
751    size for one or more dimension. e.g. `TensorShape([None, 256])`
752  * *Unknown shape:* has an unknown number of dimensions, and an unknown
753    size in all dimensions. e.g. `TensorShape(None)`
754
755  If a tensor is produced by an operation of type `"Foo"`, its shape
756  may be inferred if there is a registered shape function for
757  `"Foo"`. See [Shape
758  functions](https://www.tensorflow.org/guide/create_op#shape_functions_in_c)
759  for details of shape functions and how to register them. Alternatively,
760  you may set the shape explicitly using `tf.Tensor.set_shape`.
761  """
762  __slots__ = ["_dims"]
763
764  def __init__(self, dims):
765    """Creates a new TensorShape with the given dimensions.
766
767    Args:
768      dims: A list of Dimensions, or None if the shape is unspecified.
769
770    Raises:
771      TypeError: If dims cannot be converted to a list of dimensions.
772    """
773    if isinstance(dims, (tuple, list)):  # Most common case.
774      self._dims = tuple(as_dimension(d).value for d in dims)
775    elif dims is None:
776      self._dims = None
777    elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
778      if dims.unknown_rank:
779        self._dims = None
780      else:
781        self._dims = tuple(
782            # Protos store variable-size dimensions as -1
783            dim.size if dim.size != -1 else None
784            for dim in dims.dim
785            )
786    elif isinstance(dims, TensorShape):
787      self._dims = dims._dims
788    else:
789      try:
790        dims_iter = iter(dims)
791      except TypeError:
792        # Treat as a singleton dimension
793        self._dims = (as_dimension(dims).value,)
794      else:
795        self._dims = []
796        for d in dims_iter:
797          try:
798            self._dims.append(as_dimension(d).value)
799          except TypeError as e:
800            raise TypeError(
801                "Failed to convert '{0!r}' to a shape: '{1!r}'"
802                "could not be converted to a dimension. A shape should "
803                "either be single dimension (e.g. 10), or an iterable of "
804                "dimensions (e.g. [1, 10, None]).".format(dims, d)) from e
805        self._dims = tuple(self._dims)
806
807  @property
808  def _v2_behavior(self):
809    if _TENSORSHAPE_V2_OVERRIDE is None:
810      return tf2.enabled()
811    return _TENSORSHAPE_V2_OVERRIDE
812
813  def __repr__(self):
814    if self._v2_behavior:
815      if self._dims is not None:
816        return f"TensorShape({list(self._dims)})"
817      else:
818        return "TensorShape(None)"
819    else:
820      return f"TensorShape({self.dims})"
821
822  def __str__(self):
823    if self.rank is None:
824      return "<unknown>"
825    elif self.rank == 1:
826      if self._v2_behavior:
827        return "(%s,)" % self._dims[0]
828      else:
829        return "(%s,)" % self.dims[0]
830    else:
831      if self._v2_behavior:
832        return "(%s)" % ", ".join(str(d) for d in self._dims)
833      else:
834        return "(%s)" % ", ".join(str(d) for d in self.dims)
835
836  @property
837  def rank(self):
838    """Returns the rank of this shape, or None if it is unspecified."""
839    if self._dims is not None:
840      return len(self._dims)
841    return None
842
843  @property
844  def dims(self):
845    """Deprecated.  Returns list of dimensions for this shape.
846
847    Suggest `TensorShape.as_list` instead.
848
849    Returns:
850      A list containing `tf.compat.v1.Dimension`s, or None if the shape is
851      unspecified.
852    """
853    if self._dims is None:
854      return None
855    return [as_dimension(d) for d in self._dims]
856
857  @property
858  def ndims(self):
859    """Deprecated accessor for `rank`."""
860    return self.rank
861
862  def __len__(self):
863    """Returns the rank of this shape, or raises ValueError if unspecified."""
864    if self._dims is None:
865      raise ValueError("Cannot take the length of shape with unknown rank.")
866    return len(self._dims)
867
868  def __bool__(self):
869    """Returns True if this shape contains non-zero information."""
870    return self._dims is not None
871
872  # Python 3 wants __bool__, Python 2.7 wants __nonzero__
873  __nonzero__ = __bool__
874
875  def __iter__(self):
876    """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
877    if self._dims is None:
878      raise ValueError("Cannot iterate over a shape with unknown rank.")
879    else:
880      if self._v2_behavior:
881        return iter(d for d in self._dims)
882      else:
883        return iter(d for d in self.dims)
884
885  def __getitem__(self, key):
886    """Returns the value of a dimension or a shape, depending on the key.
887
888    Args:
889      key: If `key` is an integer, returns the dimension at that index;
890        otherwise if `key` is a slice, returns a TensorShape whose dimensions
891        are those selected by the slice from `self`.
892
893    Returns:
894      An integer if `key` is an integer, or a `TensorShape` if `key` is a
895      slice.
896
897    Raises:
898      ValueError: If `key` is a slice and `self` is completely unknown and
899        the step is set.
900    """
901    if self._dims is not None:
902      if isinstance(key, slice):
903        return TensorShape(self._dims[key])
904      else:
905        if self._v2_behavior:
906          return self._dims[key]
907        else:
908          return self.dims[key]
909    else:
910      if isinstance(key, slice):
911        start = key.start if key.start is not None else 0
912        stop = key.stop
913
914        if key.step is not None:
915          # TODO(mrry): Handle these maybe.
916          raise ValueError("Steps are not yet handled")
917        if stop is None:
918          # NOTE(mrry): This implies that TensorShape(None) is compatible with
919          # TensorShape(None)[1:], which is obviously not true. It would be
920          # possible to track the number of dimensions symbolically,
921          # and perhaps we should do that.
922          return unknown_shape()
923        elif start < 0 or stop < 0:
924          # TODO(mrry): Handle this better, as it will be useful for handling
925          # suffixes of otherwise unknown shapes.
926          return unknown_shape()
927        else:
928          return unknown_shape(rank=stop - start)
929      else:
930        if self._v2_behavior:
931          return None
932        else:
933          return Dimension(None)
934
935  def num_elements(self):
936    """Returns the total number of elements, or none for incomplete shapes."""
937    if self.is_fully_defined():
938      return functools.reduce(operator.mul, self.as_list(), 1)
939    else:
940      return None
941
942  def merge_with(self, other):
943    """Returns a `TensorShape` combining the information in `self` and `other`.
944
945    The dimensions in `self` and `other` are merged element-wise,
946    according to the rules below:
947
948    ```python
949    Dimension(n).merge_with(Dimension(None)) == Dimension(n)
950    Dimension(None).merge_with(Dimension(n)) == Dimension(n)
951    Dimension(None).merge_with(Dimension(None)) == Dimension(None)
952    # raises ValueError for n != m
953    Dimension(n).merge_with(Dimension(m))
954    ```
955    >> ts = tf.TensorShape([1,2])
956    >> ot1 = tf.TensorShape([1,2])
957    >> ts.merge_with(ot).as_list()
958    [1,2]
959
960    >> ot2 = tf.TensorShape([1,None])
961    >> ts.merge_with(ot2).as_list()
962    [1,2]
963
964    >> ot3 = tf.TensorShape([None, None])
965    >> ot3.merge_with(ot2).as_list()
966    [1, None]
967
968    Args:
969      other: Another `TensorShape`.
970
971    Returns:
972      A `TensorShape` containing the combined information of `self` and
973      `other`.
974
975    Raises:
976      ValueError: If `self` and `other` are not compatible.
977    """
978    other = as_shape(other)
979    if self.dims is None:
980      return other
981    if other.dims is None:
982      return self
983    else:
984      try:
985        self.assert_same_rank(other)
986        new_dims = [
987            dim.merge_with(other_dim)
988            for dim, other_dim in zip(self.dims, other.dims)
989        ]
990        return TensorShape(new_dims)
991      except ValueError:
992        raise ValueError("Shapes %s and %s are not compatible" % (self, other))
993
994  def __add__(self, other):
995    return self.concatenate(other)
996
997  def __radd__(self, other):
998    if not isinstance(other, TensorShape):
999      other = TensorShape(other)
1000    return other.concatenate(self)
1001
1002  def concatenate(self, other):
1003    """Returns the concatenation of the dimension in `self` and `other`.
1004
1005    *N.B.* If either `self` or `other` is completely unknown,
1006    concatenation will discard information about the other shape. In
1007    future, we might support concatenation that preserves this
1008    information for use with slicing.
1009
1010    Args:
1011      other: Another `TensorShape`.
1012
1013    Returns:
1014      A `TensorShape` whose dimensions are the concatenation of the
1015      dimensions in `self` and `other`.
1016    """
1017    # TODO(mrry): Handle the case where we concatenate a known shape with a
1018    # completely unknown shape, so that we can use the partial information.
1019    other = as_shape(other)
1020    if self.dims is None or other.dims is None:
1021      return unknown_shape()
1022    else:
1023      return TensorShape(self.dims + other.dims)
1024
1025  def assert_same_rank(self, other):
1026    """Raises an exception if `self` and `other` do not have compatible ranks.
1027
1028    Args:
1029      other: Another `TensorShape`.
1030
1031    Raises:
1032      ValueError: If `self` and `other` do not represent shapes with the
1033        same rank.
1034    """
1035    other = as_shape(other)
1036    if self.rank is not None and other.rank is not None:
1037      if self.rank != other.rank:
1038        raise ValueError("Shapes %s and %s must have the same rank" %
1039                         (self, other))
1040
1041  def assert_has_rank(self, rank):
1042    """Raises an exception if `self` is not compatible with the given `rank`.
1043
1044    Args:
1045      rank: An integer.
1046
1047    Raises:
1048      ValueError: If `self` does not represent a shape with the given `rank`.
1049    """
1050    if self.rank not in (None, rank):
1051      raise ValueError("Shape %s must have rank %d" % (self, rank))
1052
1053  def with_rank(self, rank):
1054    """Returns a shape based on `self` with the given rank.
1055
1056    This method promotes a completely unknown shape to one with a
1057    known rank.
1058
1059    Args:
1060      rank: An integer.
1061
1062    Returns:
1063      A shape that is at least as specific as `self` with the given rank.
1064
1065    Raises:
1066      ValueError: If `self` does not represent a shape with the given `rank`.
1067    """
1068    try:
1069      return self.merge_with(unknown_shape(rank=rank))
1070    except ValueError:
1071      raise ValueError("Shape %s must have rank %d" % (self, rank))
1072
1073  def with_rank_at_least(self, rank):
1074    """Returns a shape based on `self` with at least the given rank.
1075
1076    Args:
1077      rank: An integer.
1078
1079    Returns:
1080      A shape that is at least as specific as `self` with at least the given
1081      rank.
1082
1083    Raises:
1084      ValueError: If `self` does not represent a shape with at least the given
1085        `rank`.
1086    """
1087    if self.rank is not None and self.rank < rank:
1088      raise ValueError("Shape %s must have rank at least %d" % (self, rank))
1089    else:
1090      return self
1091
1092  def with_rank_at_most(self, rank):
1093    """Returns a shape based on `self` with at most the given rank.
1094
1095    Args:
1096      rank: An integer.
1097
1098    Returns:
1099      A shape that is at least as specific as `self` with at most the given
1100      rank.
1101
1102    Raises:
1103      ValueError: If `self` does not represent a shape with at most the given
1104        `rank`.
1105    """
1106    if self.rank is not None and self.rank > rank:
1107      raise ValueError("Shape %s must have rank at most %d" % (self, rank))
1108    else:
1109      return self
1110
1111  def is_subtype_of(self, other: trace.TraceType) -> bool:
1112    """Returns True iff `self` is subtype of `other`.
1113
1114    Shape A is a subtype of shape B if shape B can successfully represent it:
1115
1116    * A `TensorShape` of any rank is a subtype of `TensorShape(None)`.
1117
1118    *  TensorShapes of equal ranks are covariant, i.e.
1119      `TensorShape([A1, A2, ..])` is a subtype of
1120      `TensorShape([B1, B2, ..])` iff An is a subtype of Bn.
1121
1122      An is subtype of Bn iff An == Bn or Bn is None.
1123
1124    * TensorShapes of different defined ranks have no subtyping relation.
1125
1126    The subtyping relation is reflexive and transitive, but not symmetric.
1127
1128    Some examples:
1129    * `TensorShape([32, 784])` is a subtype of `TensorShape(None)`, and
1130      `TensorShape([4, 4])` is also a subtype of `TensorShape(None)` but
1131      `TensorShape([32, 784])` and `TensorShape([4, 4])` are not subtypes of
1132      each other.
1133
1134    * All two-dimensional shapes are subtypes of `TensorShape([None, None])`,
1135      such as `TensorShape([32, 784])`. There is no subtype relationship with,
1136      for example, `TensorShape([None])` or `TensorShape([None, None, None])`.
1137
1138    * `TensorShape([32, None])` is also a subtype of `TensorShape([None, None])`
1139      and `TensorShape(None)`. It is not a subtype of, for example,
1140      `TensorShape([32])`, `TensorShape([32, None, 1])`,
1141      `TensorShape([64, None])` or `TensorShape([None, 32])`.
1142
1143    * `TensorShape([32, 784])` is a subtype of itself, and also
1144      `TensorShape([32, None])`, `TensorShape([None, 784])`,
1145      `TensorShape([None, None])` and `TensorShape(None)`.
1146      It has no subtype relation with, for example, `TensorShape([32, 1, 784])`
1147      or `TensorShape([None])`.
1148
1149    Args:
1150      other: Another `TensorShape`.
1151
1152    Returns:
1153      True iff `self` is subtype of `other`.
1154
1155    """
1156    if not isinstance(other, TensorShape):
1157      return False
1158
1159    # All Tensors are subtypes of a Tensor with no shape.
1160    if other.rank is None:
1161      return True
1162
1163    # Tensor with a defined shape can only be subtype of another with a defined
1164    # shape if they have the same number of dimensions.
1165    if self.rank != other.rank:
1166      return False
1167
1168    # A Tensor is a subtype if each corresponding dimension is a subtype.
1169    return all(o is None or s == o for s, o in zip(self._dims, other._dims))  # pylint: disable=protected-access
1170
1171  def most_specific_common_supertype(
1172      self, others: Sequence[trace.TraceType]) -> Optional["TensorShape"]:
1173    """Returns the most specific supertype `TensorShape` of self and others.
1174
1175    * `TensorShape([None, 1])` is the most specific `TensorShape` supertyping
1176      both `TensorShape([2, 1])` and `TensorShape([5, 1])`. Note that
1177      `TensorShape(None)` is also a supertype but it is not "most specific".
1178
1179    * `TensorShape([1, 2, 3])` is the most specific `TensorShape` supertyping
1180      both `TensorShape([1, 2, 3])` and `TensorShape([1, 2, 3]`). There are
1181      other less specific TensorShapes that supertype above mentioned
1182      TensorShapes, e.g. `TensorShape([1, 2, None])`, `TensorShape(None)`.
1183
1184     * `TensorShape([None, None])` is the most specific `TensorShape`
1185       supertyping both `TensorShape([2, None])` and `TensorShape([None, 3])`.
1186       As always, `TensorShape(None)` is also a supertype but not the most
1187       specific one.
1188
1189     * `TensorShape(None`) is the only `TensorShape` supertyping both
1190       `TensorShape([1, 2, 3])` and `TensorShape([1, 2])`. In general, any two
1191       shapes that have different ranks will only have `TensorShape(None)`
1192       as a common supertype.
1193
1194     * `TensorShape(None)` is the only `TensorShape` supertyping both
1195       `TensorShape([1, 2, 3])` and `TensorShape(None)`. In general, the common
1196       supertype of any shape with `TensorShape(None)` is `TensorShape(None)`.
1197
1198    Args:
1199      others: Sequence of `TensorShape`.
1200
1201    Returns:
1202      A `TensorShape` which is the most specific supertype shape of `self`
1203      and `others`. None if it does not exist.
1204    """
1205    if any(not isinstance(other, TensorShape) for other in others):
1206      return None
1207
1208    # A Rankless TensorShape is already a global supertype so we return another
1209    # instance of it.
1210    if self.rank is None:
1211      return unknown_shape()
1212
1213    # A Rankless TensorShape is the most specific supertype for shapes whose
1214    # ranks do not match.
1215    if any(other.dims is None or self.rank != other.rank for other in others):
1216      return unknown_shape()
1217
1218    # Retain the integer dimension if it is the same across all others, else
1219    # use an undefined dimension.
1220    dims = [
1221        dim if all(dim == other._dims[i]
1222                   for other in others) else None
1223        for i, dim in enumerate(self._dims)
1224    ]
1225    return TensorShape(dims)
1226
1227  @classmethod
1228  def experimental_type_proto(cls) -> Type[tensor_shape_pb2.TensorShapeProto]:
1229    """Returns the type of proto associated with TensorShape serialization."""
1230    return tensor_shape_pb2.TensorShapeProto
1231
1232  @classmethod
1233  def experimental_from_proto(
1234      cls, proto: tensor_shape_pb2.TensorShapeProto) -> "TensorShape":
1235    """Returns a TensorShape instance based on the serialized proto."""
1236    return TensorShape(proto)
1237
1238  def experimental_as_proto(self) -> tensor_shape_pb2.TensorShapeProto:
1239    """Returns a proto representation of the TensorShape instance."""
1240    return self.as_proto()
1241
1242  # TODO(b/216206374): Consider deprecation at TraceType release.
1243  def is_compatible_with(self, other):
1244    """Returns True iff `self` is compatible with `other`.
1245
1246    Two possibly-partially-defined shapes are compatible if there
1247    exists a fully-defined shape that both shapes can represent. Thus,
1248    compatibility allows the shape inference code to reason about
1249    partially-defined shapes. For example:
1250
1251    * TensorShape(None) is compatible with all shapes.
1252
1253    * TensorShape([None, None]) is compatible with all two-dimensional
1254      shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
1255      not compatible with, for example, TensorShape([None]) or
1256      TensorShape([None, None, None]).
1257
1258    * TensorShape([32, None]) is compatible with all two-dimensional shapes
1259      with size 32 in the 0th dimension, and also TensorShape([None, None])
1260      and TensorShape(None). It is not compatible with, for example,
1261      TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
1262
1263    * TensorShape([32, 784]) is compatible with itself, and also
1264      TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
1265      None]) and TensorShape(None). It is not compatible with, for example,
1266      TensorShape([32, 1, 784]) or TensorShape([None]).
1267
1268    The compatibility relation is reflexive and symmetric, but not
1269    transitive. For example, TensorShape([32, 784]) is compatible with
1270    TensorShape(None), and TensorShape(None) is compatible with
1271    TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
1272    TensorShape([4, 4]).
1273
1274    Args:
1275      other: Another TensorShape.
1276
1277    Returns:
1278      True iff `self` is compatible with `other`.
1279
1280    """
1281    other = as_shape(other)
1282    if self.dims is not None and other.dims is not None:
1283      if self.rank != other.rank:
1284        return False
1285      for x_dim, y_dim in zip(self.dims, other.dims):
1286        if not x_dim.is_compatible_with(y_dim):
1287          return False
1288    return True
1289
1290  def assert_is_compatible_with(self, other):
1291    """Raises exception if `self` and `other` do not represent the same shape.
1292
1293    This method can be used to assert that there exists a shape that both
1294    `self` and `other` represent.
1295
1296    Args:
1297      other: Another TensorShape.
1298
1299    Raises:
1300      ValueError: If `self` and `other` do not represent the same shape.
1301    """
1302    if not self.is_compatible_with(other):
1303      raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1304
1305  def most_specific_compatible_shape(self, other):
1306    """Returns the most specific TensorShape compatible with `self` and `other`.
1307
1308    * TensorShape([None, 1]) is the most specific TensorShape compatible with
1309      both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
1310      TensorShape(None) is also compatible with above mentioned TensorShapes.
1311
1312    * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
1313      both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
1314      less specific TensorShapes compatible with above mentioned TensorShapes,
1315      e.g. TensorShape([1, 2, None]), TensorShape(None).
1316
1317    Args:
1318      other: Another `TensorShape`.
1319
1320    Returns:
1321      A `TensorShape` which is the most specific compatible shape of `self`
1322      and `other`.
1323    """
1324
1325    other = as_shape(other)
1326    if self.dims is None or other.dims is None or self.rank != other.rank:
1327      return unknown_shape()
1328
1329    dims = [
1330        d1 if d1 is not None and d2 is not None and d1 == d2 else None
1331        for d1, d2 in zip(self.dims, other.dims)
1332    ]
1333    return TensorShape(dims)
1334
1335  def is_fully_defined(self):
1336    """Returns True iff `self` is fully defined in every dimension."""
1337    return (self._dims is not None and
1338            all(dim is not None for dim in self._dims))
1339
1340  def assert_is_fully_defined(self):
1341    """Raises an exception if `self` is not fully defined in every dimension.
1342
1343    Raises:
1344      ValueError: If `self` does not have a known value for every dimension.
1345    """
1346    if not self.is_fully_defined():
1347      raise ValueError("Shape %s is not fully defined" % self)
1348
1349  def as_list(self):
1350    """Returns a list of integers or `None` for each dimension.
1351
1352    Returns:
1353      A list of integers or `None` for each dimension.
1354
1355    Raises:
1356      ValueError: If `self` is an unknown shape with an unknown rank.
1357    """
1358    if self._dims is None:
1359      raise ValueError("as_list() is not defined on an unknown TensorShape.")
1360    return list(self._dims)
1361
1362  def as_proto(self):
1363    """Returns this shape as a `TensorShapeProto`."""
1364    if self._dims is None:
1365      return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
1366    else:
1367      return tensor_shape_pb2.TensorShapeProto(dim=[
1368          tensor_shape_pb2.TensorShapeProto.Dim(
1369              size=-1 if d is None else d) for d in self._dims
1370      ])
1371
1372  def __eq__(self, other):
1373    """Returns True if `self` is equivalent to `other`.
1374
1375    It first tries to convert `other` to `TensorShape`. `TypeError` is thrown
1376    when the conversion fails. Otherwise, it compares each element in the
1377    TensorShape dimensions.
1378
1379    * Two *Fully known* shapes, return True iff each element is equal.
1380    >>> t_a = tf.TensorShape([1,2])
1381    >>> a = [1, 2]
1382    >>> t_b = tf.TensorShape([1,2])
1383    >>> t_c = tf.TensorShape([1,2,3])
1384    >>> t_a.__eq__(a)
1385    True
1386    >>> t_a.__eq__(t_b)
1387    True
1388    >>> t_a.__eq__(t_c)
1389    False
1390
1391    * Two *Partially-known* shapes, return True iff each element is equal.
1392    >>> p_a = tf.TensorShape([1,None])
1393    >>> p_b = tf.TensorShape([1,None])
1394    >>> p_c = tf.TensorShape([2,None])
1395    >>> p_a.__eq__(p_b)
1396    True
1397    >>> t_a.__eq__(p_a)
1398    False
1399    >>> p_a.__eq__(p_c)
1400    False
1401
1402    * Two *Unknown shape*, return True.
1403    >>> unk_a = tf.TensorShape(None)
1404    >>> unk_b = tf.TensorShape(None)
1405    >>> unk_a.__eq__(unk_b)
1406    True
1407    >>> unk_a.__eq__(t_a)
1408    False
1409
1410    Args:
1411      other: A `TensorShape` or type that can be converted to `TensorShape`.
1412
1413    Returns:
1414      True if the dimensions are all equal.
1415
1416    Raises:
1417      TypeError if `other` can not be converted to `TensorShape`.
1418    """
1419
1420    try:
1421      other = as_shape(other)
1422    except TypeError:
1423      return NotImplemented
1424
1425    return self._dims == other._dims
1426
1427  def __hash__(self):
1428    return hash(self._dims)
1429
1430  def __reduce__(self):
1431    return TensorShape, (self.dims,)
1432
1433  def __concat__(self, other):
1434    return self.concatenate(other)
1435
1436trace_type.register_serializable(TensorShape)
1437
1438
1439def as_shape(shape):
1440  """Converts the given object to a TensorShape."""
1441  if isinstance(shape, TensorShape):
1442    return shape
1443  else:
1444    return TensorShape(shape)
1445
1446
1447def unknown_shape(rank=None, **kwargs):
1448  """Returns an unknown TensorShape, optionally with a known rank.
1449
1450  Args:
1451    rank: (Optional) If specified, the number of dimensions in the shape.
1452    **kwargs: For backwards compatibility.
1453
1454  Returns:
1455    An unknown TensorShape.
1456
1457  Raises:
1458    TypeError: In case of invalid arguments.
1459  """
1460  if rank is None and "ndims" in kwargs:
1461    rank = kwargs.pop("ndims")
1462  if kwargs:
1463    raise TypeError("Unknown argument: %s" % kwargs)
1464  if rank is None:
1465    return TensorShape(None)
1466  else:
1467    return TensorShape([Dimension(None)] * rank)
1468