1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes for tensor shape inference.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import operator 22import six 23 24from tensorflow.core.framework import tensor_shape_pb2 25from tensorflow.python import tf2 26from tensorflow.python.eager import monitoring 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.util.tf_export import tf_export 29 30_TENSORSHAPE_V2_OVERRIDE = None 31 32_api_usage_gauge = monitoring.BoolGauge( 33 "/tensorflow/api/v2_tensorshape", 34 "Whether tensor_shape.enable_v2_tensorshape() is called.") 35 36 37@tf_export(v1=["enable_v2_tensorshape"]) 38def enable_v2_tensorshape(): 39 """In TensorFlow 2.0, iterating over a TensorShape instance returns values. 40 41 This enables the new behavior. 42 43 Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but 44 it V2 it returns either an integer, or None. 45 46 Examples: 47 48 ``` 49 ####################### 50 # If you had this in V1: 51 value = tensor_shape[i].value 52 53 # Do this in V2 instead: 54 value = tensor_shape[i] 55 56 ####################### 57 # If you had this in V1: 58 for dim in tensor_shape: 59 value = dim.value 60 print(value) 61 62 # Do this in V2 instead: 63 for value in tensor_shape: 64 print(value) 65 66 ####################### 67 # If you had this in V1: 68 dim = tensor_shape[i] 69 dim.assert_is_compatible_with(other_shape) # or using any other shape method 70 71 # Do this in V2 instead: 72 if tensor_shape.rank is None: 73 dim = Dimension(None) 74 else: 75 dim = tensor_shape.dims[i] 76 dim.assert_is_compatible_with(other_shape) # or using any other shape method 77 78 # The V2 suggestion above is more explicit, which will save you from 79 # the following trap (present in V1): 80 # you might do in-place modifications to `dim` and expect them to be reflected 81 # in `tensor_shape[i]`, but they would not be. 82 ``` 83 """ 84 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 85 _TENSORSHAPE_V2_OVERRIDE = True 86 logging.vlog(1, "Enabling v2 tensorshape") 87 _api_usage_gauge.get_cell().set(True) 88 89 90@tf_export(v1=["disable_v2_tensorshape"]) 91def disable_v2_tensorshape(): 92 """Disables the V2 TensorShape behavior and reverts to V1 behavior. 93 94 See docstring for `enable_v2_tensorshape` for details about the new behavior. 95 """ 96 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 97 _TENSORSHAPE_V2_OVERRIDE = False 98 logging.vlog(1, "Disabling v2 tensorshape") 99 _api_usage_gauge.get_cell().set(False) 100 101 102@tf_export( 103 "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"]) 104def dimension_value(dimension): 105 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 106 107 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 108 coexist with the new behavior. This utility is a bridge between the two. 109 110 When accessing the value of a TensorShape dimension, 111 use this utility, like this: 112 113 ``` 114 # If you had this in your V1 code: 115 value = tensor_shape[i].value 116 117 # Use `dimension_value` as direct replacement compatible with both V1 & V2: 118 value = dimension_value(tensor_shape[i]) 119 120 # This would be the V2 equivalent: 121 value = tensor_shape[i] # Warning: this will return the dim value in V2! 122 ``` 123 124 Args: 125 dimension: Either a `Dimension` instance, an integer, or None. 126 127 Returns: 128 A plain value, i.e. an integer or None. 129 """ 130 if isinstance(dimension, Dimension): 131 return dimension.value 132 return dimension 133 134 135@tf_export( 136 "compat.dimension_at_index", 137 v1=["dimension_at_index", "compat.dimension_at_index"]) 138def dimension_at_index(shape, index): 139 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 140 141 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 142 coexist with the new behavior. This utility is a bridge between the two. 143 144 If you want to retrieve the Dimension instance corresponding to a certain 145 index in a TensorShape instance, use this utility, like this: 146 147 ``` 148 # If you had this in your V1 code: 149 dim = tensor_shape[i] 150 151 # Use `dimension_at_index` as direct replacement compatible with both V1 & V2: 152 dim = dimension_at_index(tensor_shape, i) 153 154 # Another possibility would be this, but WARNING: it only works if the 155 # tensor_shape instance has a defined rank. 156 dim = tensor_shape.dims[i] # `dims` may be None if the rank is undefined! 157 158 # In native V2 code, we recommend instead being more explicit: 159 if tensor_shape.rank is None: 160 dim = Dimension(None) 161 else: 162 dim = tensor_shape.dims[i] 163 164 # Being more explicit will save you from the following trap (present in V1): 165 # you might do in-place modifications to `dim` and expect them to be reflected 166 # in `tensor_shape[i]`, but they would not be (as the Dimension object was 167 # instantiated on the fly. 168 ``` 169 170 Args: 171 shape: A TensorShape instance. 172 index: An integer index. 173 174 Returns: 175 A dimension object. 176 """ 177 assert isinstance(shape, TensorShape) 178 if shape.rank is None: 179 return Dimension(None) 180 else: 181 return shape.dims[index] 182 183 184@tf_export(v1=["Dimension"]) 185class Dimension(object): 186 """Represents the value of one dimension in a TensorShape. 187 188 @compatibility(TF2) 189 In TF2, members of a `TensorShape` object are integers. The `Dimension` class 190 is not part of TF2's data model. 191 192 Please refer to the [TensorShape section of the migration guide] 193 (https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code 194 patterns adapting Dimension objects to a TF2 syntax. 195 @end_compatibility 196 """ 197 198 __slots__ = ["_value"] 199 200 def __init__(self, value): 201 """Creates a new Dimension with the given value.""" 202 if isinstance(value, int): # Most common case. 203 if value < 0: 204 raise ValueError("Dimension %d must be >= 0" % value) 205 self._value = value 206 elif value is None: 207 self._value = None 208 elif isinstance(value, Dimension): 209 self._value = value._value 210 else: 211 try: 212 # int(...) compensates for the int/long dichotomy on Python 2.X. 213 # TODO(b/143206389): Remove once we fully migrate to 3.X. 214 self._value = int(value.__index__()) 215 except AttributeError: 216 six.raise_from( 217 TypeError("Dimension value must be integer or None or have " 218 "an __index__ method, got value '{0!r}' with type '{1!r}'" 219 .format(value, type(value))), None) 220 if self._value < 0: 221 raise ValueError("Dimension %d must be >= 0" % self._value) 222 223 def __repr__(self): 224 return "Dimension(%s)" % repr(self._value) 225 226 def __str__(self): 227 value = self._value 228 return "?" if value is None else str(value) 229 230 def __eq__(self, other): 231 """Returns true if `other` has the same known value as this Dimension.""" 232 try: 233 other = as_dimension(other) 234 except (TypeError, ValueError): 235 return NotImplemented 236 if self._value is None or other.value is None: 237 return None 238 return self._value == other.value 239 240 def __ne__(self, other): 241 """Returns true if `other` has a different known value from `self`.""" 242 try: 243 other = as_dimension(other) 244 except (TypeError, ValueError): 245 return NotImplemented 246 if self._value is None or other.value is None: 247 return None 248 return self._value != other.value 249 250 def __bool__(self): 251 """Equivalent to `bool(self.value)`.""" 252 return bool(self._value) 253 254 def __int__(self): 255 return self._value 256 257 # This is needed for Windows. 258 # See https://github.com/tensorflow/tensorflow/pull/9780 259 def __long__(self): 260 return self._value 261 262 def __index__(self): 263 # Allow use in Python 3 range 264 return self._value 265 266 @property 267 def value(self): 268 """The value of this dimension, or None if it is unknown.""" 269 return self._value 270 271 def is_compatible_with(self, other): 272 """Returns true if `other` is compatible with this Dimension. 273 274 Two known Dimensions are compatible if they have the same value. 275 An unknown Dimension is compatible with all other Dimensions. 276 277 Args: 278 other: Another Dimension. 279 280 Returns: 281 True if this Dimension and `other` are compatible. 282 """ 283 other = as_dimension(other) 284 return (self._value is None or other.value is None or 285 self._value == other.value) 286 287 def assert_is_compatible_with(self, other): 288 """Raises an exception if `other` is not compatible with this Dimension. 289 290 Args: 291 other: Another Dimension. 292 293 Raises: 294 ValueError: If `self` and `other` are not compatible (see 295 is_compatible_with). 296 """ 297 if not self.is_compatible_with(other): 298 raise ValueError("Dimensions %s and %s are not compatible" % 299 (self, other)) 300 301 def merge_with(self, other): 302 """Returns a Dimension that combines the information in `self` and `other`. 303 304 Dimensions are combined as follows: 305 306 ```python 307 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(n)) == 308 tf.compat.v1.Dimension(n) 309 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(None)) == 310 tf.compat.v1.Dimension(n) 311 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n)) == 312 tf.compat.v1.Dimension(n) 313 # equivalent to tf.compat.v1.Dimension(None) 314 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None)) 315 316 # raises ValueError for n != m 317 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(m)) 318 ``` 319 320 Args: 321 other: Another Dimension. 322 323 Returns: 324 A Dimension containing the combined information of `self` and 325 `other`. 326 327 Raises: 328 ValueError: If `self` and `other` are not compatible (see 329 is_compatible_with). 330 """ 331 other = as_dimension(other) 332 self.assert_is_compatible_with(other) 333 if self._value is None: 334 return Dimension(other.value) 335 else: 336 return Dimension(self._value) 337 338 def __add__(self, other): 339 """Returns the sum of `self` and `other`. 340 341 Dimensions are summed as follows: 342 343 ```python 344 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(n) == 345 tf.compat.v1.Dimension(m + n) 346 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(None) # equiv. to 347 tf.compat.v1.Dimension(None) 348 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n) # equiv. to 349 tf.compat.v1.Dimension(None) 350 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None) # equiv. to 351 tf.compat.v1.Dimension(None) 352 ``` 353 354 Args: 355 other: Another Dimension, or a value accepted by `as_dimension`. 356 357 Returns: 358 A Dimension whose value is the sum of `self` and `other`. 359 """ 360 try: 361 other = as_dimension(other) 362 except (TypeError, ValueError): 363 return NotImplemented 364 if self._value is None or other.value is None: 365 return Dimension(None) 366 else: 367 return Dimension(self._value + other.value) 368 369 def __radd__(self, other): 370 """Returns the sum of `other` and `self`. 371 372 Args: 373 other: Another Dimension, or a value accepted by `as_dimension`. 374 375 Returns: 376 A Dimension whose value is the sum of `self` and `other`. 377 """ 378 return self + other 379 380 def __sub__(self, other): 381 """Returns the subtraction of `other` from `self`. 382 383 Dimensions are subtracted as follows: 384 385 ```python 386 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(n) == 387 tf.compat.v1.Dimension(m - n) 388 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(None) # equiv. to 389 tf.compat.v1.Dimension(None) 390 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n) # equiv. to 391 tf.compat.v1.Dimension(None) 392 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None) # equiv. to 393 tf.compat.v1.Dimension(None) 394 ``` 395 396 Args: 397 other: Another Dimension, or a value accepted by `as_dimension`. 398 399 Returns: 400 A Dimension whose value is the subtraction of `other` from `self`. 401 """ 402 try: 403 other = as_dimension(other) 404 except (TypeError, ValueError): 405 return NotImplemented 406 if self._value is None or other.value is None: 407 return Dimension(None) 408 else: 409 return Dimension(self._value - other.value) 410 411 def __rsub__(self, other): 412 """Returns the subtraction of `self` from `other`. 413 414 Args: 415 other: Another Dimension, or a value accepted by `as_dimension`. 416 417 Returns: 418 A Dimension whose value is the subtraction of `self` from `other`. 419 """ 420 other = as_dimension(other) 421 if self._value is None or other.value is None: 422 return Dimension(None) 423 else: 424 return Dimension(other.value - self._value) 425 426 def __mul__(self, other): 427 """Returns the product of `self` and `other`. 428 429 Dimensions are summed as follows: 430 431 ```python 432 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(n) == 433 tf.compat.v1.Dimension(m * n) 434 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(None) # equiv. to 435 tf.compat.v1.Dimension(None) 436 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n) # equiv. to 437 tf.compat.v1.Dimension(None) 438 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None) # equiv. to 439 tf.compat.v1.Dimension(None) 440 ``` 441 442 Args: 443 other: Another Dimension, or a value accepted by `as_dimension`. 444 445 Returns: 446 A Dimension whose value is the product of `self` and `other`. 447 """ 448 try: 449 other = as_dimension(other) 450 except (TypeError, ValueError): 451 return NotImplemented 452 453 if self._value is None or other.value is None: 454 return Dimension(None) 455 else: 456 return Dimension(self._value * other.value) 457 458 def __rmul__(self, other): 459 """Returns the product of `self` and `other`. 460 461 Args: 462 other: Another Dimension, or a value accepted by `as_dimension`. 463 464 Returns: 465 A Dimension whose value is the product of `self` and `other`. 466 """ 467 return self * other 468 469 def __floordiv__(self, other): 470 """Returns the quotient of `self` and `other` rounded down. 471 472 Dimensions are divided as follows: 473 474 ```python 475 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(n) == 476 tf.compat.v1.Dimension(m // n) 477 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(None) # equiv. to 478 tf.compat.v1.Dimension(None) 479 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n) # equiv. to 480 tf.compat.v1.Dimension(None) 481 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None) # equiv. to 482 tf.compat.v1.Dimension(None) 483 ``` 484 485 Args: 486 other: Another Dimension, or a value accepted by `as_dimension`. 487 488 Returns: 489 A `Dimension` whose value is the integer quotient of `self` and `other`. 490 """ 491 try: 492 other = as_dimension(other) 493 except (TypeError, ValueError): 494 return NotImplemented 495 if self._value is None or other.value is None: 496 return Dimension(None) 497 else: 498 return Dimension(self._value // other.value) 499 500 def __rfloordiv__(self, other): 501 """Returns the quotient of `other` and `self` rounded down. 502 503 Args: 504 other: Another Dimension, or a value accepted by `as_dimension`. 505 506 Returns: 507 A `Dimension` whose value is the integer quotient of `self` and `other`. 508 """ 509 other = as_dimension(other) 510 if self._value is None or other.value is None: 511 return Dimension(None) 512 else: 513 return Dimension(other.value // self._value) 514 515 def __div__(self, other): 516 """DEPRECATED: Use `__floordiv__` via `x // y` instead. 517 518 This function exists only for backwards compatibility purposes; new code 519 should use `__floordiv__` via the syntax `x // y`. Using `x // y` 520 communicates clearly that the result rounds down, and is forward compatible 521 to Python 3. 522 523 Args: 524 other: Another `Dimension`. 525 526 Returns: 527 A `Dimension` whose value is the integer quotient of `self` and `other`. 528 """ 529 return self // other 530 531 def __rdiv__(self, other): 532 """Use `__floordiv__` via `x // y` instead. 533 534 This function exists only to have a better error message. Instead of: 535 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 536 this function will explicitly call for usage of `//` instead. 537 538 Args: 539 other: Another `Dimension`. 540 541 Raises: 542 TypeError. 543 """ 544 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 545 "please use // instead".format(type(other).__name__)) 546 547 def __truediv__(self, other): 548 """Use `__floordiv__` via `x // y` instead. 549 550 This function exists only to have a better error message. Instead of: 551 `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`, 552 this function will explicitly call for usage of `//` instead. 553 554 Args: 555 other: Another `Dimension`. 556 557 Raises: 558 TypeError. 559 """ 560 raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', " 561 "please use // instead".format(type(other).__name__)) 562 563 def __rtruediv__(self, other): 564 """Use `__floordiv__` via `x // y` instead. 565 566 This function exists only to have a better error message. Instead of: 567 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 568 this function will explicitly call for usage of `//` instead. 569 570 Args: 571 other: Another `Dimension`. 572 573 Raises: 574 TypeError. 575 """ 576 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 577 "please use // instead".format(type(other).__name__)) 578 579 def __mod__(self, other): 580 """Returns `self` modulo `other`. 581 582 Dimension modulo are computed as follows: 583 584 ```python 585 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(n) == 586 tf.compat.v1.Dimension(m % n) 587 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(None) # equiv. to 588 tf.compat.v1.Dimension(None) 589 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n) # equiv. to 590 tf.compat.v1.Dimension(None) 591 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None) # equiv. to 592 tf.compat.v1.Dimension(None) 593 ``` 594 595 Args: 596 other: Another Dimension, or a value accepted by `as_dimension`. 597 598 Returns: 599 A Dimension whose value is `self` modulo `other`. 600 """ 601 other = as_dimension(other) 602 if self._value is None or other.value is None: 603 return Dimension(None) 604 else: 605 return Dimension(self._value % other.value) 606 607 def __rmod__(self, other): 608 """Returns `other` modulo `self`. 609 610 Args: 611 other: Another Dimension, or a value accepted by `as_dimension`. 612 613 Returns: 614 A Dimension whose value is `other` modulo `self`. 615 """ 616 other = as_dimension(other) 617 return other % self 618 619 def __lt__(self, other): 620 """Returns True if `self` is known to be less than `other`. 621 622 Dimensions are compared as follows: 623 624 ```python 625 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(n)) == (m < n) 626 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(None)) == None 627 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n)) == None 628 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None 629 ``` 630 631 Args: 632 other: Another Dimension. 633 634 Returns: 635 The value of `self.value < other.value` if both are known, otherwise 636 None. 637 """ 638 other = as_dimension(other) 639 if self._value is None or other.value is None: 640 return None 641 else: 642 return self._value < other.value 643 644 def __le__(self, other): 645 """Returns True if `self` is known to be less than or equal to `other`. 646 647 Dimensions are compared as follows: 648 649 ```python 650 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(n)) == (m <= n) 651 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(None)) == None 652 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n)) == None 653 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None 654 ``` 655 656 Args: 657 other: Another Dimension. 658 659 Returns: 660 The value of `self.value <= other.value` if both are known, otherwise 661 None. 662 """ 663 other = as_dimension(other) 664 if self._value is None or other.value is None: 665 return None 666 else: 667 return self._value <= other.value 668 669 def __gt__(self, other): 670 """Returns True if `self` is known to be greater than `other`. 671 672 Dimensions are compared as follows: 673 674 ```python 675 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(n)) == (m > n) 676 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(None)) == None 677 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n)) == None 678 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None 679 ``` 680 681 Args: 682 other: Another Dimension. 683 684 Returns: 685 The value of `self.value > other.value` if both are known, otherwise 686 None. 687 """ 688 other = as_dimension(other) 689 if self._value is None or other.value is None: 690 return None 691 else: 692 return self._value > other.value 693 694 def __ge__(self, other): 695 """Returns True if `self` is known to be greater than or equal to `other`. 696 697 Dimensions are compared as follows: 698 699 ```python 700 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(n)) == (m >= n) 701 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(None)) == None 702 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n)) == None 703 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None 704 ``` 705 706 Args: 707 other: Another Dimension. 708 709 Returns: 710 The value of `self.value >= other.value` if both are known, otherwise 711 None. 712 """ 713 other = as_dimension(other) 714 if self._value is None or other.value is None: 715 return None 716 else: 717 return self._value >= other.value 718 719 def __reduce__(self): 720 return Dimension, (self._value,) 721 722 723def as_dimension(value): 724 """Converts the given value to a Dimension. 725 726 A Dimension input will be returned unmodified. 727 An input of `None` will be converted to an unknown Dimension. 728 An integer input will be converted to a Dimension with that value. 729 730 Args: 731 value: The value to be converted. 732 733 Returns: 734 A Dimension corresponding to the given value. 735 """ 736 if isinstance(value, Dimension): 737 return value 738 else: 739 return Dimension(value) 740 741 742@tf_export("TensorShape") 743class TensorShape(object): 744 """Represents the shape of a `Tensor`. 745 746 A `TensorShape` represents a possibly-partial shape specification for a 747 `Tensor`. It may be one of the following: 748 749 * *Fully-known shape:* has a known number of dimensions and a known size 750 for each dimension. e.g. `TensorShape([16, 256])` 751 * *Partially-known shape:* has a known number of dimensions, and an unknown 752 size for one or more dimension. e.g. `TensorShape([None, 256])` 753 * *Unknown shape:* has an unknown number of dimensions, and an unknown 754 size in all dimensions. e.g. `TensorShape(None)` 755 756 If a tensor is produced by an operation of type `"Foo"`, its shape 757 may be inferred if there is a registered shape function for 758 `"Foo"`. See [Shape 759 functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c) 760 for details of shape functions and how to register them. Alternatively, 761 the shape may be set explicitly using `tf.Tensor.set_shape`. 762 """ 763 __slots__ = ["_dims"] 764 765 def __init__(self, dims): 766 """Creates a new TensorShape with the given dimensions. 767 768 Args: 769 dims: A list of Dimensions, or None if the shape is unspecified. 770 771 Raises: 772 TypeError: If dims cannot be converted to a list of dimensions. 773 """ 774 if isinstance(dims, (tuple, list)): # Most common case. 775 self._dims = [Dimension(d) for d in dims] 776 elif dims is None: 777 self._dims = None 778 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): 779 if dims.unknown_rank: 780 self._dims = None 781 else: 782 self._dims = [ 783 # Protos store variable-size dimensions as -1 784 as_dimension(dim.size if dim.size != -1 else None) 785 for dim in dims.dim 786 ] 787 elif isinstance(dims, TensorShape): 788 self._dims = dims.dims 789 else: 790 try: 791 dims_iter = iter(dims) 792 except TypeError: 793 # Treat as a singleton dimension 794 self._dims = [as_dimension(dims)] 795 else: 796 self._dims = [] 797 for d in dims_iter: 798 try: 799 self._dims.append(as_dimension(d)) 800 except TypeError as e: 801 six.raise_from( 802 TypeError( 803 "Failed to convert '{0!r}' to a shape: '{1!r}'" 804 "could not be converted to a dimension. A shape should " 805 "either be single dimension (e.g. 10), or an iterable of " 806 "dimensions (e.g. [1, 10, None])." 807 .format(dims, d)), e) 808 809 @property 810 def _v2_behavior(self): 811 if _TENSORSHAPE_V2_OVERRIDE is None: 812 return tf2.enabled() 813 return _TENSORSHAPE_V2_OVERRIDE 814 815 def __repr__(self): 816 if self._v2_behavior: 817 if self._dims is not None: 818 return "TensorShape(%r)" % [dim.value for dim in self._dims] 819 else: 820 return "TensorShape(None)" 821 else: 822 return "TensorShape(%r)" % self._dims 823 824 def __str__(self): 825 if self.rank is None: 826 return "<unknown>" 827 elif self.rank == 1: 828 if self._v2_behavior: 829 return "(%s,)" % self._dims[0].value 830 else: 831 return "(%s,)" % self._dims[0] 832 else: 833 if self._v2_behavior: 834 return "(%s)" % ", ".join(str(d.value) for d in self._dims) 835 else: 836 return "(%s)" % ", ".join(str(d) for d in self._dims) 837 838 @property 839 def rank(self): 840 """Returns the rank of this shape, or None if it is unspecified.""" 841 if self._dims is not None: 842 return len(self._dims) 843 return None 844 845 @property 846 def dims(self): 847 """Deprecated. Returns list of dimensions for this shape. 848 849 Suggest `TensorShape.as_list` instead. 850 851 Returns: 852 A list containing `tf.compat.v1.Dimension`s, or None if the shape is 853 unspecified. 854 """ 855 return self._dims 856 857 @property 858 def ndims(self): 859 """Deprecated accessor for `rank`.""" 860 return self.rank 861 862 def __len__(self): 863 """Returns the rank of this shape, or raises ValueError if unspecified.""" 864 if self._dims is None: 865 raise ValueError("Cannot take the length of shape with unknown rank.") 866 return len(self._dims) 867 868 def __bool__(self): 869 """Returns True if this shape contains non-zero information.""" 870 return self._dims is not None 871 872 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 873 __nonzero__ = __bool__ 874 875 def __iter__(self): 876 """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" 877 if self._dims is None: 878 raise ValueError("Cannot iterate over a shape with unknown rank.") 879 else: 880 if self._v2_behavior: 881 return iter(d.value for d in self._dims) 882 else: 883 return iter(d for d in self._dims) 884 885 def __getitem__(self, key): 886 """Returns the value of a dimension or a shape, depending on the key. 887 888 Args: 889 key: If `key` is an integer, returns the dimension at that index; 890 otherwise if `key` is a slice, returns a TensorShape whose dimensions 891 are those selected by the slice from `self`. 892 893 Returns: 894 An integer if `key` is an integer, or a `TensorShape` if `key` is a 895 slice. 896 897 Raises: 898 ValueError: If `key` is a slice and `self` is completely unknown and 899 the step is set. 900 """ 901 if self._dims is not None: 902 if isinstance(key, slice): 903 return TensorShape(self._dims[key]) 904 else: 905 if self._v2_behavior: 906 return self._dims[key].value 907 else: 908 return self._dims[key] 909 else: 910 if isinstance(key, slice): 911 start = key.start if key.start is not None else 0 912 stop = key.stop 913 914 if key.step is not None: 915 # TODO(mrry): Handle these maybe. 916 raise ValueError("Steps are not yet handled") 917 if stop is None: 918 # NOTE(mrry): This implies that TensorShape(None) is compatible with 919 # TensorShape(None)[1:], which is obviously not true. It would be 920 # possible to track the number of dimensions symbolically, 921 # and perhaps we should do that. 922 return unknown_shape() 923 elif start < 0 or stop < 0: 924 # TODO(mrry): Handle this better, as it will be useful for handling 925 # suffixes of otherwise unknown shapes. 926 return unknown_shape() 927 else: 928 return unknown_shape(rank=stop - start) 929 else: 930 if self._v2_behavior: 931 return None 932 else: 933 return Dimension(None) 934 935 def num_elements(self): 936 """Returns the total number of elements, or none for incomplete shapes.""" 937 if self.is_fully_defined(): 938 return functools.reduce(operator.mul, self.as_list(), 1) 939 else: 940 return None 941 942 def merge_with(self, other): 943 """Returns a `TensorShape` combining the information in `self` and `other`. 944 945 The dimensions in `self` and `other` are merged element-wise, 946 according to the rules below: 947 948 ```python 949 Dimension(n).merge_with(Dimension(None)) == Dimension(n) 950 Dimension(None).merge_with(Dimension(n)) == Dimension(n) 951 Dimension(None).merge_with(Dimension(None)) == Dimension(None) 952 # raises ValueError for n != m 953 Dimension(n).merge_with(Dimension(m)) 954 ``` 955 >> ts = tf.TensorShape([1,2]) 956 >> ot1 = tf.TensorShape([1,2]) 957 >> ts.merge_with(ot).as_list() 958 [1,2] 959 960 >> ot2 = tf.TensorShape([1,None]) 961 >> ts.merge_with(ot2).as_list() 962 [1,2] 963 964 >> ot3 = tf.TensorShape([None, None]) 965 >> ot3.merge_with(ot2).as_list() 966 [1, None] 967 968 Args: 969 other: Another `TensorShape`. 970 971 Returns: 972 A `TensorShape` containing the combined information of `self` and 973 `other`. 974 975 Raises: 976 ValueError: If `self` and `other` are not compatible. 977 """ 978 other = as_shape(other) 979 if self._dims is None: 980 return other 981 if other.dims is None: 982 return self 983 else: 984 try: 985 self.assert_same_rank(other) 986 new_dims = [ 987 dim.merge_with(other_dim) 988 for dim, other_dim in zip(self._dims, other.dims) 989 ] 990 return TensorShape(new_dims) 991 except ValueError: 992 raise ValueError("Shapes %s and %s are not compatible" % (self, other)) 993 994 def __add__(self, other): 995 return self.concatenate(other) 996 997 def __radd__(self, other): 998 if not isinstance(other, TensorShape): 999 other = TensorShape(other) 1000 return other.concatenate(self) 1001 1002 def concatenate(self, other): 1003 """Returns the concatenation of the dimension in `self` and `other`. 1004 1005 *N.B.* If either `self` or `other` is completely unknown, 1006 concatenation will discard information about the other shape. In 1007 future, we might support concatenation that preserves this 1008 information for use with slicing. 1009 1010 Args: 1011 other: Another `TensorShape`. 1012 1013 Returns: 1014 A `TensorShape` whose dimensions are the concatenation of the 1015 dimensions in `self` and `other`. 1016 """ 1017 # TODO(mrry): Handle the case where we concatenate a known shape with a 1018 # completely unknown shape, so that we can use the partial information. 1019 other = as_shape(other) 1020 if self._dims is None or other.dims is None: 1021 return unknown_shape() 1022 else: 1023 return TensorShape(self._dims + other.dims) 1024 1025 def assert_same_rank(self, other): 1026 """Raises an exception if `self` and `other` do not have compatible ranks. 1027 1028 Args: 1029 other: Another `TensorShape`. 1030 1031 Raises: 1032 ValueError: If `self` and `other` do not represent shapes with the 1033 same rank. 1034 """ 1035 other = as_shape(other) 1036 if self.rank is not None and other.rank is not None: 1037 if self.rank != other.rank: 1038 raise ValueError("Shapes %s and %s must have the same rank" % 1039 (self, other)) 1040 1041 def assert_has_rank(self, rank): 1042 """Raises an exception if `self` is not compatible with the given `rank`. 1043 1044 Args: 1045 rank: An integer. 1046 1047 Raises: 1048 ValueError: If `self` does not represent a shape with the given `rank`. 1049 """ 1050 if self.rank not in (None, rank): 1051 raise ValueError("Shape %s must have rank %d" % (self, rank)) 1052 1053 def with_rank(self, rank): 1054 """Returns a shape based on `self` with the given rank. 1055 1056 This method promotes a completely unknown shape to one with a 1057 known rank. 1058 1059 Args: 1060 rank: An integer. 1061 1062 Returns: 1063 A shape that is at least as specific as `self` with the given rank. 1064 1065 Raises: 1066 ValueError: If `self` does not represent a shape with the given `rank`. 1067 """ 1068 try: 1069 return self.merge_with(unknown_shape(rank=rank)) 1070 except ValueError: 1071 raise ValueError("Shape %s must have rank %d" % (self, rank)) 1072 1073 def with_rank_at_least(self, rank): 1074 """Returns a shape based on `self` with at least the given rank. 1075 1076 Args: 1077 rank: An integer. 1078 1079 Returns: 1080 A shape that is at least as specific as `self` with at least the given 1081 rank. 1082 1083 Raises: 1084 ValueError: If `self` does not represent a shape with at least the given 1085 `rank`. 1086 """ 1087 if self.rank is not None and self.rank < rank: 1088 raise ValueError("Shape %s must have rank at least %d" % (self, rank)) 1089 else: 1090 return self 1091 1092 def with_rank_at_most(self, rank): 1093 """Returns a shape based on `self` with at most the given rank. 1094 1095 Args: 1096 rank: An integer. 1097 1098 Returns: 1099 A shape that is at least as specific as `self` with at most the given 1100 rank. 1101 1102 Raises: 1103 ValueError: If `self` does not represent a shape with at most the given 1104 `rank`. 1105 """ 1106 if self.rank is not None and self.rank > rank: 1107 raise ValueError("Shape %s must have rank at most %d" % (self, rank)) 1108 else: 1109 return self 1110 1111 def is_compatible_with(self, other): 1112 """Returns True iff `self` is compatible with `other`. 1113 1114 Two possibly-partially-defined shapes are compatible if there 1115 exists a fully-defined shape that both shapes can represent. Thus, 1116 compatibility allows the shape inference code to reason about 1117 partially-defined shapes. For example: 1118 1119 * TensorShape(None) is compatible with all shapes. 1120 1121 * TensorShape([None, None]) is compatible with all two-dimensional 1122 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is 1123 not compatible with, for example, TensorShape([None]) or 1124 TensorShape([None, None, None]). 1125 1126 * TensorShape([32, None]) is compatible with all two-dimensional shapes 1127 with size 32 in the 0th dimension, and also TensorShape([None, None]) 1128 and TensorShape(None). It is not compatible with, for example, 1129 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). 1130 1131 * TensorShape([32, 784]) is compatible with itself, and also 1132 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, 1133 None]) and TensorShape(None). It is not compatible with, for example, 1134 TensorShape([32, 1, 784]) or TensorShape([None]). 1135 1136 The compatibility relation is reflexive and symmetric, but not 1137 transitive. For example, TensorShape([32, 784]) is compatible with 1138 TensorShape(None), and TensorShape(None) is compatible with 1139 TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with 1140 TensorShape([4, 4]). 1141 1142 Args: 1143 other: Another TensorShape. 1144 1145 Returns: 1146 True iff `self` is compatible with `other`. 1147 1148 """ 1149 other = as_shape(other) 1150 if self._dims is not None and other.dims is not None: 1151 if self.rank != other.rank: 1152 return False 1153 for x_dim, y_dim in zip(self._dims, other.dims): 1154 if not x_dim.is_compatible_with(y_dim): 1155 return False 1156 return True 1157 1158 def assert_is_compatible_with(self, other): 1159 """Raises exception if `self` and `other` do not represent the same shape. 1160 1161 This method can be used to assert that there exists a shape that both 1162 `self` and `other` represent. 1163 1164 Args: 1165 other: Another TensorShape. 1166 1167 Raises: 1168 ValueError: If `self` and `other` do not represent the same shape. 1169 """ 1170 if not self.is_compatible_with(other): 1171 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 1172 1173 def most_specific_compatible_shape(self, other): 1174 """Returns the most specific TensorShape compatible with `self` and `other`. 1175 1176 * TensorShape([None, 1]) is the most specific TensorShape compatible with 1177 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that 1178 TensorShape(None) is also compatible with above mentioned TensorShapes. 1179 1180 * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with 1181 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more 1182 less specific TensorShapes compatible with above mentioned TensorShapes, 1183 e.g. TensorShape([1, 2, None]), TensorShape(None). 1184 1185 Args: 1186 other: Another `TensorShape`. 1187 1188 Returns: 1189 A `TensorShape` which is the most specific compatible shape of `self` 1190 and `other`. 1191 """ 1192 1193 other = as_shape(other) 1194 if self._dims is None or other.dims is None or self.rank != other.rank: 1195 return unknown_shape() 1196 1197 dims = [ 1198 d1 if d1 is not None and d2 is not None and d1 == d2 else None 1199 for d1, d2 in zip(self._dims, other.dims) 1200 ] 1201 return TensorShape(dims) 1202 1203 def is_fully_defined(self): 1204 """Returns True iff `self` is fully defined in every dimension.""" 1205 return (self._dims is not None and 1206 all(dim.value is not None for dim in self._dims)) 1207 1208 def assert_is_fully_defined(self): 1209 """Raises an exception if `self` is not fully defined in every dimension. 1210 1211 Raises: 1212 ValueError: If `self` does not have a known value for every dimension. 1213 """ 1214 if not self.is_fully_defined(): 1215 raise ValueError("Shape %s is not fully defined" % self) 1216 1217 def as_list(self): 1218 """Returns a list of integers or `None` for each dimension. 1219 1220 Returns: 1221 A list of integers or `None` for each dimension. 1222 1223 Raises: 1224 ValueError: If `self` is an unknown shape with an unknown rank. 1225 """ 1226 if self._dims is None: 1227 raise ValueError("as_list() is not defined on an unknown TensorShape.") 1228 return [dim.value for dim in self._dims] 1229 1230 def as_proto(self): 1231 """Returns this shape as a `TensorShapeProto`.""" 1232 if self._dims is None: 1233 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) 1234 else: 1235 return tensor_shape_pb2.TensorShapeProto(dim=[ 1236 tensor_shape_pb2.TensorShapeProto.Dim( 1237 size=-1 if d.value is None else d.value) for d in self._dims 1238 ]) 1239 1240 def __eq__(self, other): 1241 """Returns True if `self` is equivalent to `other`. 1242 1243 It first tries to convert `other` to `TensorShape`. `TypeError` is thrown 1244 when the conversion fails. Otherwise, it compares each element in the 1245 TensorShape dimensions. 1246 1247 * Two *Fully known* shapes, return True iff each element is equal. 1248 >>> t_a = tf.TensorShape([1,2]) 1249 >>> a = [1, 2] 1250 >>> t_b = tf.TensorShape([1,2]) 1251 >>> t_c = tf.TensorShape([1,2,3]) 1252 >>> t_a.__eq__(a) 1253 True 1254 >>> t_a.__eq__(t_b) 1255 True 1256 >>> t_a.__eq__(t_c) 1257 False 1258 1259 * Two *Partially-known* shapes, return False. 1260 >>> p_a = tf.TensorShape([1,None]) 1261 >>> p_b = tf.TensorShape([2,None]) 1262 >>> p_a.__eq__(p_b) 1263 False 1264 >>> t_a.__eq__(p_a) 1265 False 1266 1267 * Two *Unknown shape*, return True. 1268 >>> unk_a = tf.TensorShape(None) 1269 >>> unk_b = tf.TensorShape(None) 1270 >>> unk_a.__eq__(unk_b) 1271 True 1272 >>> unk_a.__eq__(t_a) 1273 False 1274 1275 Args: 1276 other: A `TensorShape` or type that can be converted to `TensorShape`. 1277 1278 Returns: 1279 True if the dimensions are all equal. 1280 1281 Raises: 1282 TypeError if `other` can not be converted to `TensorShape`. 1283 """ 1284 1285 try: 1286 other = as_shape(other) 1287 except TypeError: 1288 return NotImplemented 1289 return self._dims == other.dims 1290 1291 def __ne__(self, other): 1292 """Returns True if `self` is known to be different from `other`.""" 1293 try: 1294 other = as_shape(other) 1295 except TypeError: 1296 return NotImplemented 1297 if self.rank is None or other.rank is None: 1298 raise ValueError("The inequality of unknown TensorShapes is undefined.") 1299 if self.rank != other.rank: 1300 return True 1301 return self._dims != other.dims 1302 1303 def __reduce__(self): 1304 return TensorShape, (self._dims,) 1305 1306 def __concat__(self, other): 1307 return self.concatenate(other) 1308 1309 1310def as_shape(shape): 1311 """Converts the given object to a TensorShape.""" 1312 if isinstance(shape, TensorShape): 1313 return shape 1314 else: 1315 return TensorShape(shape) 1316 1317 1318def unknown_shape(rank=None, **kwargs): 1319 """Returns an unknown TensorShape, optionally with a known rank. 1320 1321 Args: 1322 rank: (Optional) If specified, the number of dimensions in the shape. 1323 **kwargs: For backwards compatibility. 1324 1325 Returns: 1326 An unknown TensorShape. 1327 1328 Raises: 1329 TypeError: In case of invalid arguments. 1330 """ 1331 if rank is None and "ndims" in kwargs: 1332 rank = kwargs.pop("ndims") 1333 if kwargs: 1334 raise TypeError("Unknown argument: %s" % kwargs) 1335 if rank is None: 1336 return TensorShape(None) 1337 else: 1338 return TensorShape([Dimension(None)] * rank) 1339