1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helper classes for tensor shape inference.""" 16import functools 17import operator 18from typing import Optional, Sequence, Type 19 20from tensorflow.core.framework import tensor_shape_pb2 21from tensorflow.core.function import trace_type 22from tensorflow.python import tf2 23from tensorflow.python.eager import monitoring 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.types import trace 26from tensorflow.python.util.tf_export import tf_export 27 28_TENSORSHAPE_V2_OVERRIDE = None 29 30_api_usage_gauge = monitoring.BoolGauge( 31 "/tensorflow/api/v2_tensorshape", 32 "Whether tensor_shape.enable_v2_tensorshape() is called.") 33 34 35@tf_export(v1=["enable_v2_tensorshape"]) 36def enable_v2_tensorshape(): 37 """In TensorFlow 2.0, iterating over a TensorShape instance returns values. 38 39 This enables the new behavior. 40 41 Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but 42 it V2 it returns either an integer, or None. 43 44 Examples: 45 46 ``` 47 ####################### 48 # If you had this in V1: 49 value = tensor_shape[i].value 50 51 # Do this in V2 instead: 52 value = tensor_shape[i] 53 54 ####################### 55 # If you had this in V1: 56 for dim in tensor_shape: 57 value = dim.value 58 print(value) 59 60 # Do this in V2 instead: 61 for value in tensor_shape: 62 print(value) 63 64 ####################### 65 # If you had this in V1: 66 dim = tensor_shape[i] 67 dim.assert_is_compatible_with(other_shape) # or using any other shape method 68 69 # Do this in V2 instead: 70 if tensor_shape.rank is None: 71 dim = Dimension(None) 72 else: 73 dim = tensor_shape.dims[i] 74 dim.assert_is_compatible_with(other_shape) # or using any other shape method 75 76 # The V2 suggestion above is more explicit, which will save you from 77 # the following trap (present in V1): 78 # you might do in-place modifications to `dim` and expect them to be reflected 79 # in `tensor_shape[i]`, but they would not be. 80 ``` 81 """ 82 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 83 _TENSORSHAPE_V2_OVERRIDE = True 84 logging.vlog(1, "Enabling v2 tensorshape") 85 _api_usage_gauge.get_cell().set(True) 86 87 88@tf_export(v1=["disable_v2_tensorshape"]) 89def disable_v2_tensorshape(): 90 """Disables the V2 TensorShape behavior and reverts to V1 behavior. 91 92 See docstring for `enable_v2_tensorshape` for details about the new behavior. 93 """ 94 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 95 _TENSORSHAPE_V2_OVERRIDE = False 96 logging.vlog(1, "Disabling v2 tensorshape") 97 _api_usage_gauge.get_cell().set(False) 98 99 100@tf_export( 101 "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"]) 102def dimension_value(dimension): 103 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 104 105 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 106 coexist with the new behavior. This utility is a bridge between the two. 107 108 When accessing the value of a TensorShape dimension, 109 use this utility, like this: 110 111 ``` 112 # If you had this in your V1 code: 113 value = tensor_shape[i].value 114 115 # Use `dimension_value` as direct replacement compatible with both V1 & V2: 116 value = dimension_value(tensor_shape[i]) 117 118 # This would be the V2 equivalent: 119 value = tensor_shape[i] # Warning: this will return the dim value in V2! 120 ``` 121 122 Args: 123 dimension: Either a `Dimension` instance, an integer, or None. 124 125 Returns: 126 A plain value, i.e. an integer or None. 127 """ 128 if isinstance(dimension, Dimension): 129 return dimension.value 130 return dimension 131 132 133@tf_export( 134 "compat.dimension_at_index", 135 v1=["dimension_at_index", "compat.dimension_at_index"]) 136def dimension_at_index(shape, index): 137 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 138 139 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 140 coexist with the new behavior. This utility is a bridge between the two. 141 142 If you want to retrieve the Dimension instance corresponding to a certain 143 index in a TensorShape instance, use this utility, like this: 144 145 ``` 146 # If you had this in your V1 code: 147 dim = tensor_shape[i] 148 149 # Use `dimension_at_index` as direct replacement compatible with both V1 & V2: 150 dim = dimension_at_index(tensor_shape, i) 151 152 # Another possibility would be this, but WARNING: it only works if the 153 # tensor_shape instance has a defined rank. 154 dim = tensor_shape.dims[i] # `dims` may be None if the rank is undefined! 155 156 # In native V2 code, we recommend instead being more explicit: 157 if tensor_shape.rank is None: 158 dim = Dimension(None) 159 else: 160 dim = tensor_shape.dims[i] 161 162 # Being more explicit will save you from the following trap (present in V1): 163 # you might do in-place modifications to `dim` and expect them to be reflected 164 # in `tensor_shape[i]`, but they would not be (as the Dimension object was 165 # instantiated on the fly. 166 ``` 167 168 Args: 169 shape: A TensorShape instance. 170 index: An integer index. 171 172 Returns: 173 A dimension object. 174 """ 175 assert isinstance(shape, TensorShape) 176 if shape.rank is None: 177 return Dimension(None) 178 else: 179 return shape.dims[index] 180 181 182@tf_export(v1=["Dimension"]) 183class Dimension(object): 184 """Represents the value of one dimension in a TensorShape. 185 186 @compatibility(TF2) 187 In TF2, members of a `TensorShape` object are integers. The `Dimension` class 188 is not part of TF2's data model. 189 190 Please refer to the [TensorShape section of the migration guide] 191 (https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code 192 patterns adapting Dimension objects to a TF2 syntax. 193 @end_compatibility 194 """ 195 196 __slots__ = ["_value"] 197 198 def __init__(self, value): 199 """Creates a new Dimension with the given value.""" 200 if isinstance(value, int): # Most common case. 201 if value < 0: 202 raise ValueError("Dimension %d must be >= 0" % value) 203 self._value = value 204 elif value is None: 205 self._value = None 206 elif isinstance(value, Dimension): 207 self._value = value._value 208 else: 209 try: 210 # int(...) compensates for the int/long dichotomy on Python 2.X. 211 # TODO(b/143206389): Remove once we fully migrate to 3.X. 212 self._value = int(value.__index__()) 213 except AttributeError: 214 raise TypeError( 215 "Dimension value must be integer or None or have " 216 "an __index__ method, got value '{0!r}' with type '{1!r}'".format( 217 value, type(value))) from None 218 if self._value < 0: 219 raise ValueError("Dimension %d must be >= 0" % self._value) 220 221 def __repr__(self): 222 return "Dimension(%s)" % repr(self._value) 223 224 def __str__(self): 225 value = self._value 226 return "?" if value is None else str(value) 227 228 def __eq__(self, other): 229 """Returns true if `other` has the same known value as this Dimension.""" 230 try: 231 other = as_dimension(other) 232 except (TypeError, ValueError): 233 return NotImplemented 234 if self._value is None or other.value is None: 235 return None 236 return self._value == other.value 237 238 def __ne__(self, other): 239 """Returns true if `other` has a different known value from `self`.""" 240 try: 241 other = as_dimension(other) 242 except (TypeError, ValueError): 243 return NotImplemented 244 if self._value is None or other.value is None: 245 return None 246 return self._value != other.value 247 248 def __bool__(self): 249 """Equivalent to `bool(self.value)`.""" 250 return bool(self._value) 251 252 def __int__(self): 253 return self._value 254 255 # This is needed for Windows. 256 # See https://github.com/tensorflow/tensorflow/pull/9780 257 def __long__(self): 258 return self._value 259 260 def __index__(self): 261 # Allow use in Python 3 range 262 return self._value 263 264 @property 265 def value(self): 266 """The value of this dimension, or None if it is unknown.""" 267 return self._value 268 269 # TODO(b/225058047): Reconsider semantics. 270 def is_compatible_with(self, other): 271 """Returns true if `other` is compatible with this Dimension. 272 273 Two known Dimensions are compatible if they have the same value. 274 An unknown Dimension is compatible with all other Dimensions. 275 276 Args: 277 other: Another Dimension. 278 279 Returns: 280 True if this Dimension and `other` are compatible. 281 """ 282 other = as_dimension(other) 283 return (self._value is None or other.value is None or 284 self._value == other.value) 285 286 def assert_is_compatible_with(self, other): 287 """Raises an exception if `other` is not compatible with this Dimension. 288 289 Args: 290 other: Another Dimension. 291 292 Raises: 293 ValueError: If `self` and `other` are not compatible (see 294 is_compatible_with). 295 """ 296 if not self.is_compatible_with(other): 297 raise ValueError("Dimensions %s and %s are not compatible" % 298 (self, other)) 299 300 def merge_with(self, other): 301 """Returns a Dimension that combines the information in `self` and `other`. 302 303 Dimensions are combined as follows: 304 305 ```python 306 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(n)) == 307 tf.compat.v1.Dimension(n) 308 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(None)) == 309 tf.compat.v1.Dimension(n) 310 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n)) == 311 tf.compat.v1.Dimension(n) 312 # equivalent to tf.compat.v1.Dimension(None) 313 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None)) 314 315 # raises ValueError for n != m 316 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(m)) 317 ``` 318 319 Args: 320 other: Another Dimension. 321 322 Returns: 323 A Dimension containing the combined information of `self` and 324 `other`. 325 326 Raises: 327 ValueError: If `self` and `other` are not compatible (see 328 is_compatible_with). 329 """ 330 other = as_dimension(other) 331 self.assert_is_compatible_with(other) 332 if self._value is None: 333 return Dimension(other.value) 334 else: 335 return Dimension(self._value) 336 337 def __add__(self, other): 338 """Returns the sum of `self` and `other`. 339 340 Dimensions are summed as follows: 341 342 ```python 343 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(n) == 344 tf.compat.v1.Dimension(m + n) 345 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(None) # equiv. to 346 tf.compat.v1.Dimension(None) 347 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n) # equiv. to 348 tf.compat.v1.Dimension(None) 349 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None) # equiv. to 350 tf.compat.v1.Dimension(None) 351 ``` 352 353 Args: 354 other: Another Dimension, or a value accepted by `as_dimension`. 355 356 Returns: 357 A Dimension whose value is the sum of `self` and `other`. 358 """ 359 try: 360 other = as_dimension(other) 361 except (TypeError, ValueError): 362 return NotImplemented 363 if self._value is None or other.value is None: 364 return Dimension(None) 365 else: 366 return Dimension(self._value + other.value) 367 368 def __radd__(self, other): 369 """Returns the sum of `other` and `self`. 370 371 Args: 372 other: Another Dimension, or a value accepted by `as_dimension`. 373 374 Returns: 375 A Dimension whose value is the sum of `self` and `other`. 376 """ 377 return self + other 378 379 def __sub__(self, other): 380 """Returns the subtraction of `other` from `self`. 381 382 Dimensions are subtracted as follows: 383 384 ```python 385 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(n) == 386 tf.compat.v1.Dimension(m - n) 387 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(None) # equiv. to 388 tf.compat.v1.Dimension(None) 389 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n) # equiv. to 390 tf.compat.v1.Dimension(None) 391 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None) # equiv. to 392 tf.compat.v1.Dimension(None) 393 ``` 394 395 Args: 396 other: Another Dimension, or a value accepted by `as_dimension`. 397 398 Returns: 399 A Dimension whose value is the subtraction of `other` from `self`. 400 """ 401 try: 402 other = as_dimension(other) 403 except (TypeError, ValueError): 404 return NotImplemented 405 if self._value is None or other.value is None: 406 return Dimension(None) 407 else: 408 return Dimension(self._value - other.value) 409 410 def __rsub__(self, other): 411 """Returns the subtraction of `self` from `other`. 412 413 Args: 414 other: Another Dimension, or a value accepted by `as_dimension`. 415 416 Returns: 417 A Dimension whose value is the subtraction of `self` from `other`. 418 """ 419 other = as_dimension(other) 420 if self._value is None or other.value is None: 421 return Dimension(None) 422 else: 423 return Dimension(other.value - self._value) 424 425 def __mul__(self, other): 426 """Returns the product of `self` and `other`. 427 428 Dimensions are summed as follows: 429 430 ```python 431 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(n) == 432 tf.compat.v1.Dimension(m * n) 433 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(None) # equiv. to 434 tf.compat.v1.Dimension(None) 435 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n) # equiv. to 436 tf.compat.v1.Dimension(None) 437 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None) # equiv. to 438 tf.compat.v1.Dimension(None) 439 ``` 440 441 Args: 442 other: Another Dimension, or a value accepted by `as_dimension`. 443 444 Returns: 445 A Dimension whose value is the product of `self` and `other`. 446 """ 447 try: 448 other = as_dimension(other) 449 except (TypeError, ValueError): 450 return NotImplemented 451 452 if self._value is None or other.value is None: 453 return Dimension(None) 454 else: 455 return Dimension(self._value * other.value) 456 457 def __rmul__(self, other): 458 """Returns the product of `self` and `other`. 459 460 Args: 461 other: Another Dimension, or a value accepted by `as_dimension`. 462 463 Returns: 464 A Dimension whose value is the product of `self` and `other`. 465 """ 466 return self * other 467 468 def __floordiv__(self, other): 469 """Returns the quotient of `self` and `other` rounded down. 470 471 Dimensions are divided as follows: 472 473 ```python 474 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(n) == 475 tf.compat.v1.Dimension(m // n) 476 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(None) # equiv. to 477 tf.compat.v1.Dimension(None) 478 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n) # equiv. to 479 tf.compat.v1.Dimension(None) 480 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None) # equiv. to 481 tf.compat.v1.Dimension(None) 482 ``` 483 484 Args: 485 other: Another Dimension, or a value accepted by `as_dimension`. 486 487 Returns: 488 A `Dimension` whose value is the integer quotient of `self` and `other`. 489 """ 490 try: 491 other = as_dimension(other) 492 except (TypeError, ValueError): 493 return NotImplemented 494 if self._value is None or other.value is None: 495 return Dimension(None) 496 else: 497 return Dimension(self._value // other.value) 498 499 def __rfloordiv__(self, other): 500 """Returns the quotient of `other` and `self` rounded down. 501 502 Args: 503 other: Another Dimension, or a value accepted by `as_dimension`. 504 505 Returns: 506 A `Dimension` whose value is the integer quotient of `self` and `other`. 507 """ 508 other = as_dimension(other) 509 if self._value is None or other.value is None: 510 return Dimension(None) 511 else: 512 return Dimension(other.value // self._value) 513 514 def __div__(self, other): 515 """DEPRECATED: Use `__floordiv__` via `x // y` instead. 516 517 This function exists only for backwards compatibility purposes; new code 518 should use `__floordiv__` via the syntax `x // y`. Using `x // y` 519 communicates clearly that the result rounds down, and is forward compatible 520 to Python 3. 521 522 Args: 523 other: Another `Dimension`. 524 525 Returns: 526 A `Dimension` whose value is the integer quotient of `self` and `other`. 527 """ 528 return self // other 529 530 def __rdiv__(self, other): 531 """Use `__floordiv__` via `x // y` instead. 532 533 This function exists only to have a better error message. Instead of: 534 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 535 this function will explicitly call for usage of `//` instead. 536 537 Args: 538 other: Another `Dimension`. 539 540 Raises: 541 TypeError. 542 """ 543 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 544 "please use // instead".format(type(other).__name__)) 545 546 def __truediv__(self, other): 547 """Use `__floordiv__` via `x // y` instead. 548 549 This function exists only to have a better error message. Instead of: 550 `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`, 551 this function will explicitly call for usage of `//` instead. 552 553 Args: 554 other: Another `Dimension`. 555 556 Raises: 557 TypeError. 558 """ 559 raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', " 560 "please use // instead".format(type(other).__name__)) 561 562 def __rtruediv__(self, other): 563 """Use `__floordiv__` via `x // y` instead. 564 565 This function exists only to have a better error message. Instead of: 566 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 567 this function will explicitly call for usage of `//` instead. 568 569 Args: 570 other: Another `Dimension`. 571 572 Raises: 573 TypeError. 574 """ 575 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 576 "please use // instead".format(type(other).__name__)) 577 578 def __mod__(self, other): 579 """Returns `self` modulo `other`. 580 581 Dimension modulo are computed as follows: 582 583 ```python 584 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(n) == 585 tf.compat.v1.Dimension(m % n) 586 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(None) # equiv. to 587 tf.compat.v1.Dimension(None) 588 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n) # equiv. to 589 tf.compat.v1.Dimension(None) 590 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None) # equiv. to 591 tf.compat.v1.Dimension(None) 592 ``` 593 594 Args: 595 other: Another Dimension, or a value accepted by `as_dimension`. 596 597 Returns: 598 A Dimension whose value is `self` modulo `other`. 599 """ 600 other = as_dimension(other) 601 if self._value is None or other.value is None: 602 return Dimension(None) 603 else: 604 return Dimension(self._value % other.value) 605 606 def __rmod__(self, other): 607 """Returns `other` modulo `self`. 608 609 Args: 610 other: Another Dimension, or a value accepted by `as_dimension`. 611 612 Returns: 613 A Dimension whose value is `other` modulo `self`. 614 """ 615 other = as_dimension(other) 616 return other % self 617 618 def __lt__(self, other): 619 """Returns True if `self` is known to be less than `other`. 620 621 Dimensions are compared as follows: 622 623 ```python 624 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(n)) == (m < n) 625 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(None)) == None 626 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n)) == None 627 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None 628 ``` 629 630 Args: 631 other: Another Dimension. 632 633 Returns: 634 The value of `self.value < other.value` if both are known, otherwise 635 None. 636 """ 637 other = as_dimension(other) 638 if self._value is None or other.value is None: 639 return None 640 else: 641 return self._value < other.value 642 643 def __le__(self, other): 644 """Returns True if `self` is known to be less than or equal to `other`. 645 646 Dimensions are compared as follows: 647 648 ```python 649 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(n)) == (m <= n) 650 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(None)) == None 651 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n)) == None 652 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None 653 ``` 654 655 Args: 656 other: Another Dimension. 657 658 Returns: 659 The value of `self.value <= other.value` if both are known, otherwise 660 None. 661 """ 662 other = as_dimension(other) 663 if self._value is None or other.value is None: 664 return None 665 else: 666 return self._value <= other.value 667 668 def __gt__(self, other): 669 """Returns True if `self` is known to be greater than `other`. 670 671 Dimensions are compared as follows: 672 673 ```python 674 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(n)) == (m > n) 675 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(None)) == None 676 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n)) == None 677 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None 678 ``` 679 680 Args: 681 other: Another Dimension. 682 683 Returns: 684 The value of `self.value > other.value` if both are known, otherwise 685 None. 686 """ 687 other = as_dimension(other) 688 if self._value is None or other.value is None: 689 return None 690 else: 691 return self._value > other.value 692 693 def __ge__(self, other): 694 """Returns True if `self` is known to be greater than or equal to `other`. 695 696 Dimensions are compared as follows: 697 698 ```python 699 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(n)) == (m >= n) 700 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(None)) == None 701 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n)) == None 702 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None 703 ``` 704 705 Args: 706 other: Another Dimension. 707 708 Returns: 709 The value of `self.value >= other.value` if both are known, otherwise 710 None. 711 """ 712 other = as_dimension(other) 713 if self._value is None or other.value is None: 714 return None 715 else: 716 return self._value >= other.value 717 718 def __reduce__(self): 719 return Dimension, (self._value,) 720 721 722def as_dimension(value): 723 """Converts the given value to a Dimension. 724 725 A Dimension input will be returned unmodified. 726 An input of `None` will be converted to an unknown Dimension. 727 An integer input will be converted to a Dimension with that value. 728 729 Args: 730 value: The value to be converted. 731 732 Returns: 733 A Dimension corresponding to the given value. 734 """ 735 if isinstance(value, Dimension): 736 return value 737 else: 738 return Dimension(value) 739 740 741@tf_export("TensorShape") 742class TensorShape(trace.TraceType, trace_type.Serializable): 743 """Represents the shape of a `Tensor`. 744 745 A `TensorShape` represents a possibly-partial shape specification for a 746 `Tensor`. It may be one of the following: 747 748 * *Fully-known shape:* has a known number of dimensions and a known size 749 for each dimension. e.g. `TensorShape([16, 256])` 750 * *Partially-known shape:* has a known number of dimensions, and an unknown 751 size for one or more dimension. e.g. `TensorShape([None, 256])` 752 * *Unknown shape:* has an unknown number of dimensions, and an unknown 753 size in all dimensions. e.g. `TensorShape(None)` 754 755 If a tensor is produced by an operation of type `"Foo"`, its shape 756 may be inferred if there is a registered shape function for 757 `"Foo"`. See [Shape 758 functions](https://www.tensorflow.org/guide/create_op#shape_functions_in_c) 759 for details of shape functions and how to register them. Alternatively, 760 you may set the shape explicitly using `tf.Tensor.set_shape`. 761 """ 762 __slots__ = ["_dims"] 763 764 def __init__(self, dims): 765 """Creates a new TensorShape with the given dimensions. 766 767 Args: 768 dims: A list of Dimensions, or None if the shape is unspecified. 769 770 Raises: 771 TypeError: If dims cannot be converted to a list of dimensions. 772 """ 773 if isinstance(dims, (tuple, list)): # Most common case. 774 self._dims = tuple(as_dimension(d).value for d in dims) 775 elif dims is None: 776 self._dims = None 777 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): 778 if dims.unknown_rank: 779 self._dims = None 780 else: 781 self._dims = tuple( 782 # Protos store variable-size dimensions as -1 783 dim.size if dim.size != -1 else None 784 for dim in dims.dim 785 ) 786 elif isinstance(dims, TensorShape): 787 self._dims = dims._dims 788 else: 789 try: 790 dims_iter = iter(dims) 791 except TypeError: 792 # Treat as a singleton dimension 793 self._dims = (as_dimension(dims).value,) 794 else: 795 self._dims = [] 796 for d in dims_iter: 797 try: 798 self._dims.append(as_dimension(d).value) 799 except TypeError as e: 800 raise TypeError( 801 "Failed to convert '{0!r}' to a shape: '{1!r}'" 802 "could not be converted to a dimension. A shape should " 803 "either be single dimension (e.g. 10), or an iterable of " 804 "dimensions (e.g. [1, 10, None]).".format(dims, d)) from e 805 self._dims = tuple(self._dims) 806 807 @property 808 def _v2_behavior(self): 809 if _TENSORSHAPE_V2_OVERRIDE is None: 810 return tf2.enabled() 811 return _TENSORSHAPE_V2_OVERRIDE 812 813 def __repr__(self): 814 if self._v2_behavior: 815 if self._dims is not None: 816 return f"TensorShape({list(self._dims)})" 817 else: 818 return "TensorShape(None)" 819 else: 820 return f"TensorShape({self.dims})" 821 822 def __str__(self): 823 if self.rank is None: 824 return "<unknown>" 825 elif self.rank == 1: 826 if self._v2_behavior: 827 return "(%s,)" % self._dims[0] 828 else: 829 return "(%s,)" % self.dims[0] 830 else: 831 if self._v2_behavior: 832 return "(%s)" % ", ".join(str(d) for d in self._dims) 833 else: 834 return "(%s)" % ", ".join(str(d) for d in self.dims) 835 836 @property 837 def rank(self): 838 """Returns the rank of this shape, or None if it is unspecified.""" 839 if self._dims is not None: 840 return len(self._dims) 841 return None 842 843 @property 844 def dims(self): 845 """Deprecated. Returns list of dimensions for this shape. 846 847 Suggest `TensorShape.as_list` instead. 848 849 Returns: 850 A list containing `tf.compat.v1.Dimension`s, or None if the shape is 851 unspecified. 852 """ 853 if self._dims is None: 854 return None 855 return [as_dimension(d) for d in self._dims] 856 857 @property 858 def ndims(self): 859 """Deprecated accessor for `rank`.""" 860 return self.rank 861 862 def __len__(self): 863 """Returns the rank of this shape, or raises ValueError if unspecified.""" 864 if self._dims is None: 865 raise ValueError("Cannot take the length of shape with unknown rank.") 866 return len(self._dims) 867 868 def __bool__(self): 869 """Returns True if this shape contains non-zero information.""" 870 return self._dims is not None 871 872 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 873 __nonzero__ = __bool__ 874 875 def __iter__(self): 876 """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" 877 if self._dims is None: 878 raise ValueError("Cannot iterate over a shape with unknown rank.") 879 else: 880 if self._v2_behavior: 881 return iter(d for d in self._dims) 882 else: 883 return iter(d for d in self.dims) 884 885 def __getitem__(self, key): 886 """Returns the value of a dimension or a shape, depending on the key. 887 888 Args: 889 key: If `key` is an integer, returns the dimension at that index; 890 otherwise if `key` is a slice, returns a TensorShape whose dimensions 891 are those selected by the slice from `self`. 892 893 Returns: 894 An integer if `key` is an integer, or a `TensorShape` if `key` is a 895 slice. 896 897 Raises: 898 ValueError: If `key` is a slice and `self` is completely unknown and 899 the step is set. 900 """ 901 if self._dims is not None: 902 if isinstance(key, slice): 903 return TensorShape(self._dims[key]) 904 else: 905 if self._v2_behavior: 906 return self._dims[key] 907 else: 908 return self.dims[key] 909 else: 910 if isinstance(key, slice): 911 start = key.start if key.start is not None else 0 912 stop = key.stop 913 914 if key.step is not None: 915 # TODO(mrry): Handle these maybe. 916 raise ValueError("Steps are not yet handled") 917 if stop is None: 918 # NOTE(mrry): This implies that TensorShape(None) is compatible with 919 # TensorShape(None)[1:], which is obviously not true. It would be 920 # possible to track the number of dimensions symbolically, 921 # and perhaps we should do that. 922 return unknown_shape() 923 elif start < 0 or stop < 0: 924 # TODO(mrry): Handle this better, as it will be useful for handling 925 # suffixes of otherwise unknown shapes. 926 return unknown_shape() 927 else: 928 return unknown_shape(rank=stop - start) 929 else: 930 if self._v2_behavior: 931 return None 932 else: 933 return Dimension(None) 934 935 def num_elements(self): 936 """Returns the total number of elements, or none for incomplete shapes.""" 937 if self.is_fully_defined(): 938 return functools.reduce(operator.mul, self.as_list(), 1) 939 else: 940 return None 941 942 def merge_with(self, other): 943 """Returns a `TensorShape` combining the information in `self` and `other`. 944 945 The dimensions in `self` and `other` are merged element-wise, 946 according to the rules below: 947 948 ```python 949 Dimension(n).merge_with(Dimension(None)) == Dimension(n) 950 Dimension(None).merge_with(Dimension(n)) == Dimension(n) 951 Dimension(None).merge_with(Dimension(None)) == Dimension(None) 952 # raises ValueError for n != m 953 Dimension(n).merge_with(Dimension(m)) 954 ``` 955 >> ts = tf.TensorShape([1,2]) 956 >> ot1 = tf.TensorShape([1,2]) 957 >> ts.merge_with(ot).as_list() 958 [1,2] 959 960 >> ot2 = tf.TensorShape([1,None]) 961 >> ts.merge_with(ot2).as_list() 962 [1,2] 963 964 >> ot3 = tf.TensorShape([None, None]) 965 >> ot3.merge_with(ot2).as_list() 966 [1, None] 967 968 Args: 969 other: Another `TensorShape`. 970 971 Returns: 972 A `TensorShape` containing the combined information of `self` and 973 `other`. 974 975 Raises: 976 ValueError: If `self` and `other` are not compatible. 977 """ 978 other = as_shape(other) 979 if self.dims is None: 980 return other 981 if other.dims is None: 982 return self 983 else: 984 try: 985 self.assert_same_rank(other) 986 new_dims = [ 987 dim.merge_with(other_dim) 988 for dim, other_dim in zip(self.dims, other.dims) 989 ] 990 return TensorShape(new_dims) 991 except ValueError: 992 raise ValueError("Shapes %s and %s are not compatible" % (self, other)) 993 994 def __add__(self, other): 995 return self.concatenate(other) 996 997 def __radd__(self, other): 998 if not isinstance(other, TensorShape): 999 other = TensorShape(other) 1000 return other.concatenate(self) 1001 1002 def concatenate(self, other): 1003 """Returns the concatenation of the dimension in `self` and `other`. 1004 1005 *N.B.* If either `self` or `other` is completely unknown, 1006 concatenation will discard information about the other shape. In 1007 future, we might support concatenation that preserves this 1008 information for use with slicing. 1009 1010 Args: 1011 other: Another `TensorShape`. 1012 1013 Returns: 1014 A `TensorShape` whose dimensions are the concatenation of the 1015 dimensions in `self` and `other`. 1016 """ 1017 # TODO(mrry): Handle the case where we concatenate a known shape with a 1018 # completely unknown shape, so that we can use the partial information. 1019 other = as_shape(other) 1020 if self.dims is None or other.dims is None: 1021 return unknown_shape() 1022 else: 1023 return TensorShape(self.dims + other.dims) 1024 1025 def assert_same_rank(self, other): 1026 """Raises an exception if `self` and `other` do not have compatible ranks. 1027 1028 Args: 1029 other: Another `TensorShape`. 1030 1031 Raises: 1032 ValueError: If `self` and `other` do not represent shapes with the 1033 same rank. 1034 """ 1035 other = as_shape(other) 1036 if self.rank is not None and other.rank is not None: 1037 if self.rank != other.rank: 1038 raise ValueError("Shapes %s and %s must have the same rank" % 1039 (self, other)) 1040 1041 def assert_has_rank(self, rank): 1042 """Raises an exception if `self` is not compatible with the given `rank`. 1043 1044 Args: 1045 rank: An integer. 1046 1047 Raises: 1048 ValueError: If `self` does not represent a shape with the given `rank`. 1049 """ 1050 if self.rank not in (None, rank): 1051 raise ValueError("Shape %s must have rank %d" % (self, rank)) 1052 1053 def with_rank(self, rank): 1054 """Returns a shape based on `self` with the given rank. 1055 1056 This method promotes a completely unknown shape to one with a 1057 known rank. 1058 1059 Args: 1060 rank: An integer. 1061 1062 Returns: 1063 A shape that is at least as specific as `self` with the given rank. 1064 1065 Raises: 1066 ValueError: If `self` does not represent a shape with the given `rank`. 1067 """ 1068 try: 1069 return self.merge_with(unknown_shape(rank=rank)) 1070 except ValueError: 1071 raise ValueError("Shape %s must have rank %d" % (self, rank)) 1072 1073 def with_rank_at_least(self, rank): 1074 """Returns a shape based on `self` with at least the given rank. 1075 1076 Args: 1077 rank: An integer. 1078 1079 Returns: 1080 A shape that is at least as specific as `self` with at least the given 1081 rank. 1082 1083 Raises: 1084 ValueError: If `self` does not represent a shape with at least the given 1085 `rank`. 1086 """ 1087 if self.rank is not None and self.rank < rank: 1088 raise ValueError("Shape %s must have rank at least %d" % (self, rank)) 1089 else: 1090 return self 1091 1092 def with_rank_at_most(self, rank): 1093 """Returns a shape based on `self` with at most the given rank. 1094 1095 Args: 1096 rank: An integer. 1097 1098 Returns: 1099 A shape that is at least as specific as `self` with at most the given 1100 rank. 1101 1102 Raises: 1103 ValueError: If `self` does not represent a shape with at most the given 1104 `rank`. 1105 """ 1106 if self.rank is not None and self.rank > rank: 1107 raise ValueError("Shape %s must have rank at most %d" % (self, rank)) 1108 else: 1109 return self 1110 1111 def is_subtype_of(self, other: trace.TraceType) -> bool: 1112 """Returns True iff `self` is subtype of `other`. 1113 1114 Shape A is a subtype of shape B if shape B can successfully represent it: 1115 1116 * A `TensorShape` of any rank is a subtype of `TensorShape(None)`. 1117 1118 * TensorShapes of equal ranks are covariant, i.e. 1119 `TensorShape([A1, A2, ..])` is a subtype of 1120 `TensorShape([B1, B2, ..])` iff An is a subtype of Bn. 1121 1122 An is subtype of Bn iff An == Bn or Bn is None. 1123 1124 * TensorShapes of different defined ranks have no subtyping relation. 1125 1126 The subtyping relation is reflexive and transitive, but not symmetric. 1127 1128 Some examples: 1129 * `TensorShape([32, 784])` is a subtype of `TensorShape(None)`, and 1130 `TensorShape([4, 4])` is also a subtype of `TensorShape(None)` but 1131 `TensorShape([32, 784])` and `TensorShape([4, 4])` are not subtypes of 1132 each other. 1133 1134 * All two-dimensional shapes are subtypes of `TensorShape([None, None])`, 1135 such as `TensorShape([32, 784])`. There is no subtype relationship with, 1136 for example, `TensorShape([None])` or `TensorShape([None, None, None])`. 1137 1138 * `TensorShape([32, None])` is also a subtype of `TensorShape([None, None])` 1139 and `TensorShape(None)`. It is not a subtype of, for example, 1140 `TensorShape([32])`, `TensorShape([32, None, 1])`, 1141 `TensorShape([64, None])` or `TensorShape([None, 32])`. 1142 1143 * `TensorShape([32, 784])` is a subtype of itself, and also 1144 `TensorShape([32, None])`, `TensorShape([None, 784])`, 1145 `TensorShape([None, None])` and `TensorShape(None)`. 1146 It has no subtype relation with, for example, `TensorShape([32, 1, 784])` 1147 or `TensorShape([None])`. 1148 1149 Args: 1150 other: Another `TensorShape`. 1151 1152 Returns: 1153 True iff `self` is subtype of `other`. 1154 1155 """ 1156 if not isinstance(other, TensorShape): 1157 return False 1158 1159 # All Tensors are subtypes of a Tensor with no shape. 1160 if other.rank is None: 1161 return True 1162 1163 # Tensor with a defined shape can only be subtype of another with a defined 1164 # shape if they have the same number of dimensions. 1165 if self.rank != other.rank: 1166 return False 1167 1168 # A Tensor is a subtype if each corresponding dimension is a subtype. 1169 return all(o is None or s == o for s, o in zip(self._dims, other._dims)) # pylint: disable=protected-access 1170 1171 def most_specific_common_supertype( 1172 self, others: Sequence[trace.TraceType]) -> Optional["TensorShape"]: 1173 """Returns the most specific supertype `TensorShape` of self and others. 1174 1175 * `TensorShape([None, 1])` is the most specific `TensorShape` supertyping 1176 both `TensorShape([2, 1])` and `TensorShape([5, 1])`. Note that 1177 `TensorShape(None)` is also a supertype but it is not "most specific". 1178 1179 * `TensorShape([1, 2, 3])` is the most specific `TensorShape` supertyping 1180 both `TensorShape([1, 2, 3])` and `TensorShape([1, 2, 3]`). There are 1181 other less specific TensorShapes that supertype above mentioned 1182 TensorShapes, e.g. `TensorShape([1, 2, None])`, `TensorShape(None)`. 1183 1184 * `TensorShape([None, None])` is the most specific `TensorShape` 1185 supertyping both `TensorShape([2, None])` and `TensorShape([None, 3])`. 1186 As always, `TensorShape(None)` is also a supertype but not the most 1187 specific one. 1188 1189 * `TensorShape(None`) is the only `TensorShape` supertyping both 1190 `TensorShape([1, 2, 3])` and `TensorShape([1, 2])`. In general, any two 1191 shapes that have different ranks will only have `TensorShape(None)` 1192 as a common supertype. 1193 1194 * `TensorShape(None)` is the only `TensorShape` supertyping both 1195 `TensorShape([1, 2, 3])` and `TensorShape(None)`. In general, the common 1196 supertype of any shape with `TensorShape(None)` is `TensorShape(None)`. 1197 1198 Args: 1199 others: Sequence of `TensorShape`. 1200 1201 Returns: 1202 A `TensorShape` which is the most specific supertype shape of `self` 1203 and `others`. None if it does not exist. 1204 """ 1205 if any(not isinstance(other, TensorShape) for other in others): 1206 return None 1207 1208 # A Rankless TensorShape is already a global supertype so we return another 1209 # instance of it. 1210 if self.rank is None: 1211 return unknown_shape() 1212 1213 # A Rankless TensorShape is the most specific supertype for shapes whose 1214 # ranks do not match. 1215 if any(other.dims is None or self.rank != other.rank for other in others): 1216 return unknown_shape() 1217 1218 # Retain the integer dimension if it is the same across all others, else 1219 # use an undefined dimension. 1220 dims = [ 1221 dim if all(dim == other._dims[i] 1222 for other in others) else None 1223 for i, dim in enumerate(self._dims) 1224 ] 1225 return TensorShape(dims) 1226 1227 @classmethod 1228 def experimental_type_proto(cls) -> Type[tensor_shape_pb2.TensorShapeProto]: 1229 """Returns the type of proto associated with TensorShape serialization.""" 1230 return tensor_shape_pb2.TensorShapeProto 1231 1232 @classmethod 1233 def experimental_from_proto( 1234 cls, proto: tensor_shape_pb2.TensorShapeProto) -> "TensorShape": 1235 """Returns a TensorShape instance based on the serialized proto.""" 1236 return TensorShape(proto) 1237 1238 def experimental_as_proto(self) -> tensor_shape_pb2.TensorShapeProto: 1239 """Returns a proto representation of the TensorShape instance.""" 1240 return self.as_proto() 1241 1242 # TODO(b/216206374): Consider deprecation at TraceType release. 1243 def is_compatible_with(self, other): 1244 """Returns True iff `self` is compatible with `other`. 1245 1246 Two possibly-partially-defined shapes are compatible if there 1247 exists a fully-defined shape that both shapes can represent. Thus, 1248 compatibility allows the shape inference code to reason about 1249 partially-defined shapes. For example: 1250 1251 * TensorShape(None) is compatible with all shapes. 1252 1253 * TensorShape([None, None]) is compatible with all two-dimensional 1254 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is 1255 not compatible with, for example, TensorShape([None]) or 1256 TensorShape([None, None, None]). 1257 1258 * TensorShape([32, None]) is compatible with all two-dimensional shapes 1259 with size 32 in the 0th dimension, and also TensorShape([None, None]) 1260 and TensorShape(None). It is not compatible with, for example, 1261 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). 1262 1263 * TensorShape([32, 784]) is compatible with itself, and also 1264 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, 1265 None]) and TensorShape(None). It is not compatible with, for example, 1266 TensorShape([32, 1, 784]) or TensorShape([None]). 1267 1268 The compatibility relation is reflexive and symmetric, but not 1269 transitive. For example, TensorShape([32, 784]) is compatible with 1270 TensorShape(None), and TensorShape(None) is compatible with 1271 TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with 1272 TensorShape([4, 4]). 1273 1274 Args: 1275 other: Another TensorShape. 1276 1277 Returns: 1278 True iff `self` is compatible with `other`. 1279 1280 """ 1281 other = as_shape(other) 1282 if self.dims is not None and other.dims is not None: 1283 if self.rank != other.rank: 1284 return False 1285 for x_dim, y_dim in zip(self.dims, other.dims): 1286 if not x_dim.is_compatible_with(y_dim): 1287 return False 1288 return True 1289 1290 def assert_is_compatible_with(self, other): 1291 """Raises exception if `self` and `other` do not represent the same shape. 1292 1293 This method can be used to assert that there exists a shape that both 1294 `self` and `other` represent. 1295 1296 Args: 1297 other: Another TensorShape. 1298 1299 Raises: 1300 ValueError: If `self` and `other` do not represent the same shape. 1301 """ 1302 if not self.is_compatible_with(other): 1303 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 1304 1305 def most_specific_compatible_shape(self, other): 1306 """Returns the most specific TensorShape compatible with `self` and `other`. 1307 1308 * TensorShape([None, 1]) is the most specific TensorShape compatible with 1309 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that 1310 TensorShape(None) is also compatible with above mentioned TensorShapes. 1311 1312 * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with 1313 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more 1314 less specific TensorShapes compatible with above mentioned TensorShapes, 1315 e.g. TensorShape([1, 2, None]), TensorShape(None). 1316 1317 Args: 1318 other: Another `TensorShape`. 1319 1320 Returns: 1321 A `TensorShape` which is the most specific compatible shape of `self` 1322 and `other`. 1323 """ 1324 1325 other = as_shape(other) 1326 if self.dims is None or other.dims is None or self.rank != other.rank: 1327 return unknown_shape() 1328 1329 dims = [ 1330 d1 if d1 is not None and d2 is not None and d1 == d2 else None 1331 for d1, d2 in zip(self.dims, other.dims) 1332 ] 1333 return TensorShape(dims) 1334 1335 def is_fully_defined(self): 1336 """Returns True iff `self` is fully defined in every dimension.""" 1337 return (self._dims is not None and 1338 all(dim is not None for dim in self._dims)) 1339 1340 def assert_is_fully_defined(self): 1341 """Raises an exception if `self` is not fully defined in every dimension. 1342 1343 Raises: 1344 ValueError: If `self` does not have a known value for every dimension. 1345 """ 1346 if not self.is_fully_defined(): 1347 raise ValueError("Shape %s is not fully defined" % self) 1348 1349 def as_list(self): 1350 """Returns a list of integers or `None` for each dimension. 1351 1352 Returns: 1353 A list of integers or `None` for each dimension. 1354 1355 Raises: 1356 ValueError: If `self` is an unknown shape with an unknown rank. 1357 """ 1358 if self._dims is None: 1359 raise ValueError("as_list() is not defined on an unknown TensorShape.") 1360 return list(self._dims) 1361 1362 def as_proto(self): 1363 """Returns this shape as a `TensorShapeProto`.""" 1364 if self._dims is None: 1365 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) 1366 else: 1367 return tensor_shape_pb2.TensorShapeProto(dim=[ 1368 tensor_shape_pb2.TensorShapeProto.Dim( 1369 size=-1 if d is None else d) for d in self._dims 1370 ]) 1371 1372 def __eq__(self, other): 1373 """Returns True if `self` is equivalent to `other`. 1374 1375 It first tries to convert `other` to `TensorShape`. `TypeError` is thrown 1376 when the conversion fails. Otherwise, it compares each element in the 1377 TensorShape dimensions. 1378 1379 * Two *Fully known* shapes, return True iff each element is equal. 1380 >>> t_a = tf.TensorShape([1,2]) 1381 >>> a = [1, 2] 1382 >>> t_b = tf.TensorShape([1,2]) 1383 >>> t_c = tf.TensorShape([1,2,3]) 1384 >>> t_a.__eq__(a) 1385 True 1386 >>> t_a.__eq__(t_b) 1387 True 1388 >>> t_a.__eq__(t_c) 1389 False 1390 1391 * Two *Partially-known* shapes, return True iff each element is equal. 1392 >>> p_a = tf.TensorShape([1,None]) 1393 >>> p_b = tf.TensorShape([1,None]) 1394 >>> p_c = tf.TensorShape([2,None]) 1395 >>> p_a.__eq__(p_b) 1396 True 1397 >>> t_a.__eq__(p_a) 1398 False 1399 >>> p_a.__eq__(p_c) 1400 False 1401 1402 * Two *Unknown shape*, return True. 1403 >>> unk_a = tf.TensorShape(None) 1404 >>> unk_b = tf.TensorShape(None) 1405 >>> unk_a.__eq__(unk_b) 1406 True 1407 >>> unk_a.__eq__(t_a) 1408 False 1409 1410 Args: 1411 other: A `TensorShape` or type that can be converted to `TensorShape`. 1412 1413 Returns: 1414 True if the dimensions are all equal. 1415 1416 Raises: 1417 TypeError if `other` can not be converted to `TensorShape`. 1418 """ 1419 1420 try: 1421 other = as_shape(other) 1422 except TypeError: 1423 return NotImplemented 1424 1425 return self._dims == other._dims 1426 1427 def __hash__(self): 1428 return hash(self._dims) 1429 1430 def __reduce__(self): 1431 return TensorShape, (self.dims,) 1432 1433 def __concat__(self, other): 1434 return self.concatenate(other) 1435 1436trace_type.register_serializable(TensorShape) 1437 1438 1439def as_shape(shape): 1440 """Converts the given object to a TensorShape.""" 1441 if isinstance(shape, TensorShape): 1442 return shape 1443 else: 1444 return TensorShape(shape) 1445 1446 1447def unknown_shape(rank=None, **kwargs): 1448 """Returns an unknown TensorShape, optionally with a known rank. 1449 1450 Args: 1451 rank: (Optional) If specified, the number of dimensions in the shape. 1452 **kwargs: For backwards compatibility. 1453 1454 Returns: 1455 An unknown TensorShape. 1456 1457 Raises: 1458 TypeError: In case of invalid arguments. 1459 """ 1460 if rank is None and "ndims" in kwargs: 1461 rank = kwargs.pop("ndims") 1462 if kwargs: 1463 raise TypeError("Unknown argument: %s" % kwargs) 1464 if rank is None: 1465 return TensorShape(None) 1466 else: 1467 return TensorShape([Dimension(None)] * rank) 1468