1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes for tensor shape inference.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import operator 22import six 23 24from tensorflow.core.framework import tensor_shape_pb2 25from tensorflow.python import tf2 26from tensorflow.python.eager import monitoring 27from tensorflow.python.util.tf_export import tf_export 28 29_TENSORSHAPE_V2_OVERRIDE = None 30 31_api_usage_gauge = monitoring.BoolGauge( 32 "/tensorflow/api/v2_tensorshape", 33 "Whether tensor_shape.enable_v2_tensorshape() is called.") 34 35 36@tf_export(v1=["enable_v2_tensorshape"]) 37def enable_v2_tensorshape(): 38 """In TensorFlow 2.0, iterating over a TensorShape instance returns values. 39 40 This enables the new behavior. 41 42 Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but 43 it V2 it returns either an integer, or None. 44 45 Examples: 46 47 ``` 48 ####################### 49 # If you had this in V1: 50 value = tensor_shape[i].value 51 52 # Do this in V2 instead: 53 value = tensor_shape[i] 54 55 ####################### 56 # If you had this in V1: 57 for dim in tensor_shape: 58 value = dim.value 59 print(value) 60 61 # Do this in V2 instead: 62 for value in tensor_shape: 63 print(value) 64 65 ####################### 66 # If you had this in V1: 67 dim = tensor_shape[i] 68 dim.assert_is_compatible_with(other_shape) # or using any other shape method 69 70 # Do this in V2 instead: 71 if tensor_shape.rank is None: 72 dim = Dimension(None) 73 else: 74 dim = tensor_shape.dims[i] 75 dim.assert_is_compatible_with(other_shape) # or using any other shape method 76 77 # The V2 suggestion above is more explicit, which will save you from 78 # the following trap (present in V1): 79 # you might do in-place modifications to `dim` and expect them to be reflected 80 # in `tensor_shape[i]`, but they would not be. 81 ``` 82 """ 83 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 84 _TENSORSHAPE_V2_OVERRIDE = True 85 _api_usage_gauge.get_cell().set(True) 86 87 88@tf_export(v1=["disable_v2_tensorshape"]) 89def disable_v2_tensorshape(): 90 """Disables the V2 TensorShape behavior and reverts to V1 behavior. 91 92 See docstring for `enable_v2_tensorshape` for details about the new behavior. 93 """ 94 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 95 _TENSORSHAPE_V2_OVERRIDE = False 96 _api_usage_gauge.get_cell().set(False) 97 98 99@tf_export( 100 "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"]) 101def dimension_value(dimension): 102 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 103 104 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 105 coexist with the new behavior. This utility is a bridge between the two. 106 107 When accessing the value of a TensorShape dimension, 108 use this utility, like this: 109 110 ``` 111 # If you had this in your V1 code: 112 value = tensor_shape[i].value 113 114 # Use `dimension_value` as direct replacement compatible with both V1 & V2: 115 value = dimension_value(tensor_shape[i]) 116 117 # This would be the V2 equivalent: 118 value = tensor_shape[i] # Warning: this will return the dim value in V2! 119 ``` 120 121 Args: 122 dimension: Either a `Dimension` instance, an integer, or None. 123 124 Returns: 125 A plain value, i.e. an integer or None. 126 """ 127 if isinstance(dimension, Dimension): 128 return dimension.value 129 return dimension 130 131 132@tf_export( 133 "compat.dimension_at_index", 134 v1=["dimension_at_index", "compat.dimension_at_index"]) 135def dimension_at_index(shape, index): 136 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 137 138 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 139 coexist with the new behavior. This utility is a bridge between the two. 140 141 If you want to retrieve the Dimension instance corresponding to a certain 142 index in a TensorShape instance, use this utility, like this: 143 144 ``` 145 # If you had this in your V1 code: 146 dim = tensor_shape[i] 147 148 # Use `dimension_at_index` as direct replacement compatible with both V1 & V2: 149 dim = dimension_at_index(tensor_shape, i) 150 151 # Another possibility would be this, but WARNING: it only works if the 152 # tensor_shape instance has a defined rank. 153 dim = tensor_shape.dims[i] # `dims` may be None if the rank is undefined! 154 155 # In native V2 code, we recommend instead being more explicit: 156 if tensor_shape.rank is None: 157 dim = Dimension(None) 158 else: 159 dim = tensor_shape.dims[i] 160 161 # Being more explicit will save you from the following trap (present in V1): 162 # you might do in-place modifications to `dim` and expect them to be reflected 163 # in `tensor_shape[i]`, but they would not be (as the Dimension object was 164 # instantiated on the fly. 165 ``` 166 167 Args: 168 shape: A TensorShape instance. 169 index: An integer index. 170 171 Returns: 172 A dimension object. 173 """ 174 assert isinstance(shape, TensorShape) 175 if shape.rank is None: 176 return Dimension(None) 177 else: 178 return shape.dims[index] 179 180 181@tf_export(v1=["Dimension"]) 182class Dimension(object): 183 """Represents the value of one dimension in a TensorShape.""" 184 185 __slots__ = ["_value"] 186 187 def __init__(self, value): 188 """Creates a new Dimension with the given value.""" 189 if isinstance(value, int): # Most common case. 190 if value < 0: 191 raise ValueError("Dimension %d must be >= 0" % value) 192 self._value = value 193 elif value is None: 194 self._value = None 195 elif isinstance(value, Dimension): 196 self._value = value._value 197 else: 198 try: 199 # int(...) compensates for the int/long dichotomy on Python 2.X. 200 # TODO(b/143206389): Remove once we fully migrate to 3.X. 201 self._value = int(value.__index__()) 202 except AttributeError: 203 six.raise_from( 204 TypeError("Dimension value must be integer or None or have " 205 "an __index__ method, got value '{0!r}' with type '{1!r}'" 206 .format(value, type(value))), None) 207 if self._value < 0: 208 raise ValueError("Dimension %d must be >= 0" % self._value) 209 210 def __repr__(self): 211 return "Dimension(%s)" % repr(self._value) 212 213 def __str__(self): 214 value = self._value 215 return "?" if value is None else str(value) 216 217 def __eq__(self, other): 218 """Returns true if `other` has the same known value as this Dimension.""" 219 try: 220 other = as_dimension(other) 221 except (TypeError, ValueError): 222 return NotImplemented 223 if self._value is None or other.value is None: 224 return None 225 return self._value == other.value 226 227 def __ne__(self, other): 228 """Returns true if `other` has a different known value from `self`.""" 229 try: 230 other = as_dimension(other) 231 except (TypeError, ValueError): 232 return NotImplemented 233 if self._value is None or other.value is None: 234 return None 235 return self._value != other.value 236 237 def __bool__(self): 238 """Equivalent to `bool(self.value)`.""" 239 return bool(self._value) 240 241 def __int__(self): 242 return self._value 243 244 # This is needed for Windows. 245 # See https://github.com/tensorflow/tensorflow/pull/9780 246 def __long__(self): 247 return self._value 248 249 def __index__(self): 250 # Allow use in Python 3 range 251 return self._value 252 253 @property 254 def value(self): 255 """The value of this dimension, or None if it is unknown.""" 256 return self._value 257 258 def is_compatible_with(self, other): 259 """Returns true if `other` is compatible with this Dimension. 260 261 Two known Dimensions are compatible if they have the same value. 262 An unknown Dimension is compatible with all other Dimensions. 263 264 Args: 265 other: Another Dimension. 266 267 Returns: 268 True if this Dimension and `other` are compatible. 269 """ 270 other = as_dimension(other) 271 return (self._value is None or other.value is None or 272 self._value == other.value) 273 274 def assert_is_compatible_with(self, other): 275 """Raises an exception if `other` is not compatible with this Dimension. 276 277 Args: 278 other: Another Dimension. 279 280 Raises: 281 ValueError: If `self` and `other` are not compatible (see 282 is_compatible_with). 283 """ 284 if not self.is_compatible_with(other): 285 raise ValueError("Dimensions %s and %s are not compatible" % 286 (self, other)) 287 288 def merge_with(self, other): 289 """Returns a Dimension that combines the information in `self` and `other`. 290 291 Dimensions are combined as follows: 292 293 ```python 294 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(n)) == 295 tf.compat.v1.Dimension(n) 296 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(None)) == 297 tf.compat.v1.Dimension(n) 298 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n)) == 299 tf.compat.v1.Dimension(n) 300 # equivalent to tf.compat.v1.Dimension(None) 301 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None)) 302 303 # raises ValueError for n != m 304 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(m)) 305 ``` 306 307 Args: 308 other: Another Dimension. 309 310 Returns: 311 A Dimension containing the combined information of `self` and 312 `other`. 313 314 Raises: 315 ValueError: If `self` and `other` are not compatible (see 316 is_compatible_with). 317 """ 318 other = as_dimension(other) 319 self.assert_is_compatible_with(other) 320 if self._value is None: 321 return Dimension(other.value) 322 else: 323 return Dimension(self._value) 324 325 def __add__(self, other): 326 """Returns the sum of `self` and `other`. 327 328 Dimensions are summed as follows: 329 330 ```python 331 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(n) == 332 tf.compat.v1.Dimension(m + n) 333 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(None) # equiv. to 334 tf.compat.v1.Dimension(None) 335 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n) # equiv. to 336 tf.compat.v1.Dimension(None) 337 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None) # equiv. to 338 tf.compat.v1.Dimension(None) 339 ``` 340 341 Args: 342 other: Another Dimension, or a value accepted by `as_dimension`. 343 344 Returns: 345 A Dimension whose value is the sum of `self` and `other`. 346 """ 347 try: 348 other = as_dimension(other) 349 except (TypeError, ValueError): 350 return NotImplemented 351 if self._value is None or other.value is None: 352 return Dimension(None) 353 else: 354 return Dimension(self._value + other.value) 355 356 def __radd__(self, other): 357 """Returns the sum of `other` and `self`. 358 359 Args: 360 other: Another Dimension, or a value accepted by `as_dimension`. 361 362 Returns: 363 A Dimension whose value is the sum of `self` and `other`. 364 """ 365 return self + other 366 367 def __sub__(self, other): 368 """Returns the subtraction of `other` from `self`. 369 370 Dimensions are subtracted as follows: 371 372 ```python 373 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(n) == 374 tf.compat.v1.Dimension(m - n) 375 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(None) # equiv. to 376 tf.compat.v1.Dimension(None) 377 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n) # equiv. to 378 tf.compat.v1.Dimension(None) 379 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None) # equiv. to 380 tf.compat.v1.Dimension(None) 381 ``` 382 383 Args: 384 other: Another Dimension, or a value accepted by `as_dimension`. 385 386 Returns: 387 A Dimension whose value is the subtraction of `other` from `self`. 388 """ 389 try: 390 other = as_dimension(other) 391 except (TypeError, ValueError): 392 return NotImplemented 393 if self._value is None or other.value is None: 394 return Dimension(None) 395 else: 396 return Dimension(self._value - other.value) 397 398 def __rsub__(self, other): 399 """Returns the subtraction of `self` from `other`. 400 401 Args: 402 other: Another Dimension, or a value accepted by `as_dimension`. 403 404 Returns: 405 A Dimension whose value is the subtraction of `self` from `other`. 406 """ 407 other = as_dimension(other) 408 if self._value is None or other.value is None: 409 return Dimension(None) 410 else: 411 return Dimension(other.value - self._value) 412 413 def __mul__(self, other): 414 """Returns the product of `self` and `other`. 415 416 Dimensions are summed as follows: 417 418 ```python 419 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(n) == 420 tf.compat.v1.Dimension(m * n) 421 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(None) # equiv. to 422 tf.compat.v1.Dimension(None) 423 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n) # equiv. to 424 tf.compat.v1.Dimension(None) 425 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None) # equiv. to 426 tf.compat.v1.Dimension(None) 427 ``` 428 429 Args: 430 other: Another Dimension, or a value accepted by `as_dimension`. 431 432 Returns: 433 A Dimension whose value is the product of `self` and `other`. 434 """ 435 try: 436 other = as_dimension(other) 437 except (TypeError, ValueError): 438 return NotImplemented 439 440 if self._value is None or other.value is None: 441 return Dimension(None) 442 else: 443 return Dimension(self._value * other.value) 444 445 def __rmul__(self, other): 446 """Returns the product of `self` and `other`. 447 448 Args: 449 other: Another Dimension, or a value accepted by `as_dimension`. 450 451 Returns: 452 A Dimension whose value is the product of `self` and `other`. 453 """ 454 return self * other 455 456 def __floordiv__(self, other): 457 """Returns the quotient of `self` and `other` rounded down. 458 459 Dimensions are divided as follows: 460 461 ```python 462 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(n) == 463 tf.compat.v1.Dimension(m // n) 464 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(None) # equiv. to 465 tf.compat.v1.Dimension(None) 466 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n) # equiv. to 467 tf.compat.v1.Dimension(None) 468 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None) # equiv. to 469 tf.compat.v1.Dimension(None) 470 ``` 471 472 Args: 473 other: Another Dimension, or a value accepted by `as_dimension`. 474 475 Returns: 476 A `Dimension` whose value is the integer quotient of `self` and `other`. 477 """ 478 try: 479 other = as_dimension(other) 480 except (TypeError, ValueError): 481 return NotImplemented 482 if self._value is None or other.value is None: 483 return Dimension(None) 484 else: 485 return Dimension(self._value // other.value) 486 487 def __rfloordiv__(self, other): 488 """Returns the quotient of `other` and `self` rounded down. 489 490 Args: 491 other: Another Dimension, or a value accepted by `as_dimension`. 492 493 Returns: 494 A `Dimension` whose value is the integer quotient of `self` and `other`. 495 """ 496 other = as_dimension(other) 497 if self._value is None or other.value is None: 498 return Dimension(None) 499 else: 500 return Dimension(other.value // self._value) 501 502 def __div__(self, other): 503 """DEPRECATED: Use `__floordiv__` via `x // y` instead. 504 505 This function exists only for backwards compatibility purposes; new code 506 should use `__floordiv__` via the syntax `x // y`. Using `x // y` 507 communicates clearly that the result rounds down, and is forward compatible 508 to Python 3. 509 510 Args: 511 other: Another `Dimension`. 512 513 Returns: 514 A `Dimension` whose value is the integer quotient of `self` and `other`. 515 """ 516 return self // other 517 518 def __rdiv__(self, other): 519 """Use `__floordiv__` via `x // y` instead. 520 521 This function exists only to have a better error message. Instead of: 522 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 523 this function will explicitly call for usage of `//` instead. 524 525 Args: 526 other: Another `Dimension`. 527 528 Raises: 529 TypeError. 530 """ 531 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 532 "please use // instead".format(type(other).__name__)) 533 534 def __truediv__(self, other): 535 """Use `__floordiv__` via `x // y` instead. 536 537 This function exists only to have a better error message. Instead of: 538 `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`, 539 this function will explicitly call for usage of `//` instead. 540 541 Args: 542 other: Another `Dimension`. 543 544 Raises: 545 TypeError. 546 """ 547 raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', " 548 "please use // instead".format(type(other).__name__)) 549 550 def __rtruediv__(self, other): 551 """Use `__floordiv__` via `x // y` instead. 552 553 This function exists only to have a better error message. Instead of: 554 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 555 this function will explicitly call for usage of `//` instead. 556 557 Args: 558 other: Another `Dimension`. 559 560 Raises: 561 TypeError. 562 """ 563 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 564 "please use // instead".format(type(other).__name__)) 565 566 def __mod__(self, other): 567 """Returns `self` modulo `other`. 568 569 Dimension modulo are computed as follows: 570 571 ```python 572 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(n) == 573 tf.compat.v1.Dimension(m % n) 574 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(None) # equiv. to 575 tf.compat.v1.Dimension(None) 576 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n) # equiv. to 577 tf.compat.v1.Dimension(None) 578 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None) # equiv. to 579 tf.compat.v1.Dimension(None) 580 ``` 581 582 Args: 583 other: Another Dimension, or a value accepted by `as_dimension`. 584 585 Returns: 586 A Dimension whose value is `self` modulo `other`. 587 """ 588 other = as_dimension(other) 589 if self._value is None or other.value is None: 590 return Dimension(None) 591 else: 592 return Dimension(self._value % other.value) 593 594 def __rmod__(self, other): 595 """Returns `other` modulo `self`. 596 597 Args: 598 other: Another Dimension, or a value accepted by `as_dimension`. 599 600 Returns: 601 A Dimension whose value is `other` modulo `self`. 602 """ 603 other = as_dimension(other) 604 return other % self 605 606 def __lt__(self, other): 607 """Returns True if `self` is known to be less than `other`. 608 609 Dimensions are compared as follows: 610 611 ```python 612 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(n)) == (m < n) 613 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(None)) == None 614 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n)) == None 615 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None 616 ``` 617 618 Args: 619 other: Another Dimension. 620 621 Returns: 622 The value of `self.value < other.value` if both are known, otherwise 623 None. 624 """ 625 other = as_dimension(other) 626 if self._value is None or other.value is None: 627 return None 628 else: 629 return self._value < other.value 630 631 def __le__(self, other): 632 """Returns True if `self` is known to be less than or equal to `other`. 633 634 Dimensions are compared as follows: 635 636 ```python 637 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(n)) == (m <= n) 638 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(None)) == None 639 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n)) == None 640 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None 641 ``` 642 643 Args: 644 other: Another Dimension. 645 646 Returns: 647 The value of `self.value <= other.value` if both are known, otherwise 648 None. 649 """ 650 other = as_dimension(other) 651 if self._value is None or other.value is None: 652 return None 653 else: 654 return self._value <= other.value 655 656 def __gt__(self, other): 657 """Returns True if `self` is known to be greater than `other`. 658 659 Dimensions are compared as follows: 660 661 ```python 662 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(n)) == (m > n) 663 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(None)) == None 664 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n)) == None 665 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None 666 ``` 667 668 Args: 669 other: Another Dimension. 670 671 Returns: 672 The value of `self.value > other.value` if both are known, otherwise 673 None. 674 """ 675 other = as_dimension(other) 676 if self._value is None or other.value is None: 677 return None 678 else: 679 return self._value > other.value 680 681 def __ge__(self, other): 682 """Returns True if `self` is known to be greater than or equal to `other`. 683 684 Dimensions are compared as follows: 685 686 ```python 687 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(n)) == (m >= n) 688 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(None)) == None 689 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n)) == None 690 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None 691 ``` 692 693 Args: 694 other: Another Dimension. 695 696 Returns: 697 The value of `self.value >= other.value` if both are known, otherwise 698 None. 699 """ 700 other = as_dimension(other) 701 if self._value is None or other.value is None: 702 return None 703 else: 704 return self._value >= other.value 705 706 def __reduce__(self): 707 return Dimension, (self._value,) 708 709 710def as_dimension(value): 711 """Converts the given value to a Dimension. 712 713 A Dimension input will be returned unmodified. 714 An input of `None` will be converted to an unknown Dimension. 715 An integer input will be converted to a Dimension with that value. 716 717 Args: 718 value: The value to be converted. 719 720 Returns: 721 A Dimension corresponding to the given value. 722 """ 723 if isinstance(value, Dimension): 724 return value 725 else: 726 return Dimension(value) 727 728 729@tf_export("TensorShape") 730class TensorShape(object): 731 """Represents the shape of a `Tensor`. 732 733 A `TensorShape` represents a possibly-partial shape specification for a 734 `Tensor`. It may be one of the following: 735 736 * *Fully-known shape:* has a known number of dimensions and a known size 737 for each dimension. e.g. `TensorShape([16, 256])` 738 * *Partially-known shape:* has a known number of dimensions, and an unknown 739 size for one or more dimension. e.g. `TensorShape([None, 256])` 740 * *Unknown shape:* has an unknown number of dimensions, and an unknown 741 size in all dimensions. e.g. `TensorShape(None)` 742 743 If a tensor is produced by an operation of type `"Foo"`, its shape 744 may be inferred if there is a registered shape function for 745 `"Foo"`. See [Shape 746 functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c) 747 for details of shape functions and how to register them. Alternatively, 748 the shape may be set explicitly using `tf.Tensor.set_shape`. 749 """ 750 __slots__ = ["_dims"] 751 752 def __init__(self, dims): 753 """Creates a new TensorShape with the given dimensions. 754 755 Args: 756 dims: A list of Dimensions, or None if the shape is unspecified. 757 758 Raises: 759 TypeError: If dims cannot be converted to a list of dimensions. 760 """ 761 if isinstance(dims, (tuple, list)): # Most common case. 762 self._dims = [Dimension(d) for d in dims] 763 elif dims is None: 764 self._dims = None 765 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): 766 if dims.unknown_rank: 767 self._dims = None 768 else: 769 self._dims = [ 770 # Protos store variable-size dimensions as -1 771 as_dimension(dim.size if dim.size != -1 else None) 772 for dim in dims.dim 773 ] 774 elif isinstance(dims, TensorShape): 775 self._dims = dims.dims 776 else: 777 try: 778 dims_iter = iter(dims) 779 except TypeError: 780 # Treat as a singleton dimension 781 self._dims = [as_dimension(dims)] 782 else: 783 self._dims = [] 784 for d in dims_iter: 785 try: 786 self._dims.append(as_dimension(d)) 787 except TypeError as e: 788 six.raise_from( 789 TypeError( 790 "Failed to convert '{0!r}' to a shape: '{1!r}'" 791 "could not be converted to a dimension. A shape should " 792 "either be single dimension (e.g. 10), or an iterable of " 793 "dimensions (e.g. [1, 10, None])." 794 .format(dims, d)), e) 795 796 @property 797 def _v2_behavior(self): 798 if _TENSORSHAPE_V2_OVERRIDE is None: 799 return tf2.enabled() 800 return _TENSORSHAPE_V2_OVERRIDE 801 802 def __repr__(self): 803 if self._v2_behavior: 804 if self._dims is not None: 805 return "TensorShape(%r)" % [dim.value for dim in self._dims] 806 else: 807 return "TensorShape(None)" 808 else: 809 return "TensorShape(%r)" % self._dims 810 811 def __str__(self): 812 if self.rank is None: 813 return "<unknown>" 814 elif self.rank == 1: 815 if self._v2_behavior: 816 return "(%s,)" % self._dims[0].value 817 else: 818 return "(%s,)" % self._dims[0] 819 else: 820 if self._v2_behavior: 821 return "(%s)" % ", ".join(str(d.value) for d in self._dims) 822 else: 823 return "(%s)" % ", ".join(str(d) for d in self._dims) 824 825 @property 826 def rank(self): 827 """Returns the rank of this shape, or None if it is unspecified.""" 828 if self._dims is not None: 829 return len(self._dims) 830 return None 831 832 @property 833 def dims(self): 834 """Deprecated. Returns list of dimensions for this shape. 835 836 Suggest `TensorShape.as_list` instead. 837 838 Returns: 839 A list containing `tf.compat.v1.Dimension`s, or None if the shape is 840 unspecified. 841 """ 842 return self._dims 843 844 @property 845 def ndims(self): 846 """Deprecated accessor for `rank`.""" 847 return self.rank 848 849 def __len__(self): 850 """Returns the rank of this shape, or raises ValueError if unspecified.""" 851 if self._dims is None: 852 raise ValueError("Cannot take the length of shape with unknown rank.") 853 return len(self._dims) 854 855 def __bool__(self): 856 """Returns True if this shape contains non-zero information.""" 857 return self._dims is not None 858 859 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 860 __nonzero__ = __bool__ 861 862 def __iter__(self): 863 """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" 864 if self._dims is None: 865 raise ValueError("Cannot iterate over a shape with unknown rank.") 866 else: 867 if self._v2_behavior: 868 return iter(d.value for d in self._dims) 869 else: 870 return iter(d for d in self._dims) 871 872 def __getitem__(self, key): 873 """Returns the value of a dimension or a shape, depending on the key. 874 875 Args: 876 key: If `key` is an integer, returns the dimension at that index; 877 otherwise if `key` is a slice, returns a TensorShape whose dimensions 878 are those selected by the slice from `self`. 879 880 Returns: 881 An integer if `key` is an integer, or a `TensorShape` if `key` is a 882 slice. 883 884 Raises: 885 ValueError: If `key` is a slice and `self` is completely unknown and 886 the step is set. 887 """ 888 if self._dims is not None: 889 if isinstance(key, slice): 890 return TensorShape(self._dims[key]) 891 else: 892 if self._v2_behavior: 893 return self._dims[key].value 894 else: 895 return self._dims[key] 896 else: 897 if isinstance(key, slice): 898 start = key.start if key.start is not None else 0 899 stop = key.stop 900 901 if key.step is not None: 902 # TODO(mrry): Handle these maybe. 903 raise ValueError("Steps are not yet handled") 904 if stop is None: 905 # NOTE(mrry): This implies that TensorShape(None) is compatible with 906 # TensorShape(None)[1:], which is obviously not true. It would be 907 # possible to track the number of dimensions symbolically, 908 # and perhaps we should do that. 909 return unknown_shape() 910 elif start < 0 or stop < 0: 911 # TODO(mrry): Handle this better, as it will be useful for handling 912 # suffixes of otherwise unknown shapes. 913 return unknown_shape() 914 else: 915 return unknown_shape(rank=stop - start) 916 else: 917 if self._v2_behavior: 918 return None 919 else: 920 return Dimension(None) 921 922 def num_elements(self): 923 """Returns the total number of elements, or none for incomplete shapes.""" 924 if self.is_fully_defined(): 925 return functools.reduce(operator.mul, self.as_list(), 1) 926 else: 927 return None 928 929 def merge_with(self, other): 930 """Returns a `TensorShape` combining the information in `self` and `other`. 931 932 The dimensions in `self` and `other` are merged element-wise, 933 according to the rules below: 934 935 ```python 936 Dimension(n).merge_with(Dimension(None)) == Dimension(n) 937 Dimension(None).merge_with(Dimension(n)) == Dimension(n) 938 Dimension(None).merge_with(Dimension(None)) == Dimension(None) 939 # raises ValueError for n != m 940 Dimension(n).merge_with(Dimension(m)) 941 ``` 942 >> ts = tf.TensorShape([1,2]) 943 >> ot1 = tf.TensorShape([1,2]) 944 >> ts.merge_with(ot).as_list() 945 [1,2] 946 947 >> ot2 = tf.TensorShape([1,None]) 948 >> ts.merge_with(ot2).as_list() 949 [1,2] 950 951 >> ot3 = tf.TensorShape([None, None]) 952 >> ot3.merge_with(ot2).as_list() 953 [1, None] 954 955 Args: 956 other: Another `TensorShape`. 957 958 Returns: 959 A `TensorShape` containing the combined information of `self` and 960 `other`. 961 962 Raises: 963 ValueError: If `self` and `other` are not compatible. 964 """ 965 other = as_shape(other) 966 if self._dims is None: 967 return other 968 if other.dims is None: 969 return self 970 else: 971 try: 972 self.assert_same_rank(other) 973 new_dims = [ 974 dim.merge_with(other_dim) 975 for dim, other_dim in zip(self._dims, other.dims) 976 ] 977 return TensorShape(new_dims) 978 except ValueError: 979 raise ValueError("Shapes %s and %s are not compatible" % (self, other)) 980 981 def __add__(self, other): 982 return self.concatenate(other) 983 984 def __radd__(self, other): 985 if not isinstance(other, TensorShape): 986 other = TensorShape(other) 987 return other.concatenate(self) 988 989 def concatenate(self, other): 990 """Returns the concatenation of the dimension in `self` and `other`. 991 992 *N.B.* If either `self` or `other` is completely unknown, 993 concatenation will discard information about the other shape. In 994 future, we might support concatenation that preserves this 995 information for use with slicing. 996 997 Args: 998 other: Another `TensorShape`. 999 1000 Returns: 1001 A `TensorShape` whose dimensions are the concatenation of the 1002 dimensions in `self` and `other`. 1003 """ 1004 # TODO(mrry): Handle the case where we concatenate a known shape with a 1005 # completely unknown shape, so that we can use the partial information. 1006 other = as_shape(other) 1007 if self._dims is None or other.dims is None: 1008 return unknown_shape() 1009 else: 1010 return TensorShape(self._dims + other.dims) 1011 1012 def assert_same_rank(self, other): 1013 """Raises an exception if `self` and `other` do not have compatible ranks. 1014 1015 Args: 1016 other: Another `TensorShape`. 1017 1018 Raises: 1019 ValueError: If `self` and `other` do not represent shapes with the 1020 same rank. 1021 """ 1022 other = as_shape(other) 1023 if self.rank is not None and other.rank is not None: 1024 if self.rank != other.rank: 1025 raise ValueError("Shapes %s and %s must have the same rank" % 1026 (self, other)) 1027 1028 def assert_has_rank(self, rank): 1029 """Raises an exception if `self` is not compatible with the given `rank`. 1030 1031 Args: 1032 rank: An integer. 1033 1034 Raises: 1035 ValueError: If `self` does not represent a shape with the given `rank`. 1036 """ 1037 if self.rank not in (None, rank): 1038 raise ValueError("Shape %s must have rank %d" % (self, rank)) 1039 1040 def with_rank(self, rank): 1041 """Returns a shape based on `self` with the given rank. 1042 1043 This method promotes a completely unknown shape to one with a 1044 known rank. 1045 1046 Args: 1047 rank: An integer. 1048 1049 Returns: 1050 A shape that is at least as specific as `self` with the given rank. 1051 1052 Raises: 1053 ValueError: If `self` does not represent a shape with the given `rank`. 1054 """ 1055 try: 1056 return self.merge_with(unknown_shape(rank=rank)) 1057 except ValueError: 1058 raise ValueError("Shape %s must have rank %d" % (self, rank)) 1059 1060 def with_rank_at_least(self, rank): 1061 """Returns a shape based on `self` with at least the given rank. 1062 1063 Args: 1064 rank: An integer. 1065 1066 Returns: 1067 A shape that is at least as specific as `self` with at least the given 1068 rank. 1069 1070 Raises: 1071 ValueError: If `self` does not represent a shape with at least the given 1072 `rank`. 1073 """ 1074 if self.rank is not None and self.rank < rank: 1075 raise ValueError("Shape %s must have rank at least %d" % (self, rank)) 1076 else: 1077 return self 1078 1079 def with_rank_at_most(self, rank): 1080 """Returns a shape based on `self` with at most the given rank. 1081 1082 Args: 1083 rank: An integer. 1084 1085 Returns: 1086 A shape that is at least as specific as `self` with at most the given 1087 rank. 1088 1089 Raises: 1090 ValueError: If `self` does not represent a shape with at most the given 1091 `rank`. 1092 """ 1093 if self.rank is not None and self.rank > rank: 1094 raise ValueError("Shape %s must have rank at most %d" % (self, rank)) 1095 else: 1096 return self 1097 1098 def is_compatible_with(self, other): 1099 """Returns True iff `self` is compatible with `other`. 1100 1101 Two possibly-partially-defined shapes are compatible if there 1102 exists a fully-defined shape that both shapes can represent. Thus, 1103 compatibility allows the shape inference code to reason about 1104 partially-defined shapes. For example: 1105 1106 * TensorShape(None) is compatible with all shapes. 1107 1108 * TensorShape([None, None]) is compatible with all two-dimensional 1109 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is 1110 not compatible with, for example, TensorShape([None]) or 1111 TensorShape([None, None, None]). 1112 1113 * TensorShape([32, None]) is compatible with all two-dimensional shapes 1114 with size 32 in the 0th dimension, and also TensorShape([None, None]) 1115 and TensorShape(None). It is not compatible with, for example, 1116 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). 1117 1118 * TensorShape([32, 784]) is compatible with itself, and also 1119 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, 1120 None]) and TensorShape(None). It is not compatible with, for example, 1121 TensorShape([32, 1, 784]) or TensorShape([None]). 1122 1123 The compatibility relation is reflexive and symmetric, but not 1124 transitive. For example, TensorShape([32, 784]) is compatible with 1125 TensorShape(None), and TensorShape(None) is compatible with 1126 TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with 1127 TensorShape([4, 4]). 1128 1129 Args: 1130 other: Another TensorShape. 1131 1132 Returns: 1133 True iff `self` is compatible with `other`. 1134 1135 """ 1136 other = as_shape(other) 1137 if self._dims is not None and other.dims is not None: 1138 if self.rank != other.rank: 1139 return False 1140 for x_dim, y_dim in zip(self._dims, other.dims): 1141 if not x_dim.is_compatible_with(y_dim): 1142 return False 1143 return True 1144 1145 def assert_is_compatible_with(self, other): 1146 """Raises exception if `self` and `other` do not represent the same shape. 1147 1148 This method can be used to assert that there exists a shape that both 1149 `self` and `other` represent. 1150 1151 Args: 1152 other: Another TensorShape. 1153 1154 Raises: 1155 ValueError: If `self` and `other` do not represent the same shape. 1156 """ 1157 if not self.is_compatible_with(other): 1158 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 1159 1160 def most_specific_compatible_shape(self, other): 1161 """Returns the most specific TensorShape compatible with `self` and `other`. 1162 1163 * TensorShape([None, 1]) is the most specific TensorShape compatible with 1164 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that 1165 TensorShape(None) is also compatible with above mentioned TensorShapes. 1166 1167 * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with 1168 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more 1169 less specific TensorShapes compatible with above mentioned TensorShapes, 1170 e.g. TensorShape([1, 2, None]), TensorShape(None). 1171 1172 Args: 1173 other: Another `TensorShape`. 1174 1175 Returns: 1176 A `TensorShape` which is the most specific compatible shape of `self` 1177 and `other`. 1178 """ 1179 1180 other = as_shape(other) 1181 if self._dims is None or other.dims is None or self.rank != other.rank: 1182 return unknown_shape() 1183 1184 dims = [ 1185 d1 if d1 is not None and d2 is not None and d1 == d2 else None 1186 for d1, d2 in zip(self._dims, other.dims) 1187 ] 1188 return TensorShape(dims) 1189 1190 def is_fully_defined(self): 1191 """Returns True iff `self` is fully defined in every dimension.""" 1192 return (self._dims is not None and 1193 all(dim.value is not None for dim in self._dims)) 1194 1195 def assert_is_fully_defined(self): 1196 """Raises an exception if `self` is not fully defined in every dimension. 1197 1198 Raises: 1199 ValueError: If `self` does not have a known value for every dimension. 1200 """ 1201 if not self.is_fully_defined(): 1202 raise ValueError("Shape %s is not fully defined" % self) 1203 1204 def as_list(self): 1205 """Returns a list of integers or `None` for each dimension. 1206 1207 Returns: 1208 A list of integers or `None` for each dimension. 1209 1210 Raises: 1211 ValueError: If `self` is an unknown shape with an unknown rank. 1212 """ 1213 if self._dims is None: 1214 raise ValueError("as_list() is not defined on an unknown TensorShape.") 1215 return [dim.value for dim in self._dims] 1216 1217 def as_proto(self): 1218 """Returns this shape as a `TensorShapeProto`.""" 1219 if self._dims is None: 1220 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) 1221 else: 1222 return tensor_shape_pb2.TensorShapeProto(dim=[ 1223 tensor_shape_pb2.TensorShapeProto.Dim( 1224 size=-1 if d.value is None else d.value) for d in self._dims 1225 ]) 1226 1227 def __eq__(self, other): 1228 """Returns True if `self` is equivalent to `other`. 1229 1230 It first tries to convert `other` to `TensorShape`. `TypeError` is thrown 1231 when the conversion fails. Otherwise, it compares each element in the 1232 TensorShape dimensions. 1233 1234 * Two *Fully known* shapes, return True iff each element is equal. 1235 >>> t_a = tf.TensorShape([1,2]) 1236 >>> a = [1, 2] 1237 >>> t_b = tf.TensorShape([1,2]) 1238 >>> t_c = tf.TensorShape([1,2,3]) 1239 >>> t_a.__eq__(a) 1240 True 1241 >>> t_a.__eq__(t_b) 1242 True 1243 >>> t_a.__eq__(t_c) 1244 False 1245 1246 * Two *Partially-known* shapes, return False. 1247 >>> p_a = tf.TensorShape([1,None]) 1248 >>> p_b = tf.TensorShape([2,None]) 1249 >>> p_a.__eq__(p_b) 1250 False 1251 >>> t_a.__eq__(p_a) 1252 False 1253 1254 * Two *Unknown shape*, return True. 1255 >>> unk_a = tf.TensorShape(None) 1256 >>> unk_b = tf.TensorShape(None) 1257 >>> unk_a.__eq__(unk_b) 1258 True 1259 >>> unk_a.__eq__(t_a) 1260 False 1261 1262 Args: 1263 other: A `TensorShape` or type that can be converted to `TensorShape`. 1264 1265 Returns: 1266 True if the dimensions are all equal. 1267 1268 Raises: 1269 TypeError if `other` can not be converted to `TensorShape`. 1270 """ 1271 1272 try: 1273 other = as_shape(other) 1274 except TypeError: 1275 return NotImplemented 1276 return self._dims == other.dims 1277 1278 def __ne__(self, other): 1279 """Returns True if `self` is known to be different from `other`.""" 1280 try: 1281 other = as_shape(other) 1282 except TypeError: 1283 return NotImplemented 1284 if self.rank is None or other.rank is None: 1285 raise ValueError("The inequality of unknown TensorShapes is undefined.") 1286 if self.rank != other.rank: 1287 return True 1288 return self._dims != other.dims 1289 1290 def __reduce__(self): 1291 return TensorShape, (self._dims,) 1292 1293 def __concat__(self, other): 1294 return self.concatenate(other) 1295 1296 1297def as_shape(shape): 1298 """Converts the given object to a TensorShape.""" 1299 if isinstance(shape, TensorShape): 1300 return shape 1301 else: 1302 return TensorShape(shape) 1303 1304 1305def unknown_shape(rank=None, **kwargs): 1306 """Returns an unknown TensorShape, optionally with a known rank. 1307 1308 Args: 1309 rank: (Optional) If specified, the number of dimensions in the shape. 1310 **kwargs: For backwards compatibility. 1311 1312 Returns: 1313 An unknown TensorShape. 1314 1315 Raises: 1316 TypeError: In case of invalid arguments. 1317 """ 1318 if rank is None and "ndims" in kwargs: 1319 rank = kwargs.pop("ndims") 1320 if kwargs: 1321 raise TypeError("Unknown argument: %s" % kwargs) 1322 if rank is None: 1323 return TensorShape(None) 1324 else: 1325 return TensorShape([Dimension(None)] * rank) 1326