1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes for tensor shape inference.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.core.framework import tensor_shape_pb2 21from tensorflow.python import tf2 22from tensorflow.python.framework import dtypes 23from tensorflow.python.util import compat 24from tensorflow.python.util.tf_export import tf_export 25 26 27_TENSORSHAPE_V2_OVERRIDE = None 28 29 30@tf_export(v1=["enable_v2_tensorshape"]) 31def enable_v2_tensorshape(): 32 """In TensorFlow 2.0, iterating over a TensorShape instance returns values. 33 34 This enables the new behavior. 35 36 Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but 37 it V2 it returns either an integer, or None. 38 39 Examples: 40 41 ``` 42 ####################### 43 # If you had this in V1: 44 value = tensor_shape[i].value 45 46 # Do this in V2 instead: 47 value = tensor_shape[i] 48 49 ####################### 50 # If you had this in V1: 51 for dim in tensor_shape: 52 value = dim.value 53 print(value) 54 55 # Do this in V2 instead: 56 for value in tensor_shape: 57 print(value) 58 59 ####################### 60 # If you had this in V1: 61 dim = tensor_shape[i] 62 dim.assert_is_compatible_with(other_shape) # or using any other shape method 63 64 # Do this in V2 instead: 65 if tensor_shape.rank is None: 66 dim = Dimension(None) 67 else: 68 dim = tensor_shape.dims[i] 69 dim.assert_is_compatible_with(other_shape) # or using any other shape method 70 71 # The V2 suggestion above is more explicit, which will save you from 72 # the following trap (present in V1): 73 # you might do in-place modifications to `dim` and expect them to be reflected 74 # in `tensor_shape[i]`, but they would not be. 75 ``` 76 """ 77 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 78 _TENSORSHAPE_V2_OVERRIDE = True 79 80 81@tf_export(v1=["disable_v2_tensorshape"]) 82def disable_v2_tensorshape(): 83 """Disables the V2 TensorShape behavior and reverts to V1 behavior. 84 85 See docstring for `enable_v2_tensorshape` for details about the new behavior. 86 """ 87 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 88 _TENSORSHAPE_V2_OVERRIDE = False 89 90 91@tf_export("compat.dimension_value", 92 v1=["dimension_value", "compat.dimension_value"]) 93def dimension_value(dimension): 94 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 95 96 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 97 coexist with the new behavior. This utility is a bridge between the two. 98 99 When accessing the value of a TensorShape dimension, 100 use this utility, like this: 101 102 ``` 103 # If you had this in your V1 code: 104 value = tensor_shape[i].value 105 106 # Use `dimension_value` as direct replacement compatible with both V1 & V2: 107 value = dimension_value(tensor_shape[i]) 108 109 # This would be the V2 equivalent: 110 value = tensor_shape[i] # Warning: this will return the dim value in V2! 111 ``` 112 113 Arguments: 114 dimension: Either a `Dimension` instance, an integer, or None. 115 116 Returns: 117 A plain value, i.e. an integer or None. 118 """ 119 if isinstance(dimension, Dimension): 120 return dimension.value 121 return dimension 122 123 124@tf_export("compat.dimension_at_index", 125 v1=["dimension_at_index", "compat.dimension_at_index"]) 126def dimension_at_index(shape, index): 127 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 128 129 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 130 coexist with the new behavior. This utility is a bridge between the two. 131 132 If you want to retrieve the Dimension instance corresponding to a certain 133 index in a TensorShape instance, use this utility, like this: 134 135 ``` 136 # If you had this in your V1 code: 137 dim = tensor_shape[i] 138 139 # Use `dimension_at_index` as direct replacement compatible with both V1 & V2: 140 dim = dimension_at_index(tensor_shape, i) 141 142 # Another possibility would be this, but WARNING: it only works if the 143 # tensor_shape instance has a defined rank. 144 dim = tensor_shape.dims[i] # `dims` may be None if the rank is undefined! 145 146 # In native V2 code, we recommend instead being more explicit: 147 if tensor_shape.rank is None: 148 dim = Dimension(None) 149 else: 150 dim = tensor_shape.dims[i] 151 152 # Being more explicit will save you from the following trap (present in V1): 153 # you might do in-place modifications to `dim` and expect them to be reflected 154 # in `tensor_shape[i]`, but they would not be (as the Dimension object was 155 # instantiated on the fly. 156 ``` 157 158 Arguments: 159 shape: A TensorShape instance. 160 index: An integer index. 161 162 Returns: 163 A dimension object. 164 """ 165 assert isinstance(shape, TensorShape) 166 if shape.rank is None: 167 return Dimension(None) 168 else: 169 return shape.dims[index] 170 171 172@tf_export(v1=["Dimension"]) 173class Dimension(object): 174 """Represents the value of one dimension in a TensorShape.""" 175 176 def __init__(self, value): 177 """Creates a new Dimension with the given value.""" 178 if value is None: 179 self._value = None 180 elif isinstance(value, Dimension): 181 self._value = value.value 182 elif isinstance(value, dtypes.DType): 183 raise TypeError("Cannot convert %s to Dimension" % value) 184 else: 185 self._value = int(value) 186 if (not isinstance(value, compat.bytes_or_text_types) and 187 self._value != value): 188 raise ValueError("Ambiguous dimension: %s" % value) 189 if self._value < 0: 190 raise ValueError("Dimension %d must be >= 0" % self._value) 191 192 def __repr__(self): 193 return "Dimension(%s)" % repr(self._value) 194 195 def __str__(self): 196 value = self._value 197 return "?" if value is None else str(value) 198 199 def __eq__(self, other): 200 """Returns true if `other` has the same known value as this Dimension.""" 201 try: 202 other = as_dimension(other) 203 except (TypeError, ValueError): 204 return NotImplemented 205 if self._value is None or other.value is None: 206 return None 207 return self._value == other.value 208 209 def __ne__(self, other): 210 """Returns true if `other` has a different known value from `self`.""" 211 try: 212 other = as_dimension(other) 213 except (TypeError, ValueError): 214 return NotImplemented 215 if self._value is None or other.value is None: 216 return None 217 return self._value != other.value 218 219 def __int__(self): 220 return self._value 221 222 # This is needed for Windows. 223 # See https://github.com/tensorflow/tensorflow/pull/9780 224 def __long__(self): 225 return self._value 226 227 def __index__(self): 228 # Allow use in Python 3 range 229 return self._value 230 231 @property 232 def value(self): 233 """The value of this dimension, or None if it is unknown.""" 234 return self._value 235 236 def is_compatible_with(self, other): 237 """Returns true if `other` is compatible with this Dimension. 238 239 Two known Dimensions are compatible if they have the same value. 240 An unknown Dimension is compatible with all other Dimensions. 241 242 Args: 243 other: Another Dimension. 244 245 Returns: 246 True if this Dimension and `other` are compatible. 247 """ 248 other = as_dimension(other) 249 return (self._value is None or other.value is None or 250 self._value == other.value) 251 252 def assert_is_compatible_with(self, other): 253 """Raises an exception if `other` is not compatible with this Dimension. 254 255 Args: 256 other: Another Dimension. 257 258 Raises: 259 ValueError: If `self` and `other` are not compatible (see 260 is_compatible_with). 261 """ 262 if not self.is_compatible_with(other): 263 raise ValueError("Dimensions %s and %s are not compatible" % (self, 264 other)) 265 266 def merge_with(self, other): 267 """Returns a Dimension that combines the information in `self` and `other`. 268 269 Dimensions are combined as follows: 270 271 ```python 272 tf.Dimension(n) .merge_with(tf.Dimension(n)) == tf.Dimension(n) 273 tf.Dimension(n) .merge_with(tf.Dimension(None)) == tf.Dimension(n) 274 tf.Dimension(None).merge_with(tf.Dimension(n)) == tf.Dimension(n) 275 # equivalent to tf.Dimension(None) 276 tf.Dimension(None).merge_with(tf.Dimension(None)) 277 278 # raises ValueError for n != m 279 tf.Dimension(n) .merge_with(tf.Dimension(m)) 280 ``` 281 282 Args: 283 other: Another Dimension. 284 285 Returns: 286 A Dimension containing the combined information of `self` and 287 `other`. 288 289 Raises: 290 ValueError: If `self` and `other` are not compatible (see 291 is_compatible_with). 292 """ 293 other = as_dimension(other) 294 self.assert_is_compatible_with(other) 295 if self._value is None: 296 return Dimension(other.value) 297 else: 298 return Dimension(self._value) 299 300 def __add__(self, other): 301 """Returns the sum of `self` and `other`. 302 303 Dimensions are summed as follows: 304 305 ```python 306 tf.Dimension(m) + tf.Dimension(n) == tf.Dimension(m + n) 307 tf.Dimension(m) + tf.Dimension(None) # equiv. to tf.Dimension(None) 308 tf.Dimension(None) + tf.Dimension(n) # equiv. to tf.Dimension(None) 309 tf.Dimension(None) + tf.Dimension(None) # equiv. to tf.Dimension(None) 310 ``` 311 312 Args: 313 other: Another Dimension, or a value accepted by `as_dimension`. 314 315 Returns: 316 A Dimension whose value is the sum of `self` and `other`. 317 """ 318 other = as_dimension(other) 319 if self._value is None or other.value is None: 320 return Dimension(None) 321 else: 322 return Dimension(self._value + other.value) 323 324 def __radd__(self, other): 325 """Returns the sum of `other` and `self`. 326 327 Args: 328 other: Another Dimension, or a value accepted by `as_dimension`. 329 330 Returns: 331 A Dimension whose value is the sum of `self` and `other`. 332 """ 333 return self + other 334 335 def __sub__(self, other): 336 """Returns the subtraction of `other` from `self`. 337 338 Dimensions are subtracted as follows: 339 340 ```python 341 tf.Dimension(m) - tf.Dimension(n) == tf.Dimension(m - n) 342 tf.Dimension(m) - tf.Dimension(None) # equiv. to tf.Dimension(None) 343 tf.Dimension(None) - tf.Dimension(n) # equiv. to tf.Dimension(None) 344 tf.Dimension(None) - tf.Dimension(None) # equiv. to tf.Dimension(None) 345 ``` 346 347 Args: 348 other: Another Dimension, or a value accepted by `as_dimension`. 349 350 Returns: 351 A Dimension whose value is the subtraction of `other` from `self`. 352 """ 353 other = as_dimension(other) 354 if self._value is None or other.value is None: 355 return Dimension(None) 356 else: 357 return Dimension(self._value - other.value) 358 359 def __rsub__(self, other): 360 """Returns the subtraction of `self` from `other`. 361 362 Args: 363 other: Another Dimension, or a value accepted by `as_dimension`. 364 365 Returns: 366 A Dimension whose value is the subtraction of `self` from `other`. 367 """ 368 other = as_dimension(other) 369 if self._value is None or other.value is None: 370 return Dimension(None) 371 else: 372 return Dimension(other.value - self._value) 373 374 def __mul__(self, other): 375 """Returns the product of `self` and `other`. 376 377 Dimensions are summed as follows: 378 379 ```python 380 tf.Dimension(m) * tf.Dimension(n) == tf.Dimension(m * n) 381 tf.Dimension(m) * tf.Dimension(None) # equiv. to tf.Dimension(None) 382 tf.Dimension(None) * tf.Dimension(n) # equiv. to tf.Dimension(None) 383 tf.Dimension(None) * tf.Dimension(None) # equiv. to tf.Dimension(None) 384 ``` 385 386 Args: 387 other: Another Dimension, or a value accepted by `as_dimension`. 388 389 Returns: 390 A Dimension whose value is the product of `self` and `other`. 391 """ 392 try: 393 other = as_dimension(other) 394 except (TypeError, ValueError): 395 return NotImplemented 396 397 if self._value is None or other.value is None: 398 return Dimension(None) 399 else: 400 return Dimension(self._value * other.value) 401 402 def __rmul__(self, other): 403 """Returns the product of `self` and `other`. 404 405 Args: 406 other: Another Dimension, or a value accepted by `as_dimension`. 407 408 Returns: 409 A Dimension whose value is the product of `self` and `other`. 410 """ 411 return self * other 412 413 def __floordiv__(self, other): 414 """Returns the quotient of `self` and `other` rounded down. 415 416 Dimensions are divided as follows: 417 418 ```python 419 tf.Dimension(m) // tf.Dimension(n) == tf.Dimension(m // n) 420 tf.Dimension(m) // tf.Dimension(None) # equiv. to tf.Dimension(None) 421 tf.Dimension(None) // tf.Dimension(n) # equiv. to tf.Dimension(None) 422 tf.Dimension(None) // tf.Dimension(None) # equiv. to tf.Dimension(None) 423 ``` 424 425 Args: 426 other: Another Dimension, or a value accepted by `as_dimension`. 427 428 Returns: 429 A `Dimension` whose value is the integer quotient of `self` and `other`. 430 """ 431 try: 432 other = as_dimension(other) 433 except (TypeError, ValueError): 434 return NotImplemented 435 if self._value is None or other.value is None: 436 return Dimension(None) 437 else: 438 return Dimension(self._value // other.value) 439 440 def __rfloordiv__(self, other): 441 """Returns the quotient of `other` and `self` rounded down. 442 443 Args: 444 other: Another Dimension, or a value accepted by `as_dimension`. 445 446 Returns: 447 A `Dimension` whose value is the integer quotient of `self` and `other`. 448 """ 449 other = as_dimension(other) 450 if self._value is None or other.value is None: 451 return Dimension(None) 452 else: 453 return Dimension(other.value // self._value) 454 455 def __div__(self, other): 456 """DEPRECATED: Use `__floordiv__` via `x // y` instead. 457 458 This function exists only for backwards compatibility purposes; new code 459 should use `__floordiv__` via the syntax `x // y`. Using `x // y` 460 communicates clearly that the result rounds down, and is forward compatible 461 to Python 3. 462 463 Args: 464 other: Another `Dimension`. 465 466 Returns: 467 A `Dimension` whose value is the integer quotient of `self` and `other`. 468 """ 469 return self // other 470 471 def __rdiv__(self, other): 472 """Use `__floordiv__` via `x // y` instead. 473 474 This function exists only to have a better error message. Instead of: 475 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 476 this function will explicitly call for usage of `//` instead. 477 478 Args: 479 other: Another `Dimension`. 480 481 Raises: 482 TypeError. 483 """ 484 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 485 "please use // instead".format(type(other).__name__)) 486 487 def __truediv__(self, other): 488 """Use `__floordiv__` via `x // y` instead. 489 490 This function exists only to have a better error message. Instead of: 491 `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`, 492 this function will explicitly call for usage of `//` instead. 493 494 Args: 495 other: Another `Dimension`. 496 497 Raises: 498 TypeError. 499 """ 500 raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', " 501 "please use // instead".format(type(other).__name__)) 502 503 def __rtruediv__(self, other): 504 """Use `__floordiv__` via `x // y` instead. 505 506 This function exists only to have a better error message. Instead of: 507 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 508 this function will explicitly call for usage of `//` instead. 509 510 Args: 511 other: Another `Dimension`. 512 513 Raises: 514 TypeError. 515 """ 516 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 517 "please use // instead".format(type(other).__name__)) 518 519 def __mod__(self, other): 520 """Returns `self` modulo `other`. 521 522 Dimension moduli are computed as follows: 523 524 ```python 525 tf.Dimension(m) % tf.Dimension(n) == tf.Dimension(m % n) 526 tf.Dimension(m) % tf.Dimension(None) # equiv. to tf.Dimension(None) 527 tf.Dimension(None) % tf.Dimension(n) # equiv. to tf.Dimension(None) 528 tf.Dimension(None) % tf.Dimension(None) # equiv. to tf.Dimension(None) 529 ``` 530 531 Args: 532 other: Another Dimension, or a value accepted by `as_dimension`. 533 534 Returns: 535 A Dimension whose value is `self` modulo `other`. 536 """ 537 try: 538 other = as_dimension(other) 539 except (TypeError, ValueError): 540 return NotImplemented 541 if self._value is None or other.value is None: 542 return Dimension(None) 543 else: 544 return Dimension(self._value % other.value) 545 546 def __rmod__(self, other): 547 """Returns `other` modulo `self`. 548 549 Args: 550 other: Another Dimension, or a value accepted by `as_dimension`. 551 552 Returns: 553 A Dimension whose value is `other` modulo `self`. 554 """ 555 try: 556 other = as_dimension(other) 557 except (TypeError, ValueError): 558 return NotImplemented 559 return other % self 560 561 def __lt__(self, other): 562 """Returns True if `self` is known to be less than `other`. 563 564 Dimensions are compared as follows: 565 566 ```python 567 (tf.Dimension(m) < tf.Dimension(n)) == (m < n) 568 (tf.Dimension(m) < tf.Dimension(None)) == None 569 (tf.Dimension(None) < tf.Dimension(n)) == None 570 (tf.Dimension(None) < tf.Dimension(None)) == None 571 ``` 572 573 Args: 574 other: Another Dimension. 575 576 Returns: 577 The value of `self.value < other.value` if both are known, otherwise 578 None. 579 """ 580 other = as_dimension(other) 581 if self._value is None or other.value is None: 582 return None 583 else: 584 return self._value < other.value 585 586 def __le__(self, other): 587 """Returns True if `self` is known to be less than or equal to `other`. 588 589 Dimensions are compared as follows: 590 591 ```python 592 (tf.Dimension(m) <= tf.Dimension(n)) == (m <= n) 593 (tf.Dimension(m) <= tf.Dimension(None)) == None 594 (tf.Dimension(None) <= tf.Dimension(n)) == None 595 (tf.Dimension(None) <= tf.Dimension(None)) == None 596 ``` 597 598 Args: 599 other: Another Dimension. 600 601 Returns: 602 The value of `self.value <= other.value` if both are known, otherwise 603 None. 604 """ 605 other = as_dimension(other) 606 if self._value is None or other.value is None: 607 return None 608 else: 609 return self._value <= other.value 610 611 def __gt__(self, other): 612 """Returns True if `self` is known to be greater than `other`. 613 614 Dimensions are compared as follows: 615 616 ```python 617 (tf.Dimension(m) > tf.Dimension(n)) == (m > n) 618 (tf.Dimension(m) > tf.Dimension(None)) == None 619 (tf.Dimension(None) > tf.Dimension(n)) == None 620 (tf.Dimension(None) > tf.Dimension(None)) == None 621 ``` 622 623 Args: 624 other: Another Dimension. 625 626 Returns: 627 The value of `self.value > other.value` if both are known, otherwise 628 None. 629 """ 630 other = as_dimension(other) 631 if self._value is None or other.value is None: 632 return None 633 else: 634 return self._value > other.value 635 636 def __ge__(self, other): 637 """Returns True if `self` is known to be greater than or equal to `other`. 638 639 Dimensions are compared as follows: 640 641 ```python 642 (tf.Dimension(m) >= tf.Dimension(n)) == (m >= n) 643 (tf.Dimension(m) >= tf.Dimension(None)) == None 644 (tf.Dimension(None) >= tf.Dimension(n)) == None 645 (tf.Dimension(None) >= tf.Dimension(None)) == None 646 ``` 647 648 Args: 649 other: Another Dimension. 650 651 Returns: 652 The value of `self.value >= other.value` if both are known, otherwise 653 None. 654 """ 655 other = as_dimension(other) 656 if self._value is None or other.value is None: 657 return None 658 else: 659 return self._value >= other.value 660 661 def __reduce__(self): 662 return Dimension, (self._value,) 663 664 665def as_dimension(value): 666 """Converts the given value to a Dimension. 667 668 A Dimension input will be returned unmodified. 669 An input of `None` will be converted to an unknown Dimension. 670 An integer input will be converted to a Dimension with that value. 671 672 Args: 673 value: The value to be converted. 674 675 Returns: 676 A Dimension corresponding to the given value. 677 """ 678 if isinstance(value, Dimension): 679 return value 680 else: 681 return Dimension(value) 682 683 684@tf_export("TensorShape") 685class TensorShape(object): 686 """Represents the shape of a `Tensor`. 687 688 A `TensorShape` represents a possibly-partial shape specification for a 689 `Tensor`. It may be one of the following: 690 691 * *Fully-known shape:* has a known number of dimensions and a known size 692 for each dimension. e.g. `TensorShape([16, 256])` 693 * *Partially-known shape:* has a known number of dimensions, and an unknown 694 size for one or more dimension. e.g. `TensorShape([None, 256])` 695 * *Unknown shape:* has an unknown number of dimensions, and an unknown 696 size in all dimensions. e.g. `TensorShape(None)` 697 698 If a tensor is produced by an operation of type `"Foo"`, its shape 699 may be inferred if there is a registered shape function for 700 `"Foo"`. See [Shape 701 functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c) 702 for details of shape functions and how to register them. Alternatively, 703 the shape may be set explicitly using `tf.Tensor.set_shape`. 704 """ 705 706 def __init__(self, dims): 707 """Creates a new TensorShape with the given dimensions. 708 709 Args: 710 dims: A list of Dimensions, or None if the shape is unspecified. 711 712 Raises: 713 TypeError: If dims cannot be converted to a list of dimensions. 714 """ 715 if dims is None: 716 self._dims = None 717 elif isinstance(dims, compat.bytes_or_text_types): 718 raise TypeError("A string has ambiguous TensorShape, please wrap in a " 719 "list or convert to an int: %s" % dims) 720 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): 721 if dims.unknown_rank: 722 self._dims = None 723 else: 724 self._dims = [ 725 # Protos store variable-size dimensions as -1 726 as_dimension(dim.size if dim.size != -1 else None) 727 for dim in dims.dim 728 ] 729 elif isinstance(dims, TensorShape): 730 self._dims = dims.dims 731 else: 732 try: 733 dims_iter = iter(dims) 734 except TypeError: 735 # Treat as a singleton dimension 736 self._dims = [as_dimension(dims)] 737 else: 738 # Got a list of dimensions 739 self._dims = [as_dimension(d) for d in dims_iter] 740 741 @property 742 def _v2_behavior(self): 743 if _TENSORSHAPE_V2_OVERRIDE is None: 744 return tf2.enabled() 745 return _TENSORSHAPE_V2_OVERRIDE 746 747 def __repr__(self): 748 if self._v2_behavior: 749 if self._dims is not None: 750 return "TensorShape(%r)" % [dim.value for dim in self._dims] 751 else: 752 return "TensorShape(None)" 753 else: 754 return "TensorShape(%r)" % self._dims 755 756 def __str__(self): 757 if self.rank is None: 758 return "<unknown>" 759 elif self.rank == 1: 760 if self._v2_behavior: 761 return "(%s,)" % self._dims[0].value 762 else: 763 return "(%s,)" % self._dims[0] 764 else: 765 if self._v2_behavior: 766 return "(%s)" % ", ".join(str(d.value) for d in self._dims) 767 else: 768 return "(%s)" % ", ".join(str(d) for d in self._dims) 769 770 @property 771 def rank(self): 772 """Returns the rank of this shape, or None if it is unspecified.""" 773 if self._dims is not None: 774 return len(self._dims) 775 return None 776 777 @property 778 def dims(self): 779 """Returns a list of Dimensions, or None if the shape is unspecified.""" 780 return self._dims 781 782 @dims.setter 783 def dims(self, dims): 784 self._dims = dims 785 786 @property 787 def ndims(self): 788 """Deprecated accessor for `rank`.""" 789 return self.rank 790 791 def __len__(self): 792 """Returns the rank of this shape, or raises ValueError if unspecified.""" 793 if self._dims is None: 794 raise ValueError("Cannot take the length of shape with unknown rank.") 795 return len(self._dims) 796 797 def __bool__(self): 798 """Returns True if this shape contains non-zero information.""" 799 return self._dims is not None 800 801 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 802 __nonzero__ = __bool__ 803 804 def __iter__(self): 805 """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" 806 if self._dims is None: 807 raise ValueError("Cannot iterate over a shape with unknown rank.") 808 else: 809 if self._v2_behavior: 810 return iter(d.value for d in self._dims) 811 else: 812 return iter(d for d in self._dims) 813 814 def __getitem__(self, key): 815 """Returns the value of a dimension or a shape, depending on the key. 816 817 Args: 818 key: If `key` is an integer, returns the dimension at that index; 819 otherwise if `key` is a slice, returns a TensorShape whose 820 dimensions are those selected by the slice from `self`. 821 822 Returns: 823 An integer if `key` is an integer, or a `TensorShape` if `key` is a 824 slice. 825 826 Raises: 827 ValueError: If `key` is a slice and `self` is completely unknown and 828 the step is set. 829 """ 830 if self._dims is not None: 831 if isinstance(key, slice): 832 return TensorShape(self._dims[key]) 833 else: 834 if self._v2_behavior: 835 return self._dims[key].value 836 else: 837 return self._dims[key] 838 else: 839 if isinstance(key, slice): 840 start = key.start if key.start is not None else 0 841 stop = key.stop 842 843 if key.step is not None: 844 # TODO(mrry): Handle these maybe. 845 raise ValueError("Steps are not yet handled") 846 if stop is None: 847 # NOTE(mrry): This implies that TensorShape(None) is compatible with 848 # TensorShape(None)[1:], which is obviously not true. It would be 849 # possible to track the number of dimensions symbolically, 850 # and perhaps we should do that. 851 return unknown_shape() 852 elif start < 0 or stop < 0: 853 # TODO(mrry): Handle this better, as it will be useful for handling 854 # suffixes of otherwise unknown shapes. 855 return unknown_shape() 856 else: 857 return unknown_shape(rank=stop - start) 858 else: 859 if self._v2_behavior: 860 return None 861 else: 862 return Dimension(None) 863 864 def num_elements(self): 865 """Returns the total number of elements, or none for incomplete shapes.""" 866 if self.is_fully_defined(): 867 size = 1 868 for dim in self._dims: 869 size *= dim.value 870 return size 871 else: 872 return None 873 874 def merge_with(self, other): 875 """Returns a `TensorShape` combining the information in `self` and `other`. 876 877 The dimensions in `self` and `other` are merged elementwise, 878 according to the rules defined for `Dimension.merge_with()`. 879 880 Args: 881 other: Another `TensorShape`. 882 883 Returns: 884 A `TensorShape` containing the combined information of `self` and 885 `other`. 886 887 Raises: 888 ValueError: If `self` and `other` are not compatible. 889 """ 890 other = as_shape(other) 891 if self._dims is None: 892 return other 893 else: 894 try: 895 self.assert_same_rank(other) 896 new_dims = [] 897 for i, dim in enumerate(self._dims): 898 new_dims.append(dim.merge_with(other[i])) 899 return TensorShape(new_dims) 900 except ValueError: 901 raise ValueError("Shapes %s and %s are not compatible" % (self, other)) 902 903 def concatenate(self, other): 904 """Returns the concatenation of the dimension in `self` and `other`. 905 906 *N.B.* If either `self` or `other` is completely unknown, 907 concatenation will discard information about the other shape. In 908 future, we might support concatenation that preserves this 909 information for use with slicing. 910 911 Args: 912 other: Another `TensorShape`. 913 914 Returns: 915 A `TensorShape` whose dimensions are the concatenation of the 916 dimensions in `self` and `other`. 917 """ 918 # TODO(mrry): Handle the case where we concatenate a known shape with a 919 # completely unknown shape, so that we can use the partial information. 920 other = as_shape(other) 921 if self._dims is None or other.dims is None: 922 return unknown_shape() 923 else: 924 return TensorShape(self._dims + other.dims) 925 926 def assert_same_rank(self, other): 927 """Raises an exception if `self` and `other` do not have compatible ranks. 928 929 Args: 930 other: Another `TensorShape`. 931 932 Raises: 933 ValueError: If `self` and `other` do not represent shapes with the 934 same rank. 935 """ 936 other = as_shape(other) 937 if self.rank is not None and other.rank is not None: 938 if self.rank != other.rank: 939 raise ValueError("Shapes %s and %s must have the same rank" % (self, 940 other)) 941 942 def assert_has_rank(self, rank): 943 """Raises an exception if `self` is not compatible with the given `rank`. 944 945 Args: 946 rank: An integer. 947 948 Raises: 949 ValueError: If `self` does not represent a shape with the given `rank`. 950 """ 951 if self.rank not in (None, rank): 952 raise ValueError("Shape %s must have rank %d" % (self, rank)) 953 954 def with_rank(self, rank): 955 """Returns a shape based on `self` with the given rank. 956 957 This method promotes a completely unknown shape to one with a 958 known rank. 959 960 Args: 961 rank: An integer. 962 963 Returns: 964 A shape that is at least as specific as `self` with the given rank. 965 966 Raises: 967 ValueError: If `self` does not represent a shape with the given `rank`. 968 """ 969 try: 970 return self.merge_with(unknown_shape(rank=rank)) 971 except ValueError: 972 raise ValueError("Shape %s must have rank %d" % (self, rank)) 973 974 def with_rank_at_least(self, rank): 975 """Returns a shape based on `self` with at least the given rank. 976 977 Args: 978 rank: An integer. 979 980 Returns: 981 A shape that is at least as specific as `self` with at least the given 982 rank. 983 984 Raises: 985 ValueError: If `self` does not represent a shape with at least the given 986 `rank`. 987 """ 988 if self.rank is not None and self.rank < rank: 989 raise ValueError("Shape %s must have rank at least %d" % (self, rank)) 990 else: 991 return self 992 993 def with_rank_at_most(self, rank): 994 """Returns a shape based on `self` with at most the given rank. 995 996 Args: 997 rank: An integer. 998 999 Returns: 1000 A shape that is at least as specific as `self` with at most the given 1001 rank. 1002 1003 Raises: 1004 ValueError: If `self` does not represent a shape with at most the given 1005 `rank`. 1006 """ 1007 if self.rank is not None and self.rank > rank: 1008 raise ValueError("Shape %s must have rank at most %d" % (self, rank)) 1009 else: 1010 return self 1011 1012 def is_compatible_with(self, other): 1013 """Returns True iff `self` is compatible with `other`. 1014 1015 Two possibly-partially-defined shapes are compatible if there 1016 exists a fully-defined shape that both shapes can represent. Thus, 1017 compatibility allows the shape inference code to reason about 1018 partially-defined shapes. For example: 1019 1020 * TensorShape(None) is compatible with all shapes. 1021 1022 * TensorShape([None, None]) is compatible with all two-dimensional 1023 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is 1024 not compatible with, for example, TensorShape([None]) or 1025 TensorShape([None, None, None]). 1026 1027 * TensorShape([32, None]) is compatible with all two-dimensional shapes 1028 with size 32 in the 0th dimension, and also TensorShape([None, None]) 1029 and TensorShape(None). It is not compatible with, for example, 1030 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). 1031 1032 * TensorShape([32, 784]) is compatible with itself, and also 1033 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, 1034 None]) and TensorShape(None). It is not compatible with, for example, 1035 TensorShape([32, 1, 784]) or TensorShape([None]). 1036 1037 The compatibility relation is reflexive and symmetric, but not 1038 transitive. For example, TensorShape([32, 784]) is compatible with 1039 TensorShape(None), and TensorShape(None) is compatible with 1040 TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with 1041 TensorShape([4, 4]). 1042 1043 Args: 1044 other: Another TensorShape. 1045 1046 Returns: 1047 True iff `self` is compatible with `other`. 1048 1049 """ 1050 other = as_shape(other) 1051 if self._dims is not None and other.dims is not None: 1052 if self.rank != other.rank: 1053 return False 1054 for x_dim, y_dim in zip(self._dims, other.dims): 1055 if not x_dim.is_compatible_with(y_dim): 1056 return False 1057 return True 1058 1059 def assert_is_compatible_with(self, other): 1060 """Raises exception if `self` and `other` do not represent the same shape. 1061 1062 This method can be used to assert that there exists a shape that both 1063 `self` and `other` represent. 1064 1065 Args: 1066 other: Another TensorShape. 1067 1068 Raises: 1069 ValueError: If `self` and `other` do not represent the same shape. 1070 """ 1071 if not self.is_compatible_with(other): 1072 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 1073 1074 def most_specific_compatible_shape(self, other): 1075 """Returns the most specific TensorShape compatible with `self` and `other`. 1076 1077 * TensorShape([None, 1]) is the most specific TensorShape compatible with 1078 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that 1079 TensorShape(None) is also compatible with above mentioned TensorShapes. 1080 1081 * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with 1082 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more 1083 less specific TensorShapes compatible with above mentioned TensorShapes, 1084 e.g. TensorShape([1, 2, None]), TensorShape(None). 1085 1086 Args: 1087 other: Another `TensorShape`. 1088 1089 Returns: 1090 A `TensorShape` which is the most specific compatible shape of `self` 1091 and `other`. 1092 """ 1093 1094 other = as_shape(other) 1095 if self._dims is None or other.dims is None or self.rank != other.rank: 1096 return unknown_shape() 1097 1098 dims = [(Dimension(None))] * self.rank 1099 for i, (d1, d2) in enumerate(zip(self._dims, other.dims)): 1100 if d1 is not None and d2 is not None and d1 == d2: 1101 dims[i] = d1 1102 return TensorShape(dims) 1103 1104 def is_fully_defined(self): 1105 """Returns True iff `self` is fully defined in every dimension.""" 1106 return (self._dims is not None and all(dim.value is not None 1107 for dim in self._dims)) 1108 1109 def assert_is_fully_defined(self): 1110 """Raises an exception if `self` is not fully defined in every dimension. 1111 1112 Raises: 1113 ValueError: If `self` does not have a known value for every dimension. 1114 """ 1115 if not self.is_fully_defined(): 1116 raise ValueError("Shape %s is not fully defined" % self) 1117 1118 def as_list(self): 1119 """Returns a list of integers or `None` for each dimension. 1120 1121 Returns: 1122 A list of integers or `None` for each dimension. 1123 1124 Raises: 1125 ValueError: If `self` is an unknown shape with an unknown rank. 1126 """ 1127 if self._dims is None: 1128 raise ValueError("as_list() is not defined on an unknown TensorShape.") 1129 return [dim.value for dim in self._dims] 1130 1131 def as_proto(self): 1132 """Returns this shape as a `TensorShapeProto`.""" 1133 if self._dims is None: 1134 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) 1135 else: 1136 return tensor_shape_pb2.TensorShapeProto(dim=[ 1137 tensor_shape_pb2.TensorShapeProto.Dim(size=-1 1138 if d.value is None else d.value) 1139 for d in self._dims 1140 ]) 1141 1142 def __eq__(self, other): 1143 """Returns True if `self` is equivalent to `other`.""" 1144 try: 1145 other = as_shape(other) 1146 except TypeError: 1147 return NotImplemented 1148 return self._dims == other.dims 1149 1150 def __ne__(self, other): 1151 """Returns True if `self` is known to be different from `other`.""" 1152 try: 1153 other = as_shape(other) 1154 except TypeError: 1155 return NotImplemented 1156 if self.rank is None or other.rank is None: 1157 raise ValueError("The inequality of unknown TensorShapes is undefined.") 1158 if self.rank != other.rank: 1159 return True 1160 return self._dims != other.dims 1161 1162 def __reduce__(self): 1163 return TensorShape, (self._dims,) 1164 1165 def __concat__(self, other): 1166 return self.concatenate(other) 1167 1168 1169def as_shape(shape): 1170 """Converts the given object to a TensorShape.""" 1171 if isinstance(shape, TensorShape): 1172 return shape 1173 else: 1174 return TensorShape(shape) 1175 1176 1177def unknown_shape(rank=None, **kwargs): 1178 """Returns an unknown TensorShape, optionally with a known rank. 1179 1180 Args: 1181 rank: (Optional) If specified, the number of dimensions in the shape. 1182 **kwargs: For backwards compatibility. 1183 1184 Returns: 1185 An unknown TensorShape. 1186 1187 Raises: 1188 TypeError: In case of invalid arguments. 1189 """ 1190 if rank is None and "ndims" in kwargs: 1191 rank = kwargs.pop("ndims") 1192 if kwargs: 1193 raise TypeError("Unknown argument: %s" % kwargs) 1194 if rank is None: 1195 return TensorShape(None) 1196 else: 1197 return TensorShape([Dimension(None)] * rank) 1198 1199 1200def scalar(): 1201 """Returns a shape representing a scalar.""" 1202 return TensorShape([]) 1203 1204 1205def vector(length): 1206 """Returns a shape representing a vector. 1207 1208 Args: 1209 length: The length of the vector, which may be None if unknown. 1210 1211 Returns: 1212 A TensorShape representing a vector of the given length. 1213 """ 1214 return TensorShape([length]) 1215 1216 1217def matrix(rows, cols): 1218 """Returns a shape representing a matrix. 1219 1220 Args: 1221 rows: The number of rows in the matrix, which may be None if unknown. 1222 cols: The number of columns in the matrix, which may be None if unknown. 1223 1224 Returns: 1225 A TensorShape representing a matrix of the given size. 1226 """ 1227 return TensorShape([rows, cols]) 1228