1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes for tensor shape inference.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.core.framework import tensor_shape_pb2 21from tensorflow.python.util import compat 22from tensorflow.python.util.tf_export import tf_export 23 24 25@tf_export("Dimension") 26class Dimension(object): 27 """Represents the value of one dimension in a TensorShape.""" 28 29 def __init__(self, value): 30 """Creates a new Dimension with the given value.""" 31 if value is None: 32 self._value = None 33 else: 34 self._value = int(value) 35 if (not isinstance(value, compat.bytes_or_text_types) and 36 self._value != value): 37 raise ValueError("Ambiguous dimension: %s" % value) 38 if self._value < 0: 39 raise ValueError("Dimension %d must be >= 0" % self._value) 40 41 def __repr__(self): 42 return "Dimension(%s)" % repr(self._value) 43 44 def __str__(self): 45 value = self._value 46 return "?" if value is None else str(value) 47 48 def __eq__(self, other): 49 """Returns true if `other` has the same known value as this Dimension.""" 50 try: 51 other = as_dimension(other) 52 except (TypeError, ValueError): 53 return NotImplemented 54 if self._value is None or other.value is None: 55 return None 56 return self._value == other.value 57 58 def __ne__(self, other): 59 """Returns true if `other` has a different known value from `self`.""" 60 try: 61 other = as_dimension(other) 62 except (TypeError, ValueError): 63 return NotImplemented 64 if self._value is None or other.value is None: 65 return None 66 return self._value != other.value 67 68 def __int__(self): 69 return self._value 70 71 # This is needed for Windows. 72 # See https://github.com/tensorflow/tensorflow/pull/9780 73 def __long__(self): 74 return self._value 75 76 def __index__(self): 77 # Allow use in Python 3 range 78 return self._value 79 80 @property 81 def value(self): 82 """The value of this dimension, or None if it is unknown.""" 83 return self._value 84 85 def is_compatible_with(self, other): 86 """Returns true if `other` is compatible with this Dimension. 87 88 Two known Dimensions are compatible if they have the same value. 89 An unknown Dimension is compatible with all other Dimensions. 90 91 Args: 92 other: Another Dimension. 93 94 Returns: 95 True if this Dimension and `other` are compatible. 96 """ 97 other = as_dimension(other) 98 return (self._value is None or other.value is None or 99 self._value == other.value) 100 101 def assert_is_compatible_with(self, other): 102 """Raises an exception if `other` is not compatible with this Dimension. 103 104 Args: 105 other: Another Dimension. 106 107 Raises: 108 ValueError: If `self` and `other` are not compatible (see 109 is_compatible_with). 110 """ 111 if not self.is_compatible_with(other): 112 raise ValueError("Dimensions %s and %s are not compatible" % (self, 113 other)) 114 115 def merge_with(self, other): 116 """Returns a Dimension that combines the information in `self` and `other`. 117 118 Dimensions are combined as follows: 119 120 ```python 121 tf.Dimension(n) .merge_with(tf.Dimension(n)) == tf.Dimension(n) 122 tf.Dimension(n) .merge_with(tf.Dimension(None)) == tf.Dimension(n) 123 tf.Dimension(None).merge_with(tf.Dimension(n)) == tf.Dimension(n) 124 tf.Dimension(None).merge_with(tf.Dimension(None)) == tf.Dimension(None) 125 tf.Dimension(n) .merge_with(tf.Dimension(m)) # raises ValueError for n != m 126 ``` 127 128 Args: 129 other: Another Dimension. 130 131 Returns: 132 A Dimension containing the combined information of `self` and 133 `other`. 134 135 Raises: 136 ValueError: If `self` and `other` are not compatible (see 137 is_compatible_with). 138 """ 139 other = as_dimension(other) 140 self.assert_is_compatible_with(other) 141 if self._value is None: 142 return Dimension(other.value) 143 else: 144 return Dimension(self._value) 145 146 def __add__(self, other): 147 """Returns the sum of `self` and `other`. 148 149 Dimensions are summed as follows: 150 151 ```python 152 tf.Dimension(m) + tf.Dimension(n) == tf.Dimension(m + n) 153 tf.Dimension(m) + tf.Dimension(None) == tf.Dimension(None) 154 tf.Dimension(None) + tf.Dimension(n) == tf.Dimension(None) 155 tf.Dimension(None) + tf.Dimension(None) == tf.Dimension(None) 156 ``` 157 158 Args: 159 other: Another Dimension. 160 161 Returns: 162 A Dimension whose value is the sum of `self` and `other`. 163 """ 164 other = as_dimension(other) 165 if self._value is None or other.value is None: 166 return Dimension(None) 167 else: 168 return Dimension(self._value + other.value) 169 170 def __sub__(self, other): 171 """Returns the subtraction of `other` from `self`. 172 173 Dimensions are subtracted as follows: 174 175 ```python 176 tf.Dimension(m) - tf.Dimension(n) == tf.Dimension(m - n) 177 tf.Dimension(m) - tf.Dimension(None) == tf.Dimension(None) 178 tf.Dimension(None) - tf.Dimension(n) == tf.Dimension(None) 179 tf.Dimension(None) - tf.Dimension(None) == tf.Dimension(None) 180 ``` 181 182 Args: 183 other: Another Dimension. 184 185 Returns: 186 A Dimension whose value is the subtraction of sum of `other` from `self`. 187 """ 188 other = as_dimension(other) 189 if self._value is None or other.value is None: 190 return Dimension(None) 191 else: 192 return Dimension(self._value - other.value) 193 194 def __mul__(self, other): 195 """Returns the product of `self` and `other`. 196 197 Dimensions are summed as follows: 198 199 ```python 200 tf.Dimension(m) * tf.Dimension(n) == tf.Dimension(m * n) 201 tf.Dimension(m) * tf.Dimension(None) == tf.Dimension(None) 202 tf.Dimension(None) * tf.Dimension(n) == tf.Dimension(None) 203 tf.Dimension(None) * tf.Dimension(None) == tf.Dimension(None) 204 ``` 205 206 Args: 207 other: Another Dimension. 208 209 Returns: 210 A Dimension whose value is the product of `self` and `other`. 211 """ 212 other = as_dimension(other) 213 if self._value is None or other.value is None: 214 return Dimension(None) 215 else: 216 return Dimension(self._value * other.value) 217 218 def __floordiv__(self, other): 219 """Returns the quotient of `self` and `other` rounded down. 220 221 Dimensions are divided as follows: 222 223 ```python 224 tf.Dimension(m) // tf.Dimension(n) == tf.Dimension(m // n) 225 tf.Dimension(m) // tf.Dimension(None) == tf.Dimension(None) 226 tf.Dimension(None) // tf.Dimension(n) == tf.Dimension(None) 227 tf.Dimension(None) // tf.Dimension(None) == tf.Dimension(None) 228 ``` 229 230 Args: 231 other: Another `Dimension`. 232 233 Returns: 234 A `Dimension` whose value is the integer quotient of `self` and `other`. 235 """ 236 other = as_dimension(other) 237 if self._value is None or other.value is None: 238 return Dimension(None) 239 else: 240 return Dimension(self._value // other.value) 241 242 def __div__(self, other): 243 """DEPRECATED: Use `__floordiv__` via `x // y` instead. 244 245 This function exists only for backwards compatibility purposes; new code 246 should use `__floordiv__` via the syntax `x // y`. Using `x // y` 247 communicates clearly that the result rounds down, and is forward compatible 248 to Python 3. 249 250 Args: 251 other: Another `Dimension`. 252 253 Returns: 254 A `Dimension` whose value is the integer quotient of `self` and `other`. 255 """ 256 return self // other 257 258 def __mod__(self, other): 259 """Returns `self` modulo `other. 260 261 Dimension moduli are computed as follows: 262 263 ```python 264 tf.Dimension(m) % tf.Dimension(n) == tf.Dimension(m % n) 265 tf.Dimension(m) % tf.Dimension(None) == tf.Dimension(None) 266 tf.Dimension(None) % tf.Dimension(n) == tf.Dimension(None) 267 tf.Dimension(None) % tf.Dimension(None) == tf.Dimension(None) 268 ``` 269 270 Args: 271 other: Another Dimension. 272 273 Returns: 274 A Dimension whose value is `self` modulo `other`. 275 """ 276 other = as_dimension(other) 277 if self._value is None or other.value is None: 278 return Dimension(None) 279 else: 280 return Dimension(self._value % other.value) 281 282 def __lt__(self, other): 283 """Returns True if `self` is known to be less than `other`. 284 285 Dimensions are compared as follows: 286 287 ```python 288 (tf.Dimension(m) < tf.Dimension(n)) == (m < n) 289 (tf.Dimension(m) < tf.Dimension(None)) == None 290 (tf.Dimension(None) < tf.Dimension(n)) == None 291 (tf.Dimension(None) < tf.Dimension(None)) == None 292 ``` 293 294 Args: 295 other: Another Dimension. 296 297 Returns: 298 The value of `self.value < other.value` if both are known, otherwise 299 None. 300 """ 301 other = as_dimension(other) 302 if self._value is None or other.value is None: 303 return None 304 else: 305 return self._value < other.value 306 307 def __le__(self, other): 308 """Returns True if `self` is known to be less than or equal to `other`. 309 310 Dimensions are compared as follows: 311 312 ```python 313 (tf.Dimension(m) <= tf.Dimension(n)) == (m <= n) 314 (tf.Dimension(m) <= tf.Dimension(None)) == None 315 (tf.Dimension(None) <= tf.Dimension(n)) == None 316 (tf.Dimension(None) <= tf.Dimension(None)) == None 317 ``` 318 319 Args: 320 other: Another Dimension. 321 322 Returns: 323 The value of `self.value <= other.value` if both are known, otherwise 324 None. 325 """ 326 other = as_dimension(other) 327 if self._value is None or other.value is None: 328 return None 329 else: 330 return self._value <= other.value 331 332 def __gt__(self, other): 333 """Returns True if `self` is known to be greater than `other`. 334 335 Dimensions are compared as follows: 336 337 ```python 338 (tf.Dimension(m) > tf.Dimension(n)) == (m > n) 339 (tf.Dimension(m) > tf.Dimension(None)) == None 340 (tf.Dimension(None) > tf.Dimension(n)) == None 341 (tf.Dimension(None) > tf.Dimension(None)) == None 342 ``` 343 344 Args: 345 other: Another Dimension. 346 347 Returns: 348 The value of `self.value > other.value` if both are known, otherwise 349 None. 350 """ 351 other = as_dimension(other) 352 if self._value is None or other.value is None: 353 return None 354 else: 355 return self._value > other.value 356 357 def __ge__(self, other): 358 """Returns True if `self` is known to be greater than or equal to `other`. 359 360 Dimensions are compared as follows: 361 362 ```python 363 (tf.Dimension(m) >= tf.Dimension(n)) == (m >= n) 364 (tf.Dimension(m) >= tf.Dimension(None)) == None 365 (tf.Dimension(None) >= tf.Dimension(n)) == None 366 (tf.Dimension(None) >= tf.Dimension(None)) == None 367 ``` 368 369 Args: 370 other: Another Dimension. 371 372 Returns: 373 The value of `self.value >= other.value` if both are known, otherwise 374 None. 375 """ 376 other = as_dimension(other) 377 if self._value is None or other.value is None: 378 return None 379 else: 380 return self._value >= other.value 381 382 383def as_dimension(value): 384 """Converts the given value to a Dimension. 385 386 A Dimension input will be returned unmodified. 387 An input of `None` will be converted to an unknown Dimension. 388 An integer input will be converted to a Dimension with that value. 389 390 Args: 391 value: The value to be converted. 392 393 Returns: 394 A Dimension corresponding to the given value. 395 """ 396 if isinstance(value, Dimension): 397 return value 398 else: 399 return Dimension(value) 400 401 402@tf_export("TensorShape") 403class TensorShape(object): 404 """Represents the shape of a `Tensor`. 405 406 A `TensorShape` represents a possibly-partial shape specification for a 407 `Tensor`. It may be one of the following: 408 409 * *Fully-known shape:* has a known number of dimensions and a known size 410 for each dimension. e.g. `TensorShape([16, 256])` 411 * *Partially-known shape:* has a known number of dimensions, and an unknown 412 size for one or more dimension. e.g. `TensorShape([None, 256])` 413 * *Unknown shape:* has an unknown number of dimensions, and an unknown 414 size in all dimensions. e.g. `TensorShape(None)` 415 416 If a tensor is produced by an operation of type `"Foo"`, its shape 417 may be inferred if there is a registered shape function for 418 `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`} 419 for details of shape functions and how to register them. Alternatively, 420 the shape may be set explicitly using @{tf.Tensor.set_shape}. 421 """ 422 423 def __init__(self, dims): 424 """Creates a new TensorShape with the given dimensions. 425 426 Args: 427 dims: A list of Dimensions, or None if the shape is unspecified. 428 DEPRECATED: A single integer is treated as a singleton list. 429 430 Raises: 431 TypeError: If dims cannot be converted to a list of dimensions. 432 """ 433 # TODO(irving): Eliminate the single integer special case. 434 if dims is None: 435 self._dims = None 436 elif isinstance(dims, compat.bytes_or_text_types): 437 raise TypeError("A string has ambiguous TensorShape, please wrap in a " 438 "list or convert to an int: %s" % dims) 439 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): 440 if dims.unknown_rank: 441 self._dims = None 442 else: 443 self._dims = [ 444 # Protos store variable-size dimensions as -1 445 as_dimension(dim.size if dim.size != -1 else None) 446 for dim in dims.dim 447 ] 448 elif isinstance(dims, TensorShape): 449 self._dims = dims.dims 450 else: 451 try: 452 dims_iter = iter(dims) 453 except TypeError: 454 # Treat as a singleton dimension 455 self._dims = [as_dimension(dims)] 456 else: 457 # Got a list of dimensions 458 self._dims = [as_dimension(d) for d in dims_iter] 459 460 def __repr__(self): 461 return "TensorShape(%r)" % self._dims 462 463 def __str__(self): 464 if self.ndims is None: 465 return "<unknown>" 466 elif self.ndims == 1: 467 return "(%s,)" % self._dims[0] 468 else: 469 return "(%s)" % ", ".join(str(d) for d in self._dims) 470 471 @property 472 def dims(self): 473 """Returns a list of Dimensions, or None if the shape is unspecified.""" 474 return self._dims 475 476 @property 477 def ndims(self): 478 """Returns the rank of this shape, or None if it is unspecified.""" 479 if self._dims is None: 480 return None 481 else: 482 return len(self._dims) 483 484 def __len__(self): 485 """Returns the rank of this shape, or raises ValueError if unspecified.""" 486 if self._dims is None: 487 raise ValueError("Cannot take the length of Shape with unknown rank.") 488 return len(self._dims) 489 490 def __bool__(self): 491 """Returns True if this shape contains non-zero information.""" 492 return self._dims is not None 493 494 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 495 __nonzero__ = __bool__ 496 497 def __iter__(self): 498 """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" 499 if self._dims is None: 500 raise ValueError("Cannot iterate over a shape with unknown rank.") 501 else: 502 return iter(self._dims) 503 504 def __getitem__(self, key): 505 """Returns the value of a dimension or a shape, depending on the key. 506 507 Args: 508 key: If `key` is an integer, returns the dimension at that index; 509 otherwise if `key` is a slice, returns a TensorShape whose 510 dimensions are those selected by the slice from `self`. 511 512 Returns: 513 A dimension if `key` is an integer, or a `TensorShape` if `key` is a 514 slice. 515 516 Raises: 517 ValueError: If `key` is a slice, and any of its elements are negative, or 518 if `self` is completely unknown and the step is set. 519 """ 520 if self._dims is not None: 521 if isinstance(key, slice): 522 return TensorShape(self._dims[key]) 523 else: 524 return self._dims[key] 525 else: 526 if isinstance(key, slice): 527 start = key.start if key.start is not None else 0 528 stop = key.stop 529 530 if key.step is not None: 531 # TODO(mrry): Handle these maybe. 532 raise ValueError("Steps are not yet handled") 533 if stop is None: 534 # NOTE(mrry): This implies that TensorShape(None) is compatible with 535 # TensorShape(None)[1:], which is obviously not true. It would be 536 # possible to track the number of dimensions symbolically, 537 # and perhaps we should do that. 538 return unknown_shape() 539 elif start < 0 or stop < 0: 540 # TODO(mrry): Handle this better, as it will be useful for handling 541 # suffixes of otherwise unknown shapes. 542 return unknown_shape() 543 else: 544 return unknown_shape(ndims=stop - start) 545 else: 546 return Dimension(None) 547 548 def num_elements(self): 549 """Returns the total number of elements, or none for incomplete shapes.""" 550 if self.is_fully_defined(): 551 size = 1 552 for dim in self._dims: 553 size *= dim.value 554 return size 555 else: 556 return None 557 558 def merge_with(self, other): 559 """Returns a `TensorShape` combining the information in `self` and `other`. 560 561 The dimensions in `self` and `other` are merged elementwise, 562 according to the rules defined for `Dimension.merge_with()`. 563 564 Args: 565 other: Another `TensorShape`. 566 567 Returns: 568 A `TensorShape` containing the combined information of `self` and 569 `other`. 570 571 Raises: 572 ValueError: If `self` and `other` are not compatible. 573 """ 574 other = as_shape(other) 575 if self._dims is None: 576 return other 577 else: 578 try: 579 self.assert_same_rank(other) 580 new_dims = [] 581 for i, dim in enumerate(self._dims): 582 new_dims.append(dim.merge_with(other[i])) 583 return TensorShape(new_dims) 584 except ValueError: 585 raise ValueError("Shapes %s and %s are not compatible" % (self, other)) 586 587 def concatenate(self, other): 588 """Returns the concatenation of the dimension in `self` and `other`. 589 590 *N.B.* If either `self` or `other` is completely unknown, 591 concatenation will discard information about the other shape. In 592 future, we might support concatenation that preserves this 593 information for use with slicing. 594 595 Args: 596 other: Another `TensorShape`. 597 598 Returns: 599 A `TensorShape` whose dimensions are the concatenation of the 600 dimensions in `self` and `other`. 601 """ 602 # TODO(mrry): Handle the case where we concatenate a known shape with a 603 # completely unknown shape, so that we can use the partial information. 604 other = as_shape(other) 605 if self._dims is None or other.dims is None: 606 return unknown_shape() 607 else: 608 return TensorShape(self._dims + other.dims) 609 610 def assert_same_rank(self, other): 611 """Raises an exception if `self` and `other` do not have compatible ranks. 612 613 Args: 614 other: Another `TensorShape`. 615 616 Raises: 617 ValueError: If `self` and `other` do not represent shapes with the 618 same rank. 619 """ 620 other = as_shape(other) 621 if self.ndims is not None and other.ndims is not None: 622 if self.ndims != other.ndims: 623 raise ValueError("Shapes %s and %s must have the same rank" % (self, 624 other)) 625 626 def assert_has_rank(self, rank): 627 """Raises an exception if `self` is not compatible with the given `rank`. 628 629 Args: 630 rank: An integer. 631 632 Raises: 633 ValueError: If `self` does not represent a shape with the given `rank`. 634 """ 635 if self.ndims not in (None, rank): 636 raise ValueError("Shape %s must have rank %d" % (self, rank)) 637 638 def with_rank(self, rank): 639 """Returns a shape based on `self` with the given rank. 640 641 This method promotes a completely unknown shape to one with a 642 known rank. 643 644 Args: 645 rank: An integer. 646 647 Returns: 648 A shape that is at least as specific as `self` with the given rank. 649 650 Raises: 651 ValueError: If `self` does not represent a shape with the given `rank`. 652 """ 653 try: 654 return self.merge_with(unknown_shape(ndims=rank)) 655 except ValueError: 656 raise ValueError("Shape %s must have rank %d" % (self, rank)) 657 658 def with_rank_at_least(self, rank): 659 """Returns a shape based on `self` with at least the given rank. 660 661 Args: 662 rank: An integer. 663 664 Returns: 665 A shape that is at least as specific as `self` with at least the given 666 rank. 667 668 Raises: 669 ValueError: If `self` does not represent a shape with at least the given 670 `rank`. 671 """ 672 if self.ndims is not None and self.ndims < rank: 673 raise ValueError("Shape %s must have rank at least %d" % (self, rank)) 674 else: 675 return self 676 677 def with_rank_at_most(self, rank): 678 """Returns a shape based on `self` with at most the given rank. 679 680 Args: 681 rank: An integer. 682 683 Returns: 684 A shape that is at least as specific as `self` with at most the given 685 rank. 686 687 Raises: 688 ValueError: If `self` does not represent a shape with at most the given 689 `rank`. 690 """ 691 if self.ndims is not None and self.ndims > rank: 692 raise ValueError("Shape %s must have rank at most %d" % (self, rank)) 693 else: 694 return self 695 696 def is_compatible_with(self, other): 697 """Returns True iff `self` is compatible with `other`. 698 699 Two possibly-partially-defined shapes are compatible if there 700 exists a fully-defined shape that both shapes can represent. Thus, 701 compatibility allows the shape inference code to reason about 702 partially-defined shapes. For example: 703 704 * TensorShape(None) is compatible with all shapes. 705 706 * TensorShape([None, None]) is compatible with all two-dimensional 707 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is 708 not compatible with, for example, TensorShape([None]) or 709 TensorShape([None, None, None]). 710 711 * TensorShape([32, None]) is compatible with all two-dimensional shapes 712 with size 32 in the 0th dimension, and also TensorShape([None, None]) 713 and TensorShape(None). It is not compatible with, for example, 714 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). 715 716 * TensorShape([32, 784]) is compatible with itself, and also 717 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, 718 None]) and TensorShape(None). It is not compatible with, for example, 719 TensorShape([32, 1, 784]) or TensorShape([None]). 720 721 The compatibility relation is reflexive and symmetric, but not 722 transitive. For example, TensorShape([32, 784]) is compatible with 723 TensorShape(None), and TensorShape(None) is compatible with 724 TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with 725 TensorShape([4, 4]). 726 727 Args: 728 other: Another TensorShape. 729 730 Returns: 731 True iff `self` is compatible with `other`. 732 733 """ 734 other = as_shape(other) 735 if self._dims is not None and other.dims is not None: 736 if self.ndims != other.ndims: 737 return False 738 for x_dim, y_dim in zip(self._dims, other.dims): 739 if not x_dim.is_compatible_with(y_dim): 740 return False 741 return True 742 743 def assert_is_compatible_with(self, other): 744 """Raises exception if `self` and `other` do not represent the same shape. 745 746 This method can be used to assert that there exists a shape that both 747 `self` and `other` represent. 748 749 Args: 750 other: Another TensorShape. 751 752 Raises: 753 ValueError: If `self` and `other` do not represent the same shape. 754 """ 755 if not self.is_compatible_with(other): 756 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 757 758 def most_specific_compatible_shape(self, other): 759 """Returns the most specific TensorShape compatible with `self` and `other`. 760 761 * TensorShape([None, 1]) is the most specific TensorShape compatible with 762 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that 763 TensorShape(None) is also compatible with above mentioned TensorShapes. 764 765 * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with 766 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more 767 less specific TensorShapes compatible with above mentioned TensorShapes, 768 e.g. TensorShape([1, 2, None]), TensorShape(None). 769 770 Args: 771 other: Another `TensorShape`. 772 773 Returns: 774 A `TensorShape` which is the most specific compatible shape of `self` 775 and `other`. 776 """ 777 778 other = as_shape(other) 779 if self._dims is None or other.dims is None or self.ndims != other.ndims: 780 return unknown_shape() 781 782 dims = [(Dimension(None))] * self.ndims 783 for i, (d1, d2) in enumerate(zip(self._dims, other.dims)): 784 if d1 is not None and d2 is not None and d1 == d2: 785 dims[i] = d1 786 return TensorShape(dims) 787 788 def is_fully_defined(self): 789 """Returns True iff `self` is fully defined in every dimension.""" 790 return (self._dims is not None and all(dim.value is not None 791 for dim in self._dims)) 792 793 def assert_is_fully_defined(self): 794 """Raises an exception if `self` is not fully defined in every dimension. 795 796 Raises: 797 ValueError: If `self` does not have a known value for every dimension. 798 """ 799 if not self.is_fully_defined(): 800 raise ValueError("Shape %s is not fully defined" % self) 801 802 def as_list(self): 803 """Returns a list of integers or `None` for each dimension. 804 805 Returns: 806 A list of integers or `None` for each dimension. 807 808 Raises: 809 ValueError: If `self` is an unknown shape with an unknown rank. 810 """ 811 if self._dims is None: 812 raise ValueError("as_list() is not defined on an unknown TensorShape.") 813 return [dim.value for dim in self._dims] 814 815 def as_proto(self): 816 """Returns this shape as a `TensorShapeProto`.""" 817 if self._dims is None: 818 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) 819 else: 820 return tensor_shape_pb2.TensorShapeProto(dim=[ 821 tensor_shape_pb2.TensorShapeProto.Dim(size=-1 822 if d.value is None else d.value) 823 for d in self._dims 824 ]) 825 826 def __eq__(self, other): 827 """Returns True if `self` is equivalent to `other`.""" 828 try: 829 other = as_shape(other) 830 except TypeError: 831 return NotImplemented 832 return self._dims == other.dims 833 834 def __ne__(self, other): 835 """Returns True if `self` is known to be different from `other`.""" 836 try: 837 other = as_shape(other) 838 except TypeError: 839 return NotImplemented 840 if self.ndims is None or other.ndims is None: 841 raise ValueError("The inequality of unknown TensorShapes is undefined.") 842 if self.ndims != other.ndims: 843 return True 844 return self._dims != other.dims 845 846 847def as_shape(shape): 848 """Converts the given object to a TensorShape.""" 849 if isinstance(shape, TensorShape): 850 return shape 851 else: 852 return TensorShape(shape) 853 854 855def unknown_shape(ndims=None): 856 """Returns an unknown TensorShape, optionally with a known rank. 857 858 Args: 859 ndims: (Optional) If specified, the number of dimensions in the shape. 860 861 Returns: 862 An unknown TensorShape. 863 """ 864 if ndims is None: 865 return TensorShape(None) 866 else: 867 return TensorShape([Dimension(None)] * ndims) 868 869 870def scalar(): 871 """Returns a shape representing a scalar.""" 872 return TensorShape([]) 873 874 875def vector(length): 876 """Returns a shape representing a vector. 877 878 Args: 879 length: The length of the vector, which may be None if unknown. 880 881 Returns: 882 A TensorShape representing a vector of the given length. 883 """ 884 return TensorShape([length]) 885 886 887def matrix(rows, cols): 888 """Returns a shape representing a matrix. 889 890 Args: 891 rows: The number of rows in the matrix, which may be None if unknown. 892 cols: The number of columns in the matrix, which may be None if unknown. 893 894 Returns: 895 A TensorShape representing a matrix of the given size. 896 """ 897 return TensorShape([rows, cols]) 898