• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes for tensor shape inference."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import operator
22import six
23
24from tensorflow.core.framework import tensor_shape_pb2
25from tensorflow.python import tf2
26from tensorflow.python.eager import monitoring
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.util.tf_export import tf_export
29
30_TENSORSHAPE_V2_OVERRIDE = None
31
32_api_usage_gauge = monitoring.BoolGauge(
33    "/tensorflow/api/v2_tensorshape",
34    "Whether tensor_shape.enable_v2_tensorshape() is called.")
35
36
37@tf_export(v1=["enable_v2_tensorshape"])
38def enable_v2_tensorshape():
39  """In TensorFlow 2.0, iterating over a TensorShape instance returns values.
40
41  This enables the new behavior.
42
43  Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but
44  it V2 it returns either an integer, or None.
45
46  Examples:
47
48  ```
49  #######################
50  # If you had this in V1:
51  value = tensor_shape[i].value
52
53  # Do this in V2 instead:
54  value = tensor_shape[i]
55
56  #######################
57  # If you had this in V1:
58  for dim in tensor_shape:
59    value = dim.value
60    print(value)
61
62  # Do this in V2 instead:
63  for value in tensor_shape:
64    print(value)
65
66  #######################
67  # If you had this in V1:
68  dim = tensor_shape[i]
69  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
70
71  # Do this in V2 instead:
72  if tensor_shape.rank is None:
73    dim = Dimension(None)
74  else:
75    dim = tensor_shape.dims[i]
76  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
77
78  # The V2 suggestion above is more explicit, which will save you from
79  # the following trap (present in V1):
80  # you might do in-place modifications to `dim` and expect them to be reflected
81  # in `tensor_shape[i]`, but they would not be.
82  ```
83  """
84  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
85  _TENSORSHAPE_V2_OVERRIDE = True
86  logging.vlog(1, "Enabling v2 tensorshape")
87  _api_usage_gauge.get_cell().set(True)
88
89
90@tf_export(v1=["disable_v2_tensorshape"])
91def disable_v2_tensorshape():
92  """Disables the V2 TensorShape behavior and reverts to V1 behavior.
93
94  See docstring for `enable_v2_tensorshape` for details about the new behavior.
95  """
96  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
97  _TENSORSHAPE_V2_OVERRIDE = False
98  logging.vlog(1, "Disabling v2 tensorshape")
99  _api_usage_gauge.get_cell().set(False)
100
101
102@tf_export(
103    "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"])
104def dimension_value(dimension):
105  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
106
107  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
108  coexist with the new behavior. This utility is a bridge between the two.
109
110  When accessing the value of a TensorShape dimension,
111  use this utility, like this:
112
113  ```
114  # If you had this in your V1 code:
115  value = tensor_shape[i].value
116
117  # Use `dimension_value` as direct replacement compatible with both V1 & V2:
118  value = dimension_value(tensor_shape[i])
119
120  # This would be the V2 equivalent:
121  value = tensor_shape[i]  # Warning: this will return the dim value in V2!
122  ```
123
124  Args:
125    dimension: Either a `Dimension` instance, an integer, or None.
126
127  Returns:
128    A plain value, i.e. an integer or None.
129  """
130  if isinstance(dimension, Dimension):
131    return dimension.value
132  return dimension
133
134
135@tf_export(
136    "compat.dimension_at_index",
137    v1=["dimension_at_index", "compat.dimension_at_index"])
138def dimension_at_index(shape, index):
139  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
140
141  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
142  coexist with the new behavior. This utility is a bridge between the two.
143
144  If you want to retrieve the Dimension instance corresponding to a certain
145  index in a TensorShape instance, use this utility, like this:
146
147  ```
148  # If you had this in your V1 code:
149  dim = tensor_shape[i]
150
151  # Use `dimension_at_index` as direct replacement compatible with both V1 & V2:
152  dim = dimension_at_index(tensor_shape, i)
153
154  # Another possibility would be this, but WARNING: it only works if the
155  # tensor_shape instance has a defined rank.
156  dim = tensor_shape.dims[i]  # `dims` may be None if the rank is undefined!
157
158  # In native V2 code, we recommend instead being more explicit:
159  if tensor_shape.rank is None:
160    dim = Dimension(None)
161  else:
162    dim = tensor_shape.dims[i]
163
164  # Being more explicit will save you from the following trap (present in V1):
165  # you might do in-place modifications to `dim` and expect them to be reflected
166  # in `tensor_shape[i]`, but they would not be (as the Dimension object was
167  # instantiated on the fly.
168  ```
169
170  Args:
171    shape: A TensorShape instance.
172    index: An integer index.
173
174  Returns:
175    A dimension object.
176  """
177  assert isinstance(shape, TensorShape)
178  if shape.rank is None:
179    return Dimension(None)
180  else:
181    return shape.dims[index]
182
183
184@tf_export(v1=["Dimension"])
185class Dimension(object):
186  """Represents the value of one dimension in a TensorShape.
187
188  @compatibility(TF2)
189  In TF2, members of a `TensorShape` object are integers. The `Dimension` class
190  is not part of TF2's data model.
191
192  Please refer to the [TensorShape section of the migration guide]
193  (https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code
194  patterns adapting Dimension objects to a TF2 syntax.
195  @end_compatibility
196  """
197
198  __slots__ = ["_value"]
199
200  def __init__(self, value):
201    """Creates a new Dimension with the given value."""
202    if isinstance(value, int):  # Most common case.
203      if value < 0:
204        raise ValueError("Dimension %d must be >= 0" % value)
205      self._value = value
206    elif value is None:
207      self._value = None
208    elif isinstance(value, Dimension):
209      self._value = value._value
210    else:
211      try:
212        # int(...) compensates for the int/long dichotomy on Python 2.X.
213        # TODO(b/143206389): Remove once we fully migrate to 3.X.
214        self._value = int(value.__index__())
215      except AttributeError:
216        six.raise_from(
217            TypeError("Dimension value must be integer or None or have "
218                      "an __index__ method, got value '{0!r}' with type '{1!r}'"
219                      .format(value, type(value))), None)
220      if self._value < 0:
221        raise ValueError("Dimension %d must be >= 0" % self._value)
222
223  def __repr__(self):
224    return "Dimension(%s)" % repr(self._value)
225
226  def __str__(self):
227    value = self._value
228    return "?" if value is None else str(value)
229
230  def __eq__(self, other):
231    """Returns true if `other` has the same known value as this Dimension."""
232    try:
233      other = as_dimension(other)
234    except (TypeError, ValueError):
235      return NotImplemented
236    if self._value is None or other.value is None:
237      return None
238    return self._value == other.value
239
240  def __ne__(self, other):
241    """Returns true if `other` has a different known value from `self`."""
242    try:
243      other = as_dimension(other)
244    except (TypeError, ValueError):
245      return NotImplemented
246    if self._value is None or other.value is None:
247      return None
248    return self._value != other.value
249
250  def __bool__(self):
251    """Equivalent to `bool(self.value)`."""
252    return bool(self._value)
253
254  def __int__(self):
255    return self._value
256
257  # This is needed for Windows.
258  # See https://github.com/tensorflow/tensorflow/pull/9780
259  def __long__(self):
260    return self._value
261
262  def __index__(self):
263    # Allow use in Python 3 range
264    return self._value
265
266  @property
267  def value(self):
268    """The value of this dimension, or None if it is unknown."""
269    return self._value
270
271  def is_compatible_with(self, other):
272    """Returns true if `other` is compatible with this Dimension.
273
274    Two known Dimensions are compatible if they have the same value.
275    An unknown Dimension is compatible with all other Dimensions.
276
277    Args:
278      other: Another Dimension.
279
280    Returns:
281      True if this Dimension and `other` are compatible.
282    """
283    other = as_dimension(other)
284    return (self._value is None or other.value is None or
285            self._value == other.value)
286
287  def assert_is_compatible_with(self, other):
288    """Raises an exception if `other` is not compatible with this Dimension.
289
290    Args:
291      other: Another Dimension.
292
293    Raises:
294      ValueError: If `self` and `other` are not compatible (see
295        is_compatible_with).
296    """
297    if not self.is_compatible_with(other):
298      raise ValueError("Dimensions %s and %s are not compatible" %
299                       (self, other))
300
301  def merge_with(self, other):
302    """Returns a Dimension that combines the information in `self` and `other`.
303
304    Dimensions are combined as follows:
305
306    ```python
307    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(n))     ==
308    tf.compat.v1.Dimension(n)
309    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(None))  ==
310    tf.compat.v1.Dimension(n)
311    tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n))     ==
312    tf.compat.v1.Dimension(n)
313    # equivalent to tf.compat.v1.Dimension(None)
314    tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None))
315
316    # raises ValueError for n != m
317    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(m))
318    ```
319
320    Args:
321      other: Another Dimension.
322
323    Returns:
324      A Dimension containing the combined information of `self` and
325      `other`.
326
327    Raises:
328      ValueError: If `self` and `other` are not compatible (see
329        is_compatible_with).
330    """
331    other = as_dimension(other)
332    self.assert_is_compatible_with(other)
333    if self._value is None:
334      return Dimension(other.value)
335    else:
336      return Dimension(self._value)
337
338  def __add__(self, other):
339    """Returns the sum of `self` and `other`.
340
341    Dimensions are summed as follows:
342
343    ```python
344    tf.compat.v1.Dimension(m)    + tf.compat.v1.Dimension(n)     ==
345    tf.compat.v1.Dimension(m + n)
346    tf.compat.v1.Dimension(m)    + tf.compat.v1.Dimension(None)  # equiv. to
347    tf.compat.v1.Dimension(None)
348    tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n)     # equiv. to
349    tf.compat.v1.Dimension(None)
350    tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None)  # equiv. to
351    tf.compat.v1.Dimension(None)
352    ```
353
354    Args:
355      other: Another Dimension, or a value accepted by `as_dimension`.
356
357    Returns:
358      A Dimension whose value is the sum of `self` and `other`.
359    """
360    try:
361      other = as_dimension(other)
362    except (TypeError, ValueError):
363      return NotImplemented
364    if self._value is None or other.value is None:
365      return Dimension(None)
366    else:
367      return Dimension(self._value + other.value)
368
369  def __radd__(self, other):
370    """Returns the sum of `other` and `self`.
371
372    Args:
373      other: Another Dimension, or a value accepted by `as_dimension`.
374
375    Returns:
376      A Dimension whose value is the sum of `self` and `other`.
377    """
378    return self + other
379
380  def __sub__(self, other):
381    """Returns the subtraction of `other` from `self`.
382
383    Dimensions are subtracted as follows:
384
385    ```python
386    tf.compat.v1.Dimension(m)    - tf.compat.v1.Dimension(n)     ==
387    tf.compat.v1.Dimension(m - n)
388    tf.compat.v1.Dimension(m)    - tf.compat.v1.Dimension(None)  # equiv. to
389    tf.compat.v1.Dimension(None)
390    tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n)     # equiv. to
391    tf.compat.v1.Dimension(None)
392    tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None)  # equiv. to
393    tf.compat.v1.Dimension(None)
394    ```
395
396    Args:
397      other: Another Dimension, or a value accepted by `as_dimension`.
398
399    Returns:
400      A Dimension whose value is the subtraction of `other` from `self`.
401    """
402    try:
403      other = as_dimension(other)
404    except (TypeError, ValueError):
405      return NotImplemented
406    if self._value is None or other.value is None:
407      return Dimension(None)
408    else:
409      return Dimension(self._value - other.value)
410
411  def __rsub__(self, other):
412    """Returns the subtraction of `self` from `other`.
413
414    Args:
415      other: Another Dimension, or a value accepted by `as_dimension`.
416
417    Returns:
418      A Dimension whose value is the subtraction of `self` from `other`.
419    """
420    other = as_dimension(other)
421    if self._value is None or other.value is None:
422      return Dimension(None)
423    else:
424      return Dimension(other.value - self._value)
425
426  def __mul__(self, other):
427    """Returns the product of `self` and `other`.
428
429    Dimensions are summed as follows:
430
431    ```python
432    tf.compat.v1.Dimension(m)    * tf.compat.v1.Dimension(n)     ==
433    tf.compat.v1.Dimension(m * n)
434    tf.compat.v1.Dimension(m)    * tf.compat.v1.Dimension(None)  # equiv. to
435    tf.compat.v1.Dimension(None)
436    tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n)     # equiv. to
437    tf.compat.v1.Dimension(None)
438    tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None)  # equiv. to
439    tf.compat.v1.Dimension(None)
440    ```
441
442    Args:
443      other: Another Dimension, or a value accepted by `as_dimension`.
444
445    Returns:
446      A Dimension whose value is the product of `self` and `other`.
447    """
448    try:
449      other = as_dimension(other)
450    except (TypeError, ValueError):
451      return NotImplemented
452
453    if self._value is None or other.value is None:
454      return Dimension(None)
455    else:
456      return Dimension(self._value * other.value)
457
458  def __rmul__(self, other):
459    """Returns the product of `self` and `other`.
460
461    Args:
462      other: Another Dimension, or a value accepted by `as_dimension`.
463
464    Returns:
465      A Dimension whose value is the product of `self` and `other`.
466    """
467    return self * other
468
469  def __floordiv__(self, other):
470    """Returns the quotient of `self` and `other` rounded down.
471
472    Dimensions are divided as follows:
473
474    ```python
475    tf.compat.v1.Dimension(m)    // tf.compat.v1.Dimension(n)     ==
476    tf.compat.v1.Dimension(m // n)
477    tf.compat.v1.Dimension(m)    // tf.compat.v1.Dimension(None)  # equiv. to
478    tf.compat.v1.Dimension(None)
479    tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n)     # equiv. to
480    tf.compat.v1.Dimension(None)
481    tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None)  # equiv. to
482    tf.compat.v1.Dimension(None)
483    ```
484
485    Args:
486      other: Another Dimension, or a value accepted by `as_dimension`.
487
488    Returns:
489      A `Dimension` whose value is the integer quotient of `self` and `other`.
490    """
491    try:
492      other = as_dimension(other)
493    except (TypeError, ValueError):
494      return NotImplemented
495    if self._value is None or other.value is None:
496      return Dimension(None)
497    else:
498      return Dimension(self._value // other.value)
499
500  def __rfloordiv__(self, other):
501    """Returns the quotient of `other` and `self` rounded down.
502
503    Args:
504      other: Another Dimension, or a value accepted by `as_dimension`.
505
506    Returns:
507      A `Dimension` whose value is the integer quotient of `self` and `other`.
508    """
509    other = as_dimension(other)
510    if self._value is None or other.value is None:
511      return Dimension(None)
512    else:
513      return Dimension(other.value // self._value)
514
515  def __div__(self, other):
516    """DEPRECATED: Use `__floordiv__` via `x // y` instead.
517
518    This function exists only for backwards compatibility purposes; new code
519    should use `__floordiv__` via the syntax `x // y`.  Using `x // y`
520    communicates clearly that the result rounds down, and is forward compatible
521    to Python 3.
522
523    Args:
524      other: Another `Dimension`.
525
526    Returns:
527      A `Dimension` whose value is the integer quotient of `self` and `other`.
528    """
529    return self // other
530
531  def __rdiv__(self, other):
532    """Use `__floordiv__` via `x // y` instead.
533
534    This function exists only to have a better error message. Instead of:
535    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
536    this function will explicitly call for usage of `//` instead.
537
538    Args:
539      other: Another `Dimension`.
540
541    Raises:
542      TypeError.
543    """
544    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
545                    "please use // instead".format(type(other).__name__))
546
547  def __truediv__(self, other):
548    """Use `__floordiv__` via `x // y` instead.
549
550    This function exists only to have a better error message. Instead of:
551    `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`,
552    this function will explicitly call for usage of `//` instead.
553
554    Args:
555      other: Another `Dimension`.
556
557    Raises:
558      TypeError.
559    """
560    raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', "
561                    "please use // instead".format(type(other).__name__))
562
563  def __rtruediv__(self, other):
564    """Use `__floordiv__` via `x // y` instead.
565
566    This function exists only to have a better error message. Instead of:
567    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
568    this function will explicitly call for usage of `//` instead.
569
570    Args:
571      other: Another `Dimension`.
572
573    Raises:
574      TypeError.
575    """
576    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
577                    "please use // instead".format(type(other).__name__))
578
579  def __mod__(self, other):
580    """Returns `self` modulo `other`.
581
582    Dimension modulo are computed as follows:
583
584    ```python
585    tf.compat.v1.Dimension(m)    % tf.compat.v1.Dimension(n)     ==
586    tf.compat.v1.Dimension(m % n)
587    tf.compat.v1.Dimension(m)    % tf.compat.v1.Dimension(None)  # equiv. to
588    tf.compat.v1.Dimension(None)
589    tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n)     # equiv. to
590    tf.compat.v1.Dimension(None)
591    tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None)  # equiv. to
592    tf.compat.v1.Dimension(None)
593    ```
594
595    Args:
596      other: Another Dimension, or a value accepted by `as_dimension`.
597
598    Returns:
599      A Dimension whose value is `self` modulo `other`.
600    """
601    other = as_dimension(other)
602    if self._value is None or other.value is None:
603      return Dimension(None)
604    else:
605      return Dimension(self._value % other.value)
606
607  def __rmod__(self, other):
608    """Returns `other` modulo `self`.
609
610    Args:
611      other: Another Dimension, or a value accepted by `as_dimension`.
612
613    Returns:
614      A Dimension whose value is `other` modulo `self`.
615    """
616    other = as_dimension(other)
617    return other % self
618
619  def __lt__(self, other):
620    """Returns True if `self` is known to be less than `other`.
621
622    Dimensions are compared as follows:
623
624    ```python
625    (tf.compat.v1.Dimension(m)    < tf.compat.v1.Dimension(n))    == (m < n)
626    (tf.compat.v1.Dimension(m)    < tf.compat.v1.Dimension(None)) == None
627    (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n))    == None
628    (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None
629    ```
630
631    Args:
632      other: Another Dimension.
633
634    Returns:
635      The value of `self.value < other.value` if both are known, otherwise
636      None.
637    """
638    other = as_dimension(other)
639    if self._value is None or other.value is None:
640      return None
641    else:
642      return self._value < other.value
643
644  def __le__(self, other):
645    """Returns True if `self` is known to be less than or equal to `other`.
646
647    Dimensions are compared as follows:
648
649    ```python
650    (tf.compat.v1.Dimension(m)    <= tf.compat.v1.Dimension(n))    == (m <= n)
651    (tf.compat.v1.Dimension(m)    <= tf.compat.v1.Dimension(None)) == None
652    (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n))    == None
653    (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None
654    ```
655
656    Args:
657      other: Another Dimension.
658
659    Returns:
660      The value of `self.value <= other.value` if both are known, otherwise
661      None.
662    """
663    other = as_dimension(other)
664    if self._value is None or other.value is None:
665      return None
666    else:
667      return self._value <= other.value
668
669  def __gt__(self, other):
670    """Returns True if `self` is known to be greater than `other`.
671
672    Dimensions are compared as follows:
673
674    ```python
675    (tf.compat.v1.Dimension(m)    > tf.compat.v1.Dimension(n))    == (m > n)
676    (tf.compat.v1.Dimension(m)    > tf.compat.v1.Dimension(None)) == None
677    (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n))    == None
678    (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None
679    ```
680
681    Args:
682      other: Another Dimension.
683
684    Returns:
685      The value of `self.value > other.value` if both are known, otherwise
686      None.
687    """
688    other = as_dimension(other)
689    if self._value is None or other.value is None:
690      return None
691    else:
692      return self._value > other.value
693
694  def __ge__(self, other):
695    """Returns True if `self` is known to be greater than or equal to `other`.
696
697    Dimensions are compared as follows:
698
699    ```python
700    (tf.compat.v1.Dimension(m)    >= tf.compat.v1.Dimension(n))    == (m >= n)
701    (tf.compat.v1.Dimension(m)    >= tf.compat.v1.Dimension(None)) == None
702    (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n))    == None
703    (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None
704    ```
705
706    Args:
707      other: Another Dimension.
708
709    Returns:
710      The value of `self.value >= other.value` if both are known, otherwise
711      None.
712    """
713    other = as_dimension(other)
714    if self._value is None or other.value is None:
715      return None
716    else:
717      return self._value >= other.value
718
719  def __reduce__(self):
720    return Dimension, (self._value,)
721
722
723def as_dimension(value):
724  """Converts the given value to a Dimension.
725
726  A Dimension input will be returned unmodified.
727  An input of `None` will be converted to an unknown Dimension.
728  An integer input will be converted to a Dimension with that value.
729
730  Args:
731    value: The value to be converted.
732
733  Returns:
734    A Dimension corresponding to the given value.
735  """
736  if isinstance(value, Dimension):
737    return value
738  else:
739    return Dimension(value)
740
741
742@tf_export("TensorShape")
743class TensorShape(object):
744  """Represents the shape of a `Tensor`.
745
746  A `TensorShape` represents a possibly-partial shape specification for a
747  `Tensor`. It may be one of the following:
748
749  * *Fully-known shape:* has a known number of dimensions and a known size
750    for each dimension. e.g. `TensorShape([16, 256])`
751  * *Partially-known shape:* has a known number of dimensions, and an unknown
752    size for one or more dimension. e.g. `TensorShape([None, 256])`
753  * *Unknown shape:* has an unknown number of dimensions, and an unknown
754    size in all dimensions. e.g. `TensorShape(None)`
755
756  If a tensor is produced by an operation of type `"Foo"`, its shape
757  may be inferred if there is a registered shape function for
758  `"Foo"`. See [Shape
759  functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c)
760  for details of shape functions and how to register them. Alternatively,
761  the shape may be set explicitly using `tf.Tensor.set_shape`.
762  """
763  __slots__ = ["_dims"]
764
765  def __init__(self, dims):
766    """Creates a new TensorShape with the given dimensions.
767
768    Args:
769      dims: A list of Dimensions, or None if the shape is unspecified.
770
771    Raises:
772      TypeError: If dims cannot be converted to a list of dimensions.
773    """
774    if isinstance(dims, (tuple, list)):  # Most common case.
775      self._dims = [Dimension(d) for d in dims]
776    elif dims is None:
777      self._dims = None
778    elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
779      if dims.unknown_rank:
780        self._dims = None
781      else:
782        self._dims = [
783            # Protos store variable-size dimensions as -1
784            as_dimension(dim.size if dim.size != -1 else None)
785            for dim in dims.dim
786        ]
787    elif isinstance(dims, TensorShape):
788      self._dims = dims.dims
789    else:
790      try:
791        dims_iter = iter(dims)
792      except TypeError:
793        # Treat as a singleton dimension
794        self._dims = [as_dimension(dims)]
795      else:
796        self._dims = []
797        for d in dims_iter:
798          try:
799            self._dims.append(as_dimension(d))
800          except TypeError as e:
801            six.raise_from(
802                TypeError(
803                    "Failed to convert '{0!r}' to a shape: '{1!r}'"
804                    "could not be converted to a dimension. A shape should "
805                    "either be single dimension (e.g. 10), or an iterable of "
806                    "dimensions (e.g. [1, 10, None])."
807                    .format(dims, d)), e)
808
809  @property
810  def _v2_behavior(self):
811    if _TENSORSHAPE_V2_OVERRIDE is None:
812      return tf2.enabled()
813    return _TENSORSHAPE_V2_OVERRIDE
814
815  def __repr__(self):
816    if self._v2_behavior:
817      if self._dims is not None:
818        return "TensorShape(%r)" % [dim.value for dim in self._dims]
819      else:
820        return "TensorShape(None)"
821    else:
822      return "TensorShape(%r)" % self._dims
823
824  def __str__(self):
825    if self.rank is None:
826      return "<unknown>"
827    elif self.rank == 1:
828      if self._v2_behavior:
829        return "(%s,)" % self._dims[0].value
830      else:
831        return "(%s,)" % self._dims[0]
832    else:
833      if self._v2_behavior:
834        return "(%s)" % ", ".join(str(d.value) for d in self._dims)
835      else:
836        return "(%s)" % ", ".join(str(d) for d in self._dims)
837
838  @property
839  def rank(self):
840    """Returns the rank of this shape, or None if it is unspecified."""
841    if self._dims is not None:
842      return len(self._dims)
843    return None
844
845  @property
846  def dims(self):
847    """Deprecated.  Returns list of dimensions for this shape.
848
849    Suggest `TensorShape.as_list` instead.
850
851    Returns:
852      A list containing `tf.compat.v1.Dimension`s, or None if the shape is
853      unspecified.
854    """
855    return self._dims
856
857  @property
858  def ndims(self):
859    """Deprecated accessor for `rank`."""
860    return self.rank
861
862  def __len__(self):
863    """Returns the rank of this shape, or raises ValueError if unspecified."""
864    if self._dims is None:
865      raise ValueError("Cannot take the length of shape with unknown rank.")
866    return len(self._dims)
867
868  def __bool__(self):
869    """Returns True if this shape contains non-zero information."""
870    return self._dims is not None
871
872  # Python 3 wants __bool__, Python 2.7 wants __nonzero__
873  __nonzero__ = __bool__
874
875  def __iter__(self):
876    """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
877    if self._dims is None:
878      raise ValueError("Cannot iterate over a shape with unknown rank.")
879    else:
880      if self._v2_behavior:
881        return iter(d.value for d in self._dims)
882      else:
883        return iter(d for d in self._dims)
884
885  def __getitem__(self, key):
886    """Returns the value of a dimension or a shape, depending on the key.
887
888    Args:
889      key: If `key` is an integer, returns the dimension at that index;
890        otherwise if `key` is a slice, returns a TensorShape whose dimensions
891        are those selected by the slice from `self`.
892
893    Returns:
894      An integer if `key` is an integer, or a `TensorShape` if `key` is a
895      slice.
896
897    Raises:
898      ValueError: If `key` is a slice and `self` is completely unknown and
899        the step is set.
900    """
901    if self._dims is not None:
902      if isinstance(key, slice):
903        return TensorShape(self._dims[key])
904      else:
905        if self._v2_behavior:
906          return self._dims[key].value
907        else:
908          return self._dims[key]
909    else:
910      if isinstance(key, slice):
911        start = key.start if key.start is not None else 0
912        stop = key.stop
913
914        if key.step is not None:
915          # TODO(mrry): Handle these maybe.
916          raise ValueError("Steps are not yet handled")
917        if stop is None:
918          # NOTE(mrry): This implies that TensorShape(None) is compatible with
919          # TensorShape(None)[1:], which is obviously not true. It would be
920          # possible to track the number of dimensions symbolically,
921          # and perhaps we should do that.
922          return unknown_shape()
923        elif start < 0 or stop < 0:
924          # TODO(mrry): Handle this better, as it will be useful for handling
925          # suffixes of otherwise unknown shapes.
926          return unknown_shape()
927        else:
928          return unknown_shape(rank=stop - start)
929      else:
930        if self._v2_behavior:
931          return None
932        else:
933          return Dimension(None)
934
935  def num_elements(self):
936    """Returns the total number of elements, or none for incomplete shapes."""
937    if self.is_fully_defined():
938      return functools.reduce(operator.mul, self.as_list(), 1)
939    else:
940      return None
941
942  def merge_with(self, other):
943    """Returns a `TensorShape` combining the information in `self` and `other`.
944
945    The dimensions in `self` and `other` are merged element-wise,
946    according to the rules below:
947
948    ```python
949    Dimension(n).merge_with(Dimension(None)) == Dimension(n)
950    Dimension(None).merge_with(Dimension(n)) == Dimension(n)
951    Dimension(None).merge_with(Dimension(None)) == Dimension(None)
952    # raises ValueError for n != m
953    Dimension(n).merge_with(Dimension(m))
954    ```
955    >> ts = tf.TensorShape([1,2])
956    >> ot1 = tf.TensorShape([1,2])
957    >> ts.merge_with(ot).as_list()
958    [1,2]
959
960    >> ot2 = tf.TensorShape([1,None])
961    >> ts.merge_with(ot2).as_list()
962    [1,2]
963
964    >> ot3 = tf.TensorShape([None, None])
965    >> ot3.merge_with(ot2).as_list()
966    [1, None]
967
968    Args:
969      other: Another `TensorShape`.
970
971    Returns:
972      A `TensorShape` containing the combined information of `self` and
973      `other`.
974
975    Raises:
976      ValueError: If `self` and `other` are not compatible.
977    """
978    other = as_shape(other)
979    if self._dims is None:
980      return other
981    if other.dims is None:
982      return self
983    else:
984      try:
985        self.assert_same_rank(other)
986        new_dims = [
987            dim.merge_with(other_dim)
988            for dim, other_dim in zip(self._dims, other.dims)
989        ]
990        return TensorShape(new_dims)
991      except ValueError:
992        raise ValueError("Shapes %s and %s are not compatible" % (self, other))
993
994  def __add__(self, other):
995    return self.concatenate(other)
996
997  def __radd__(self, other):
998    if not isinstance(other, TensorShape):
999      other = TensorShape(other)
1000    return other.concatenate(self)
1001
1002  def concatenate(self, other):
1003    """Returns the concatenation of the dimension in `self` and `other`.
1004
1005    *N.B.* If either `self` or `other` is completely unknown,
1006    concatenation will discard information about the other shape. In
1007    future, we might support concatenation that preserves this
1008    information for use with slicing.
1009
1010    Args:
1011      other: Another `TensorShape`.
1012
1013    Returns:
1014      A `TensorShape` whose dimensions are the concatenation of the
1015      dimensions in `self` and `other`.
1016    """
1017    # TODO(mrry): Handle the case where we concatenate a known shape with a
1018    # completely unknown shape, so that we can use the partial information.
1019    other = as_shape(other)
1020    if self._dims is None or other.dims is None:
1021      return unknown_shape()
1022    else:
1023      return TensorShape(self._dims + other.dims)
1024
1025  def assert_same_rank(self, other):
1026    """Raises an exception if `self` and `other` do not have compatible ranks.
1027
1028    Args:
1029      other: Another `TensorShape`.
1030
1031    Raises:
1032      ValueError: If `self` and `other` do not represent shapes with the
1033        same rank.
1034    """
1035    other = as_shape(other)
1036    if self.rank is not None and other.rank is not None:
1037      if self.rank != other.rank:
1038        raise ValueError("Shapes %s and %s must have the same rank" %
1039                         (self, other))
1040
1041  def assert_has_rank(self, rank):
1042    """Raises an exception if `self` is not compatible with the given `rank`.
1043
1044    Args:
1045      rank: An integer.
1046
1047    Raises:
1048      ValueError: If `self` does not represent a shape with the given `rank`.
1049    """
1050    if self.rank not in (None, rank):
1051      raise ValueError("Shape %s must have rank %d" % (self, rank))
1052
1053  def with_rank(self, rank):
1054    """Returns a shape based on `self` with the given rank.
1055
1056    This method promotes a completely unknown shape to one with a
1057    known rank.
1058
1059    Args:
1060      rank: An integer.
1061
1062    Returns:
1063      A shape that is at least as specific as `self` with the given rank.
1064
1065    Raises:
1066      ValueError: If `self` does not represent a shape with the given `rank`.
1067    """
1068    try:
1069      return self.merge_with(unknown_shape(rank=rank))
1070    except ValueError:
1071      raise ValueError("Shape %s must have rank %d" % (self, rank))
1072
1073  def with_rank_at_least(self, rank):
1074    """Returns a shape based on `self` with at least the given rank.
1075
1076    Args:
1077      rank: An integer.
1078
1079    Returns:
1080      A shape that is at least as specific as `self` with at least the given
1081      rank.
1082
1083    Raises:
1084      ValueError: If `self` does not represent a shape with at least the given
1085        `rank`.
1086    """
1087    if self.rank is not None and self.rank < rank:
1088      raise ValueError("Shape %s must have rank at least %d" % (self, rank))
1089    else:
1090      return self
1091
1092  def with_rank_at_most(self, rank):
1093    """Returns a shape based on `self` with at most the given rank.
1094
1095    Args:
1096      rank: An integer.
1097
1098    Returns:
1099      A shape that is at least as specific as `self` with at most the given
1100      rank.
1101
1102    Raises:
1103      ValueError: If `self` does not represent a shape with at most the given
1104        `rank`.
1105    """
1106    if self.rank is not None and self.rank > rank:
1107      raise ValueError("Shape %s must have rank at most %d" % (self, rank))
1108    else:
1109      return self
1110
1111  def is_compatible_with(self, other):
1112    """Returns True iff `self` is compatible with `other`.
1113
1114    Two possibly-partially-defined shapes are compatible if there
1115    exists a fully-defined shape that both shapes can represent. Thus,
1116    compatibility allows the shape inference code to reason about
1117    partially-defined shapes. For example:
1118
1119    * TensorShape(None) is compatible with all shapes.
1120
1121    * TensorShape([None, None]) is compatible with all two-dimensional
1122      shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
1123      not compatible with, for example, TensorShape([None]) or
1124      TensorShape([None, None, None]).
1125
1126    * TensorShape([32, None]) is compatible with all two-dimensional shapes
1127      with size 32 in the 0th dimension, and also TensorShape([None, None])
1128      and TensorShape(None). It is not compatible with, for example,
1129      TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
1130
1131    * TensorShape([32, 784]) is compatible with itself, and also
1132      TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
1133      None]) and TensorShape(None). It is not compatible with, for example,
1134      TensorShape([32, 1, 784]) or TensorShape([None]).
1135
1136    The compatibility relation is reflexive and symmetric, but not
1137    transitive. For example, TensorShape([32, 784]) is compatible with
1138    TensorShape(None), and TensorShape(None) is compatible with
1139    TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
1140    TensorShape([4, 4]).
1141
1142    Args:
1143      other: Another TensorShape.
1144
1145    Returns:
1146      True iff `self` is compatible with `other`.
1147
1148    """
1149    other = as_shape(other)
1150    if self._dims is not None and other.dims is not None:
1151      if self.rank != other.rank:
1152        return False
1153      for x_dim, y_dim in zip(self._dims, other.dims):
1154        if not x_dim.is_compatible_with(y_dim):
1155          return False
1156    return True
1157
1158  def assert_is_compatible_with(self, other):
1159    """Raises exception if `self` and `other` do not represent the same shape.
1160
1161    This method can be used to assert that there exists a shape that both
1162    `self` and `other` represent.
1163
1164    Args:
1165      other: Another TensorShape.
1166
1167    Raises:
1168      ValueError: If `self` and `other` do not represent the same shape.
1169    """
1170    if not self.is_compatible_with(other):
1171      raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1172
1173  def most_specific_compatible_shape(self, other):
1174    """Returns the most specific TensorShape compatible with `self` and `other`.
1175
1176    * TensorShape([None, 1]) is the most specific TensorShape compatible with
1177      both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
1178      TensorShape(None) is also compatible with above mentioned TensorShapes.
1179
1180    * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
1181      both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
1182      less specific TensorShapes compatible with above mentioned TensorShapes,
1183      e.g. TensorShape([1, 2, None]), TensorShape(None).
1184
1185    Args:
1186      other: Another `TensorShape`.
1187
1188    Returns:
1189      A `TensorShape` which is the most specific compatible shape of `self`
1190      and `other`.
1191    """
1192
1193    other = as_shape(other)
1194    if self._dims is None or other.dims is None or self.rank != other.rank:
1195      return unknown_shape()
1196
1197    dims = [
1198        d1 if d1 is not None and d2 is not None and d1 == d2 else None
1199        for d1, d2 in zip(self._dims, other.dims)
1200    ]
1201    return TensorShape(dims)
1202
1203  def is_fully_defined(self):
1204    """Returns True iff `self` is fully defined in every dimension."""
1205    return (self._dims is not None and
1206            all(dim.value is not None for dim in self._dims))
1207
1208  def assert_is_fully_defined(self):
1209    """Raises an exception if `self` is not fully defined in every dimension.
1210
1211    Raises:
1212      ValueError: If `self` does not have a known value for every dimension.
1213    """
1214    if not self.is_fully_defined():
1215      raise ValueError("Shape %s is not fully defined" % self)
1216
1217  def as_list(self):
1218    """Returns a list of integers or `None` for each dimension.
1219
1220    Returns:
1221      A list of integers or `None` for each dimension.
1222
1223    Raises:
1224      ValueError: If `self` is an unknown shape with an unknown rank.
1225    """
1226    if self._dims is None:
1227      raise ValueError("as_list() is not defined on an unknown TensorShape.")
1228    return [dim.value for dim in self._dims]
1229
1230  def as_proto(self):
1231    """Returns this shape as a `TensorShapeProto`."""
1232    if self._dims is None:
1233      return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
1234    else:
1235      return tensor_shape_pb2.TensorShapeProto(dim=[
1236          tensor_shape_pb2.TensorShapeProto.Dim(
1237              size=-1 if d.value is None else d.value) for d in self._dims
1238      ])
1239
1240  def __eq__(self, other):
1241    """Returns True if `self` is equivalent to `other`.
1242
1243    It first tries to convert `other` to `TensorShape`. `TypeError` is thrown
1244    when the conversion fails. Otherwise, it compares each element in the
1245    TensorShape dimensions.
1246
1247    * Two *Fully known* shapes, return True iff each element is equal.
1248    >>> t_a = tf.TensorShape([1,2])
1249    >>> a = [1, 2]
1250    >>> t_b = tf.TensorShape([1,2])
1251    >>> t_c = tf.TensorShape([1,2,3])
1252    >>> t_a.__eq__(a)
1253    True
1254    >>> t_a.__eq__(t_b)
1255    True
1256    >>> t_a.__eq__(t_c)
1257    False
1258
1259    * Two *Partially-known* shapes, return False.
1260    >>> p_a = tf.TensorShape([1,None])
1261    >>> p_b = tf.TensorShape([2,None])
1262    >>> p_a.__eq__(p_b)
1263    False
1264    >>> t_a.__eq__(p_a)
1265    False
1266
1267    * Two *Unknown shape*, return True.
1268    >>> unk_a = tf.TensorShape(None)
1269    >>> unk_b = tf.TensorShape(None)
1270    >>> unk_a.__eq__(unk_b)
1271    True
1272    >>> unk_a.__eq__(t_a)
1273    False
1274
1275    Args:
1276      other: A `TensorShape` or type that can be converted to `TensorShape`.
1277
1278    Returns:
1279      True if the dimensions are all equal.
1280
1281    Raises:
1282      TypeError if `other` can not be converted to `TensorShape`.
1283    """
1284
1285    try:
1286      other = as_shape(other)
1287    except TypeError:
1288      return NotImplemented
1289    return self._dims == other.dims
1290
1291  def __ne__(self, other):
1292    """Returns True if `self` is known to be different from `other`."""
1293    try:
1294      other = as_shape(other)
1295    except TypeError:
1296      return NotImplemented
1297    if self.rank is None or other.rank is None:
1298      raise ValueError("The inequality of unknown TensorShapes is undefined.")
1299    if self.rank != other.rank:
1300      return True
1301    return self._dims != other.dims
1302
1303  def __reduce__(self):
1304    return TensorShape, (self._dims,)
1305
1306  def __concat__(self, other):
1307    return self.concatenate(other)
1308
1309
1310def as_shape(shape):
1311  """Converts the given object to a TensorShape."""
1312  if isinstance(shape, TensorShape):
1313    return shape
1314  else:
1315    return TensorShape(shape)
1316
1317
1318def unknown_shape(rank=None, **kwargs):
1319  """Returns an unknown TensorShape, optionally with a known rank.
1320
1321  Args:
1322    rank: (Optional) If specified, the number of dimensions in the shape.
1323    **kwargs: For backwards compatibility.
1324
1325  Returns:
1326    An unknown TensorShape.
1327
1328  Raises:
1329    TypeError: In case of invalid arguments.
1330  """
1331  if rank is None and "ndims" in kwargs:
1332    rank = kwargs.pop("ndims")
1333  if kwargs:
1334    raise TypeError("Unknown argument: %s" % kwargs)
1335  if rank is None:
1336    return TensorShape(None)
1337  else:
1338    return TensorShape([Dimension(None)] * rank)
1339