1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorArray: a dynamically sized array of Tensors.""" 16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23 24import traceback 25import weakref 26 27import numpy as np 28 29from tensorflow.python.eager import context 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors_impl 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.framework import type_spec 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_util 40from tensorflow.python.ops import gen_control_flow_ops 41from tensorflow.python.ops import gen_data_flow_ops 42from tensorflow.python.ops import list_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import tf_should_use 46from tensorflow.python.util.tf_export import tf_export 47 48 49# _GraphTensorArray accesses many of the hidden generated ops, but is in 50# fact built to wrap these methods. 51# pylint: disable=protected-access 52class _GraphTensorArray(object): 53 """Graph-mode implementation of TensorArray.""" 54 55 def __init__(self, 56 dtype, 57 size=None, 58 dynamic_size=None, 59 clear_after_read=None, 60 tensor_array_name=None, 61 handle=None, 62 flow=None, 63 infer_shape=True, 64 element_shape=None, 65 colocate_with_first_write_call=True, 66 name=None): 67 """Constructs a graph mode TensorArray. 68 69 Args: 70 dtype: (required) data type of the TensorArray. 71 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 72 Required if handle is not provided. 73 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 74 can grow the TensorArray past its initial size. Default: False. 75 clear_after_read: Boolean (optional, default: True). If True, clear 76 TensorArray values after reading them. This disables read-many 77 semantics, but allows early release of memory. 78 tensor_array_name: (optional) Python string: the name of the TensorArray. 79 This is used when creating the TensorArray handle. If this value is 80 set, handle should be None. 81 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 82 is set, tensor_array_name should be None. Only supported in graph mode. 83 flow: (optional) A float `Tensor` scalar coming from an existing 84 `TensorArray.flow`. Only supported in graph mode. 85 infer_shape: (optional, default: True) If True, shape inference is 86 enabled. In this case, all elements must have the same shape. 87 element_shape: (optional, default: None) A `TensorShape` object specifying 88 the shape constraints of each of the elements of the TensorArray. Need 89 not be fully defined. 90 colocate_with_first_write_call: If `True`, the TensorArray will be 91 colocated on the same device as the Tensor used on its first write 92 (write operations include `write`, `unstack`, and `split`). If `False`, 93 the TensorArray will be placed on the device determined by the device 94 context available during its initialization. 95 name: A name for the operation (optional). 96 97 Raises: 98 ValueError: if both handle and tensor_array_name are provided. 99 TypeError: if handle is provided but is not a Tensor. 100 """ 101 if handle is not None and tensor_array_name: 102 raise ValueError( 103 "Cannot provide both `handle` and `tensor_array_name` arguments at " 104 "the same time.") 105 if handle is not None and not isinstance(handle, ops.Tensor): 106 raise TypeError( 107 f"Expected `handle` to be a Tensor, but got `{handle}` of type " 108 f"`{type(handle)}` instead.") 109 if handle is None and size is None: 110 raise ValueError( 111 "Argument `size` must be provided if handle is not provided.") 112 if handle is not None and size is not None: 113 raise ValueError("Cannot provide both a `handle` and `size` arguments " 114 "at the same time.") 115 if handle is not None and element_shape is not None: 116 raise ValueError( 117 "Cannot provide both `handle` and `element_shape` arguments " 118 "at the same time.") 119 if handle is not None and dynamic_size is not None: 120 raise ValueError( 121 "Cannot provide both `handle` and `dynamic_size` arguments " 122 "at the same time.") 123 if handle is not None and clear_after_read is not None: 124 raise ValueError( 125 "Cannot provide both `handle` and `clear_after_read` arguments " 126 "at the same time.") 127 128 if clear_after_read is None: 129 clear_after_read = True 130 self._dynamic_size = dynamic_size or False 131 self._dtype = dtypes.as_dtype(dtype).base_dtype 132 133 # Used to keep track of what tensors the TensorArray should be 134 # colocated with. We choose to colocate the TensorArray with the 135 # first tensor written to it. 136 self._colocate_with_first_write_call = colocate_with_first_write_call 137 if colocate_with_first_write_call: 138 self._colocate_with = [] 139 else: 140 self._colocate_with = None 141 142 # Record the current static shape for the array elements. The element 143 # shape is defined either by `element_shape` or the shape of the tensor 144 # of the first write. If `infer_shape` is true, all writes checks for 145 # shape equality. 146 self._element_shape = [tensor_shape.as_shape(element_shape)] 147 self._infer_shape = infer_shape 148 self._size = size 149 with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: 150 if handle is not None: 151 self._handle = handle 152 if flow is None: 153 raise ValueError("flow must not be None if handle is not None.") 154 self._flow = flow 155 else: 156 # Construct the TensorArray with an empty device. The first 157 # write into the TensorArray from a Tensor with a set device 158 # will retroactively set the device value of this op. 159 def create(): 160 """Create the TensorArray op.""" 161 return gen_data_flow_ops.tensor_array_v3( 162 dtype=dtype, 163 size=size, 164 element_shape=element_shape, 165 identical_element_shapes=infer_shape, 166 dynamic_size=self._dynamic_size, 167 clear_after_read=clear_after_read, 168 tensor_array_name=tensor_array_name, 169 name=scope) 170 171 if colocate_with_first_write_call: 172 with ops.device(None), ops.colocate_with(None, ignore_existing=True): 173 self._handle, self._flow = create() 174 else: 175 self._handle, self._flow = create() 176 177 @property 178 def flow(self): 179 return self._flow 180 181 @property 182 def dtype(self): 183 return self._dtype 184 185 @property 186 def handle(self): 187 return self._handle 188 189 @property 190 def element_shape(self): 191 return self._element_shape[0] 192 193 def _check_element_shape(self, shape): 194 """Changes the element shape of the array given a shape to merge with. 195 196 Args: 197 shape: A `TensorShape` object to merge with. 198 199 Raises: 200 ValueError: if the provided shape is incompatible with the current 201 element shape of the `TensorArray`. 202 """ 203 if not shape.is_compatible_with(self.element_shape): 204 raise ValueError("Inconsistent shapes: saw %s but expected %s " % 205 (shape, self.element_shape)) 206 if self._infer_shape: 207 self._element_shape[0] = self.element_shape.merge_with(shape) 208 209 @contextlib.contextmanager 210 def _maybe_colocate_with(self, value): 211 """Colocate operations with an internal colocation group or `value`. 212 213 Args: 214 value: `Tensor`, the tensor to try to colocate with. 215 216 Yields: 217 Does not yield anything, but the new context is a colocation context. 218 219 If no internal colocation group is set, colocate with `value` and set 220 the internal colocation group to be value. 221 """ 222 if not self._colocate_with_first_write_call: 223 yield 224 else: 225 if not self._colocate_with: 226 self._colocate_with.append(value) 227 with ops.colocate_with(self._colocate_with[0]): 228 yield 229 230 def identity(self): 231 """See TensorArray.""" 232 flow = array_ops.identity(self._flow) 233 return build_ta_with_new_flow(self, flow) 234 235 def grad(self, source, flow=None, name=None): 236 """See TensorArray.""" 237 # tensor_array_grad requires a flow input when forward 238 # TensorArrays are dynamically sized. This forces the creation 239 # of the grad TensorArray only once the final forward array's size 240 # is fixed. 241 if flow is None: 242 flow = self.flow 243 with ops.name_scope(name, "TensorArrayGrad", [self._handle]): 244 with ops.colocate_with(self._handle): 245 g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3( 246 handle=self._handle, source=source, flow_in=flow, name=name) 247 with ops.control_dependencies([g_handle]): 248 flow = array_ops.identity(flow, name="gradient_flow") 249 g = TensorArray( 250 dtype=self._dtype, 251 handle=g_handle, 252 flow=flow, 253 infer_shape=self._infer_shape, 254 colocate_with_first_write_call=False) 255 # pylint: disable=protected-access 256 g._implementation._element_shape = self._element_shape 257 # pylint: enable=protected-access 258 return g 259 260 def read(self, index, name=None): 261 """See TensorArray.""" 262 value = gen_data_flow_ops.tensor_array_read_v3( 263 handle=self._handle, 264 index=index, 265 flow_in=self._flow, 266 dtype=self._dtype, 267 name=name) 268 if self._element_shape: 269 value.set_shape(self._element_shape[0].dims) 270 return value 271 272 def write(self, index, value, name=None): 273 """See TensorArray.""" 274 with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]): 275 # TODO(b/129870929): Fix after all callers provide proper init dtype. 276 value = ops.convert_to_tensor( 277 value, preferred_dtype=self._dtype, name="value") 278 _check_dtypes(value, self._dtype) 279 self._check_element_shape(value.shape) 280 with self._maybe_colocate_with(value): 281 flow_out = gen_data_flow_ops.tensor_array_write_v3( 282 handle=self._handle, 283 index=index, 284 value=value, 285 flow_in=self._flow, 286 name=name) 287 return build_ta_with_new_flow(self, flow_out) 288 289 def stack(self, name=None): 290 """See TensorArray.""" 291 with ops.colocate_with(self._handle): 292 with ops.name_scope(name, "TensorArrayStack", [self._handle]): 293 value = self.gather(math_ops.range(0, self.size()), name=name) 294 if (self.element_shape and not self._dynamic_size and 295 self._size is not None): 296 value.set_shape([tensor_util.constant_value(self._size)] + 297 self.element_shape.dims) 298 return value 299 300 def gather(self, indices, name=None): 301 """See TensorArray.""" 302 if self._element_shape: 303 element_shape = self._element_shape[0] 304 else: 305 element_shape = tensor_shape.unknown_shape(None) 306 value = gen_data_flow_ops.tensor_array_gather_v3( 307 handle=self._handle, 308 indices=indices, 309 flow_in=self._flow, 310 dtype=self._dtype, 311 name=name, 312 element_shape=element_shape) 313 if self.element_shape: 314 value.set_shape([None] + self.element_shape.dims) 315 return value 316 317 def concat(self, name=None): 318 """See TensorArray.""" 319 value, _ = gen_data_flow_ops.tensor_array_concat_v3( 320 handle=self._handle, 321 flow_in=self._flow, 322 dtype=self._dtype, 323 name=name, 324 element_shape_except0=self.element_shape[1:]) 325 if self.element_shape: 326 value.set_shape([None] + self.element_shape.dims[1:]) 327 return value 328 329 @tf_should_use.should_use_result 330 def unstack(self, value, name=None): 331 """See TensorArray.""" 332 with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]): 333 num_elements = array_ops.shape(value)[0] 334 return self.scatter( 335 indices=math_ops.range(0, num_elements), value=value, name=name) 336 337 @tf_should_use.should_use_result 338 def scatter(self, indices, value, name=None): 339 """See TensorArray.""" 340 with ops.name_scope(name, "TensorArrayScatter", 341 [self._handle, value, indices]): 342 # TODO(b/129870929): Fix after all callers provide proper init dtype. 343 value = ops.convert_to_tensor( 344 value, preferred_dtype=self._dtype, name="value") 345 _check_dtypes(value, self._dtype) 346 if not context.executing_eagerly(): 347 self._check_element_shape(value.shape[1:]) 348 with self._maybe_colocate_with(value): 349 flow_out = gen_data_flow_ops.tensor_array_scatter_v3( 350 handle=self._handle, 351 indices=indices, 352 value=value, 353 flow_in=self._flow, 354 name=name) 355 return build_ta_with_new_flow(self, flow_out) 356 357 @tf_should_use.should_use_result 358 def split(self, value, lengths, name=None): 359 """See TensorArray.""" 360 with ops.name_scope(name, "TensorArraySplit", 361 [self._handle, value, lengths]): 362 value = ops.convert_to_tensor(value, dtype=self._dtype, name="value") 363 with self._maybe_colocate_with(value): 364 lengths_64 = math_ops.cast(lengths, dtypes.int64) 365 if not context.executing_eagerly(): 366 clengths = tensor_util.constant_value(lengths_64) 367 if value.shape.dims is not None and clengths is not None: 368 if clengths.shape and clengths.max() == clengths.min(): 369 self._check_element_shape( 370 tensor_shape.TensorShape([clengths[0] 371 ]).concatenate(value.shape[1:])) 372 flow_out = gen_data_flow_ops.tensor_array_split_v3( 373 handle=self._handle, 374 value=value, 375 lengths=lengths_64, 376 flow_in=self._flow, 377 name=name) 378 return build_ta_with_new_flow(self, flow_out) 379 380 def size(self, name=None): 381 """See TensorArray.""" 382 if not self._dynamic_size and self._size is not None: 383 return ops.convert_to_tensor(self._size, dtype=dtypes.int32) 384 else: 385 return gen_data_flow_ops.tensor_array_size_v3( 386 handle=self._handle, flow_in=self.flow, name=name) 387 388 @tf_should_use.should_use_result 389 def close(self, name=None): 390 """See TensorArray.""" 391 return gen_data_flow_ops.tensor_array_close_v3( 392 handle=self._handle, name=name) 393 394 395class _GraphTensorArrayV2(object): 396 """Graph-mode implementation of TensorArray backed by TensorLists. 397 398 The backing tensor of this TensorArray is a TensorList variant tensor which is 399 stored in the `flow`. The `handle` is always none here. The reason we use the 400 `flow` field and not the `handle` field is to ensure backwards compatibility 401 with legacy control flow. 402 """ 403 404 def __init__(self, 405 dtype, 406 size=None, 407 dynamic_size=None, 408 clear_after_read=None, 409 tensor_array_name=None, 410 handle=None, 411 flow=None, 412 infer_shape=True, 413 element_shape=None, 414 colocate_with_first_write_call=True, 415 name=None): 416 """Constructs a graph mode TensorArray. 417 418 Args: 419 dtype: (required) data type of the TensorArray. 420 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 421 Required if flow is not provided. 422 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 423 can grow the TensorArray past its initial size. Default: False. 424 clear_after_read: (optional) unused. Not supported in TensorLists. 425 tensor_array_name: (optional) unused. 426 handle: (optional) Must always be None. 427 flow: (optional) A variant `Tensor` scalar for a TensorList. 428 infer_shape: (optional, default: True) If True, shape inference is 429 enabled. In this case, all elements must have the same shape. 430 element_shape: (optional, default: None) A `TensorShape` object specifying 431 the shape constraints of each of the elements of the TensorArray. Need 432 not be fully defined. 433 colocate_with_first_write_call: (optional). unused. 434 name: (optional) A name for the operation. 435 436 Raises: 437 ValueError: if both handle and tensor_array_name are provided. 438 TypeError: if handle is provided but is not a Tensor. 439 """ 440 assert handle is None 441 del handle 442 del clear_after_read 443 del tensor_array_name 444 del colocate_with_first_write_call 445 446 self._dynamic_size = dynamic_size 447 self._size = size 448 449 if (flow is not None and 450 (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)): 451 raise TypeError( 452 f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` " 453 f"instead.") 454 if flow is None and size is None: 455 raise ValueError("Argument `size` must be provided if argument `flow` " 456 "is not provided.") 457 if flow is not None and size is not None: 458 raise ValueError("Cannot provide both `flow` and `size` arguments " 459 "at the same time.") 460 if flow is not None and element_shape is not None: 461 raise ValueError( 462 "Cannot provide both `flow` and `element_shape` arguments" 463 "at the same time.") 464 465 self._dtype = dtypes.as_dtype(dtype).base_dtype 466 467 # Record the current static shape for the array elements. The element 468 # shape is defined either by `element_shape` or the shape of the tensor 469 # of the first write. If `infer_shape` is true, all writes checks for 470 # shape equality. 471 self._element_shape = [tensor_shape.as_shape(element_shape)] 472 self._infer_shape = infer_shape 473 with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope: 474 if flow is None: 475 self._flow = list_ops.tensor_list_reserve( 476 element_shape=element_shape, 477 num_elements=size, 478 element_dtype=dtype, 479 name=scope) 480 else: 481 self._flow = flow 482 483 # For backwards compatibility. 484 self._colocate_with_first_write_call = None 485 self._colocate_with = None 486 487 @property 488 def flow(self): 489 return self._flow 490 491 @property 492 def dtype(self): 493 return self._dtype 494 495 @property 496 def element_shape(self): 497 return self._element_shape[0] 498 499 @property 500 def handle(self): 501 # We intentionally do not raise an error so that legacy while_loop does not 502 # complain. 503 return None 504 505 def _check_element_shape(self, shape): 506 """Changes the element shape of the array given a shape to merge with. 507 508 Args: 509 shape: A `TensorShape` object to merge with. 510 511 Raises: 512 ValueError: if the provided shape is incompatible with the current 513 element shape of the `TensorArray`. 514 """ 515 if not shape.is_compatible_with(self.element_shape): 516 raise ValueError("Inconsistent shapes: saw %s but expected %s " % 517 (shape, self.element_shape)) 518 if self._infer_shape: 519 self._element_shape[0] = self.element_shape.merge_with(shape) 520 521 def identity(self): 522 """See TensorArray.""" 523 flow = array_ops.identity(self._flow) 524 return build_ta_with_new_flow(self, flow) 525 526 def grad(self, source, flow=None, name=None): 527 """Not supported.""" 528 raise NotImplementedError() 529 530 def read(self, index, name=None): 531 """See TensorArray.""" 532 with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]): 533 value = list_ops.tensor_list_get_item( 534 input_handle=self._flow, 535 index=index, 536 element_dtype=self._dtype, 537 element_shape=self.element_shape, 538 name=name) 539 return value 540 541 def write(self, index, value, name=None): 542 """See TensorArray.""" 543 with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]): 544 # TODO(b/129870929): Fix after all callers provide proper init dtype. 545 value = ops.convert_to_tensor( 546 value, preferred_dtype=self._dtype, name="value") 547 _check_dtypes(value, self._dtype) 548 self._check_element_shape(value.shape) 549 flow_out = list_ops.tensor_list_set_item( 550 input_handle=self._flow, 551 index=index, 552 item=value, 553 resize_if_index_out_of_bounds=self._dynamic_size, 554 name=name) 555 return build_ta_with_new_flow(self, flow_out) 556 557 def stack(self, name=None): 558 """See TensorArray.""" 559 with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]): 560 # TODO(b/139941163): remove constant_value after changing num_elements to regular input 561 if not self._dynamic_size and self._size is not None: 562 ta_size = tensor_util.constant_value(self._size) 563 else: 564 ta_size = -1 565 value = list_ops.tensor_list_stack( 566 input_handle=self._flow, 567 element_dtype=self._dtype, 568 num_elements=ta_size, 569 element_shape=self.element_shape) 570 return value 571 572 def gather(self, indices, name=None): 573 """See TensorArray.""" 574 value = list_ops.tensor_list_gather( 575 input_handle=self._flow, 576 indices=indices, 577 element_dtype=self._dtype, 578 element_shape=self.element_shape, 579 name=name) 580 return value 581 582 def concat(self, name=None): 583 """See TensorArray.""" 584 if self.element_shape: 585 element_shape = [None] + self.element_shape.dims[1:] 586 else: 587 element_shape = None 588 589 value = list_ops.tensor_list_concat( 590 input_handle=self._flow, 591 element_dtype=self._dtype, 592 element_shape=element_shape, 593 name=name) 594 return value 595 596 @tf_should_use.should_use_result 597 def unstack(self, value, name=None): 598 """See TensorArray.""" 599 with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]): 600 # TODO(b/129870929): Fix after all callers provide proper init dtype. 601 value = ops.convert_to_tensor( 602 value, preferred_dtype=self._dtype, name="value") 603 _check_dtypes(value, self._dtype) 604 self._check_element_shape(value.shape[1:]) 605 flow_out = list_ops.tensor_list_from_tensor( 606 tensor=value, element_shape=value.shape[1:]) 607 return build_ta_with_new_flow(self, flow_out) 608 609 @tf_should_use.should_use_result 610 def scatter(self, indices, value, name=None): 611 """See TensorArray.""" 612 with ops.name_scope(name, "TensorArrayScatter", 613 [self._flow, value, indices]): 614 # TODO(b/129870929): Fix after all callers provide proper init dtype. 615 value = ops.convert_to_tensor( 616 value, preferred_dtype=self._dtype, name="value") 617 _check_dtypes(value, self._dtype) 618 self._check_element_shape(value.shape[1:]) 619 flow_out = list_ops.tensor_list_scatter( 620 tensor=value, 621 indices=indices, 622 element_shape=self.element_shape, 623 input_handle=self._flow) 624 return build_ta_with_new_flow(self, flow_out) 625 626 @tf_should_use.should_use_result 627 def split(self, value, lengths, name=None): 628 """See TensorArray.""" 629 with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]): 630 # TODO(b/129870929): Fix after all callers provide proper init dtype. 631 value = ops.convert_to_tensor( 632 value, preferred_dtype=self._dtype, name="value") 633 _check_dtypes(value, self._dtype) 634 lengths_64 = math_ops.cast(lengths, dtypes.int64) 635 if not context.executing_eagerly(): 636 clengths = tensor_util.constant_value(lengths_64) 637 if value.shape.dims is not None and clengths is not None: 638 if clengths.shape and clengths.max() == clengths.min(): 639 self._check_element_shape( 640 tensor_shape.TensorShape([clengths[0] 641 ]).concatenate(value.shape[1:])) 642 flow_out = list_ops.tensor_list_split( 643 tensor=value, 644 lengths=lengths_64, 645 element_shape=self.element_shape, 646 name=name) 647 return build_ta_with_new_flow(self, flow_out) 648 649 def size(self, name=None): 650 """See TensorArray.""" 651 if not self._dynamic_size and self._size is not None: 652 return ops.convert_to_tensor(self._size, dtype=dtypes.int32) 653 else: 654 return list_ops.tensor_list_length(input_handle=self._flow, name=name) 655 656 def close(self, name=None): 657 """See TensorArray.""" 658 return gen_control_flow_ops.no_op(name=name) 659 660 661# pylint: enable=protected-access 662 663 664class _EagerTensorArray(object): 665 """Eager-compatible implementation of TensorArray.""" 666 667 def __init__(self, 668 dtype, 669 size=None, 670 dynamic_size=None, 671 clear_after_read=None, 672 tensor_array_name=None, 673 handle=None, 674 flow=None, 675 infer_shape=True, 676 element_shape=None, 677 colocate_with_first_write_call=True, 678 name=None): 679 """Constructs a TensorArray compatible with eager execution. 680 681 Args: 682 dtype: (required) data type of the TensorArray. 683 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 684 Required if handle is not provided. 685 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 686 can grow the TensorArray past its initial size. Default: False. 687 clear_after_read: Boolean (optional, default: True). If True, clear 688 TensorArray values after reading them. This disables read-many 689 semantics, but allows early release of memory. 690 tensor_array_name: unused. 691 handle: unsupported. 692 flow: unsupported. 693 infer_shape: used for error checking, same semantics as TensorArray. 694 element_shape: used for error checking, same semantics as TensorArray. 695 colocate_with_first_write_call: unsupported. 696 name: unsupported. 697 698 Raises: 699 ValueError: handle or flow are supplied, or if size is not supplied. 700 """ 701 702 del (flow, tensor_array_name, name) # Unused. 703 704 if handle is not None: 705 raise ValueError("TensorArray handles are not supported when eager " 706 "execution is enabled.") 707 if size is None: 708 raise ValueError("Size must be declared for TensorArrays when eager " 709 "execution is enabled.") 710 711 # These attributes are not meaningful when eager is enabled, but some 712 # library functions (e.g., those in control_flow_ops.py) access them to 713 # create new tensor arrays; as such, we define them for the sake of 714 # compatibility. 715 self._handle = None 716 # we assign a dummy value to _flow in case other code assumes it to be 717 # a Tensor 718 self._flow = constant_op.constant(0, dtype=dtypes.int32) 719 self._infer_shape = infer_shape 720 self._element_shape = tensor_shape.as_shape(element_shape) 721 self._colocate_with_first_write_call = colocate_with_first_write_call 722 723 self._dtype = dtypes.as_dtype(dtype).base_dtype 724 self._dynamic_size = dynamic_size or False 725 self._clear_after_read = (True 726 if clear_after_read is None else clear_after_read) 727 self._previously_read_indices = [] 728 729 if isinstance(size, ops.EagerTensor): 730 size = size.numpy() 731 self._tensor_array = [None for _ in range(size)] 732 733 @property 734 def flow(self): 735 """For compatibility; flows are not meaningful when eager is enabled.""" 736 return self._flow 737 738 @property 739 def dtype(self): 740 return self._dtype 741 742 @property 743 def handle(self): 744 """For compatibility; handles are not meaningful when eager is enabled.""" 745 return self._handle 746 747 @property 748 def element_shape(self): 749 return self._element_shape 750 751 def identity(self): 752 """See TensorArray.""" 753 return self.parent() 754 755 def grad(self, source, flow=None, name=None): 756 raise NotImplementedError( 757 "TensorArray.grad is not supported when executing eagerly; eager's " 758 "gradient implementation does not use/need this function to compute " 759 "gradients of operations that use TensorArrays.") 760 761 def read(self, index, name=None): 762 """See TensorArray.""" 763 del name # not meaningful when executing eagerly. 764 765 if isinstance(index, ops.EagerTensor): 766 index = index.numpy() 767 768 if index < 0: 769 raise errors_impl.OutOfRangeError( 770 None, None, 771 "Reading from negative indices (index %d) is not allowed." % index) 772 773 if index >= len(self._tensor_array): 774 raise errors_impl.OutOfRangeError( 775 None, None, "Tried to read from index %d but array size is: %d " % 776 (index, len(self._tensor_array))) 777 778 tensor = self._tensor_array[index] 779 if tensor is None: 780 if index in self._previously_read_indices: 781 raise errors_impl.InvalidArgumentError( 782 None, None, 783 "Could not read index %d twice because it was cleared after " 784 "a previous read (perhaps try setting clear_after_read = false?)" % 785 index) 786 else: 787 tensor = self._maybe_zero(index) 788 789 if self._clear_after_read: 790 self._tensor_array[index] = None 791 self._previously_read_indices.append(index) 792 return tensor 793 794 def _write(self, index, value): 795 """Writes `value` into index named by `index`. 796 797 Args: 798 index: 0-D. int32 scalar with the index to write to. 799 value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`. 800 801 Raises: 802 errors_impl.InvalidArgumentError: `value` dtype does not match dtype. 803 errors_impl.OutOfRangeError: `index` is out of bounds. 804 ValueError: shape of `value` is not consistent with inferred shape. 805 """ 806 807 if isinstance(index, ops.EagerTensor): 808 index = index.numpy() 809 810 if index < 0: 811 raise errors_impl.OutOfRangeError( 812 None, None, 813 "Writing to negative indices (index %d) is not allowed." % index) 814 815 size = len(self._tensor_array) 816 if index >= size: 817 if not self._dynamic_size: 818 raise errors_impl.OutOfRangeError( 819 None, None, 820 "Tried to write to index %d but array is not resizeable and size " 821 "is: %d " % (index, size)) 822 self._tensor_array.extend(None for _ in range(index - size + 1)) 823 824 if not isinstance(value, ops.EagerTensor): 825 # TODO(b/129870929): Fix after all callers provide proper init dtype. 826 value = ops.convert_to_tensor( 827 value, preferred_dtype=self._dtype, name="value") 828 829 if self._dtype != value.dtype: 830 raise errors_impl.InvalidArgumentError( 831 None, None, 832 "TensorArray dtype is %s but Op is trying to write dtype %s " % 833 (self._dtype.name, value.dtype.name)) 834 835 if not self._element_shape.is_compatible_with(value.shape): 836 raise ValueError("Incompatible shape for value (%s), expected (%s)" % 837 (value.shape, self._element_shape)) 838 839 if self._infer_shape: 840 self._element_shape = self._element_shape.merge_with(value.shape) 841 842 self._tensor_array[index] = value 843 844 def write(self, index, value, name=None): 845 """See TensorArray.""" 846 del name # not meaningful when executing eagerly. 847 self._write(index, value) 848 return self.parent() 849 850 def _maybe_zero(self, ix): 851 val = self._tensor_array[ix] 852 if val is None: 853 val = self._tensor_array[ix] = array_ops.zeros( 854 shape=self._element_shape, dtype=self._dtype) 855 return val 856 857 def stack(self, name=None): 858 """See TensorArray.""" 859 if self._tensor_array: 860 for ix in range(len(self._tensor_array)): 861 self._maybe_zero(ix) 862 if not self._tensor_array and self._element_shape.is_fully_defined(): 863 return ops.convert_to_tensor( 864 np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype) 865 else: 866 return ops.convert_to_tensor( 867 self._tensor_array, name=name, dtype=self._dtype) 868 869 def gather(self, indices, name=None): 870 """See TensorArray.""" 871 del name # not meaningful when executing eagerly. 872 if isinstance(indices, ops.EagerTensor): 873 indices = indices.numpy() 874 return array_ops.stack([self._maybe_zero(i) for i in indices]) 875 876 def concat(self, name=None): 877 """See TensorArray.""" 878 try: 879 return array_ops.concat( 880 [self._maybe_zero(ix) for ix in range(len(self._tensor_array))], 881 0, 882 name=name) 883 except errors_impl.OpError: 884 # Reproduce a subset of the error-handling for graph-mode TensorArrays. 885 shapes = [t.shape for t in self._tensor_array] 886 ndims = [s.ndims for s in shapes] 887 if 0 in ndims: 888 idx = ndims.index(0) 889 raise errors_impl.InvalidArgumentError( 890 None, None, "Concat saw a scalar shape at index %d but requires " 891 "at least vectors." % idx) 892 else: 893 raise 894 895 def unstack(self, value, name=None): 896 """See TensorArray.""" 897 tensors = array_ops.unstack(value, name=name) 898 if len(tensors) > len(self._tensor_array) and not self._dynamic_size: 899 raise ValueError( 900 "Cannot unstack %d tensors into a TensorArray of static size %d " % 901 (len(tensors), len(self._tensor_array))) 902 self._tensor_array = tensors 903 return self.parent() 904 905 def scatter(self, indices, value, name=None): 906 """See TensorArray.""" 907 del name # not meaningful when executing eagerly. 908 if isinstance(indices, ops.EagerTensor): 909 indices = indices.numpy() 910 for index, val in zip(indices, array_ops.unstack(value)): 911 self._write(index, val) # pylint: disable=protected-access 912 return self.parent() 913 914 def split(self, value, lengths, name=None): 915 """See TensorArray.""" 916 # TODO(b/129870929): Fix after all callers provide proper init dtype. 917 value = ops.convert_to_tensor( 918 value, preferred_dtype=self._dtype, name="value") 919 _check_dtypes(value, self._dtype) 920 lengths = ops.convert_to_tensor(lengths) 921 sum_lengths = math_ops.reduce_sum(lengths) 922 if lengths.shape.ndims != 1: 923 raise errors_impl.InvalidArgumentError( 924 None, None, "Expected lengths to be a vector, received shape: %s " % 925 lengths.shape.as_list()) 926 elif value.shape.ndims == 0: 927 raise errors_impl.InvalidArgumentError( 928 None, None, "Expected value to be at least a vector, " 929 "but received shape: %s " % value.shape.as_list()) 930 elif sum_lengths.numpy() != value.shape.as_list()[0]: 931 raise errors_impl.InvalidArgumentError( 932 None, None, "Expected sum of lengths to be equal to " 933 "values.shape[0], but sum of lengths is %d and " 934 "value's shape is: %s " % (sum_lengths.numpy(), 935 value.shape.as_list())) 936 elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array): 937 raise errors_impl.InvalidArgumentError( 938 None, None, "TensorArray's size is not equal to the size of " 939 "lengths (%d vs. %d), and the TensorArray is not marked as " 940 "dynamically resizeable." % 941 (len(self._tensor_array), lengths.shape[0])) 942 else: 943 self._tensor_array = array_ops.split(value, lengths, name=name) 944 return self.parent() 945 946 def size(self, name=None): 947 """See TensorArray.""" 948 del name # not meaningful when executing eagerly. 949 return constant_op.constant(len(self._tensor_array)) 950 951 def close(self, name=None): 952 del name # not meaningful when executing eagerly. 953 del self._tensor_array[:] 954 955 956# TensorArray is designed to hide an underlying implementation object 957# and as such accesses many of that object's hidden fields. 958# pylint: disable=protected-access 959# pylint:disable=line-too-long 960@tf_export("TensorArray") 961class TensorArray(object): 962 """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. 963 964 This class is meant to be used with dynamic iteration primitives such as 965 `while_loop` and `map_fn`. It supports gradient back-propagation via special 966 "flow" control flow dependencies. 967 968 Example 1: Plain reading and writing. 969 970 >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False) 971 >>> ta = ta.write(0, 10) 972 >>> ta = ta.write(1, 20) 973 >>> ta = ta.write(2, 30) 974 >>> 975 >>> ta.read(0) 976 <tf.Tensor: shape=(), dtype=float32, numpy=10.0> 977 >>> ta.read(1) 978 <tf.Tensor: shape=(), dtype=float32, numpy=20.0> 979 >>> ta.read(2) 980 <tf.Tensor: shape=(), dtype=float32, numpy=30.0> 981 >>> ta.stack() 982 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.], 983 dtype=float32)> 984 985 Example 2: Fibonacci sequence algorithm that writes in a loop then returns. 986 987 >>> @tf.function 988 ... def fibonacci(n): 989 ... ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True) 990 ... ta = ta.unstack([0., 1.]) 991 ... 992 ... for i in range(2, n): 993 ... ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2)) 994 ... 995 ... return ta.stack() 996 >>> 997 >>> fibonacci(7) 998 <tf.Tensor: shape=(7,), dtype=float32, 999 numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)> 1000 1001 Example 3: A simple loop interacting with a `tf.Variable`. 1002 1003 >>> v = tf.Variable(1) 1004 >>> @tf.function 1005 ... def f(x): 1006 ... ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True) 1007 ... for i in tf.range(x): 1008 ... v.assign_add(i) 1009 ... ta = ta.write(i, v) 1010 ... return ta.stack() 1011 >>> f(5) 1012 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11], 1013 dtype=int32)> 1014 """ 1015 1016 def __init__(self, 1017 dtype, 1018 size=None, 1019 dynamic_size=None, 1020 clear_after_read=None, 1021 tensor_array_name=None, 1022 handle=None, 1023 flow=None, 1024 infer_shape=True, 1025 element_shape=None, 1026 colocate_with_first_write_call=True, 1027 name=None): 1028 """Construct a new TensorArray or wrap an existing TensorArray handle. 1029 1030 A note about the parameter `name`: 1031 1032 The name of the `TensorArray` (even if passed in) is uniquified: each time 1033 a new `TensorArray` is created at runtime it is assigned its own name for 1034 the duration of the run. This avoids name collisions if a `TensorArray` 1035 is created within a `while_loop`. 1036 1037 Args: 1038 dtype: (required) data type of the TensorArray. 1039 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 1040 Required if handle is not provided. 1041 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 1042 can grow the TensorArray past its initial size. Default: False. 1043 clear_after_read: Boolean (optional, default: True). If True, clear 1044 TensorArray values after reading them. This disables read-many 1045 semantics, but allows early release of memory. 1046 tensor_array_name: (optional) Python string: the name of the TensorArray. 1047 This is used when creating the TensorArray handle. If this value is 1048 set, handle should be None. 1049 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 1050 is set, tensor_array_name should be None. Only supported in graph mode. 1051 flow: (optional) A float `Tensor` scalar coming from an existing 1052 `TensorArray.flow`. Only supported in graph mode. 1053 infer_shape: (optional, default: True) If True, shape inference is 1054 enabled. In this case, all elements must have the same shape. 1055 element_shape: (optional, default: None) A `TensorShape` object specifying 1056 the shape constraints of each of the elements of the TensorArray. Need 1057 not be fully defined. 1058 colocate_with_first_write_call: If `True`, the TensorArray will be 1059 colocated on the same device as the Tensor used on its first write 1060 (write operations include `write`, `unstack`, and `split`). If `False`, 1061 the TensorArray will be placed on the device determined by the device 1062 context available during its initialization. 1063 name: A name for the operation (optional). 1064 1065 Raises: 1066 ValueError: if both handle and tensor_array_name are provided. 1067 TypeError: if handle is provided but is not a Tensor. 1068 """ 1069 if (context.executing_eagerly() and 1070 (flow is None or flow.dtype != dtypes.variant)): 1071 # It is possible to create a Variant-style TensorArray even in eager mode, 1072 # and this is fine but can have performance implications in eager. 1073 # An example of when this happens is if a tf.function returns a 1074 # TensorArray in its output; its flow variant object is returned to Eager. 1075 # This can be wrapped back up in a Variant-style TensorArray. 1076 implementation = _EagerTensorArray 1077 elif (flow is not None and flow.dtype == dtypes.variant or 1078 control_flow_util.EnableControlFlowV2(ops.get_default_graph())): 1079 implementation = _GraphTensorArrayV2 1080 else: 1081 implementation = _GraphTensorArray 1082 self._implementation = implementation( 1083 dtype, 1084 size=size, 1085 dynamic_size=dynamic_size, 1086 clear_after_read=clear_after_read, 1087 tensor_array_name=tensor_array_name, 1088 handle=handle, 1089 flow=flow, 1090 infer_shape=infer_shape, 1091 element_shape=element_shape, 1092 colocate_with_first_write_call=colocate_with_first_write_call, 1093 name=name) 1094 1095 self._implementation.parent = weakref.ref(self) 1096 1097 @property 1098 def flow(self): 1099 """The flow `Tensor` forcing ops leading to this TensorArray state.""" 1100 return self._implementation._flow 1101 1102 @property 1103 def dtype(self): 1104 """The data type of this TensorArray.""" 1105 return self._implementation._dtype 1106 1107 @property 1108 def handle(self): 1109 """The reference to the TensorArray.""" 1110 return self._implementation.handle 1111 1112 @property 1113 def element_shape(self): 1114 """The `tf.TensorShape` of elements in this TensorArray.""" 1115 return self._implementation.element_shape 1116 1117 @property 1118 def dynamic_size(self): 1119 """Python bool; if `True` the TensorArray can grow dynamically.""" 1120 return self._implementation._dynamic_size 1121 1122 @property 1123 def _infer_shape(self): 1124 # TODO(slebedev): consider making public or changing TensorArrayStructure 1125 # to access _implementation directly. Note that dynamic_size is also 1126 # only used by TensorArrayStructure. 1127 return self._implementation._infer_shape 1128 1129 def identity(self): 1130 """Returns a TensorArray with the same content and properties. 1131 1132 Returns: 1133 A new TensorArray object with flow that ensures the control dependencies 1134 from the contexts will become control dependencies for writes, reads, etc. 1135 Use this object for all subsequent operations. 1136 """ 1137 return self._implementation.identity() 1138 1139 def grad(self, source, flow=None, name=None): 1140 return self._implementation.grad(source, flow=flow, name=name) 1141 1142 def read(self, index, name=None): 1143 """Read the value at location `index` in the TensorArray. 1144 1145 Args: 1146 index: 0-D. int32 tensor with the index to read from. 1147 name: A name for the operation (optional). 1148 1149 Returns: 1150 The tensor at index `index`. 1151 """ 1152 return self._implementation.read(index, name=name) 1153 1154 @tf_should_use.should_use_result(warn_in_eager=True) 1155 def write(self, index, value, name=None): 1156 """Write `value` into index `index` of the TensorArray. 1157 1158 Args: 1159 index: 0-D. int32 scalar with the index to write to. 1160 value: N-D. Tensor of type `dtype`. The Tensor to write to this index. 1161 name: A name for the operation (optional). 1162 1163 Returns: 1164 A new TensorArray object with flow that ensures the write occurs. 1165 Use this object for all subsequent operations. 1166 1167 Raises: 1168 ValueError: if there are more writers than specified. 1169 """ 1170 return self._implementation.write(index, value, name=name) 1171 1172 def stack(self, name=None): 1173 """Return the values in the TensorArray as a stacked `Tensor`. 1174 1175 All of the values must have been written and their shapes must all match. 1176 If input shapes have rank-`R`, then output shape will have rank-`(R+1)`. 1177 1178 For example: 1179 1180 1181 >>> ta = tf.TensorArray(tf.int32, size=3) 1182 >>> ta.write(0, tf.constant([1, 2])) 1183 >>> ta.write(1, tf.constant([3, 4])) 1184 >>> ta.write(2, tf.constant([5, 6])) 1185 >>> ta.stack() 1186 <tf.Tensor: shape=(3, 2), dtype=int32, numpy= 1187 array([[1, 2], 1188 [3, 4], 1189 [5, 6]], dtype=int32)> 1190 1191 1192 Args: 1193 name: A name for the operation (optional). 1194 1195 Returns: 1196 All the tensors in the TensorArray stacked into one tensor. 1197 """ 1198 return self._implementation.stack(name=name) 1199 1200 def gather(self, indices, name=None): 1201 """Return selected values in the TensorArray as a packed `Tensor`. 1202 1203 All of selected values must have been written and their shapes 1204 must all match. 1205 1206 Args: 1207 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the 1208 `TensorArray` is not dynamic, `max_value=size()`. 1209 name: A name for the operation (optional). 1210 1211 Returns: 1212 The tensors in the `TensorArray` selected by `indices`, packed into one 1213 tensor. 1214 """ 1215 return self._implementation.gather(indices, name=name) 1216 1217 def concat(self, name=None): 1218 """Return the values in the TensorArray as a concatenated `Tensor`. 1219 1220 All of the values must have been written, their ranks must match, and 1221 and their shapes must all match for all dimensions except the first. 1222 1223 Args: 1224 name: A name for the operation (optional). 1225 1226 Returns: 1227 All the tensors in the TensorArray concatenated into one tensor. 1228 """ 1229 return self._implementation.concat(name=name) 1230 1231 @tf_should_use.should_use_result 1232 def unstack(self, value, name=None): 1233 """Unstack the values of a `Tensor` in the TensorArray. 1234 1235 If input value shapes have rank-`R`, then the output TensorArray will 1236 contain elements whose shapes are rank-`(R-1)`. 1237 1238 Args: 1239 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack. 1240 name: A name for the operation (optional). 1241 1242 Returns: 1243 A new TensorArray object with flow that ensures the unstack occurs. 1244 Use this object for all subsequent operations. 1245 1246 Raises: 1247 ValueError: if the shape inference fails. 1248 """ 1249 return self._implementation.unstack(value, name=name) 1250 1251 @tf_should_use.should_use_result 1252 def scatter(self, indices, value, name=None): 1253 """Scatter the values of a `Tensor` in specific indices of a `TensorArray`. 1254 1255 Args: 1256 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the 1257 `TensorArray` is not dynamic, `max_value=size()`. 1258 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack. 1259 name: A name for the operation (optional). 1260 1261 Returns: 1262 A new TensorArray object with flow that ensures the scatter occurs. 1263 Use this object for all subsequent operations. 1264 1265 Raises: 1266 ValueError: if the shape inference fails. 1267 """ 1268 return self._implementation.scatter(indices, value, name=name) 1269 1270 @tf_should_use.should_use_result 1271 def split(self, value, lengths, name=None): 1272 """Split the values of a `Tensor` into the TensorArray. 1273 1274 Args: 1275 value: (N+1)-D. Tensor of type `dtype`. The Tensor to split. 1276 lengths: 1-D. int32 vector with the lengths to use when splitting `value` 1277 along its first dimension. 1278 name: A name for the operation (optional). 1279 1280 Returns: 1281 A new TensorArray object with flow that ensures the split occurs. 1282 Use this object for all subsequent operations. 1283 1284 Raises: 1285 ValueError: if the shape inference fails. 1286 """ 1287 return self._implementation.split(value, lengths, name=name) 1288 1289 def size(self, name=None): 1290 """Return the size of the TensorArray.""" 1291 return self._implementation.size(name=name) 1292 1293 @tf_should_use.should_use_result 1294 def close(self, name=None): 1295 """Close the current TensorArray.""" 1296 return self._implementation.close(name=name) 1297 1298 1299def build_ta_with_new_flow(old_ta, flow): 1300 """Builds a TensorArray with a new `flow` tensor.""" 1301 # Sometimes we get old_ta as the implementation, sometimes it's the 1302 # TensorArray wrapper object. 1303 impl = (old_ta._implementation if isinstance(old_ta, TensorArray) else old_ta) 1304 1305 if not context.executing_eagerly(): 1306 if (not isinstance(impl, _GraphTensorArrayV2) and 1307 control_flow_util.EnableControlFlowV2(ops.get_default_graph())): 1308 raise NotImplementedError("Attempting to build a graph-mode TF2-style " 1309 "TensorArray from either an eager-mode " 1310 "TensorArray or a TF1-style TensorArray. " 1311 "This is not currently supported. You may be " 1312 "attempting to capture a TensorArray " 1313 "inside a tf.function or tf.data map function. " 1314 "Instead, construct a new TensorArray inside " 1315 "the function.") 1316 new_ta = TensorArray( 1317 dtype=impl.dtype, 1318 handle=impl.handle, 1319 flow=flow, 1320 infer_shape=impl._infer_shape, 1321 colocate_with_first_write_call=impl._colocate_with_first_write_call) 1322 new_impl = new_ta._implementation 1323 new_impl._dynamic_size = impl._dynamic_size 1324 new_impl._size = impl._size 1325 new_impl._colocate_with = impl._colocate_with 1326 new_impl._element_shape = impl._element_shape # Share _element_shape. 1327 return new_ta 1328 1329 1330# pylint: enable=protected-access 1331 1332 1333def _check_dtypes(value, dtype): 1334 if value.dtype != dtype: 1335 logging.error("Error: Input value {} has dtype {}, but expected dtype {}. " 1336 "This leads to undefined behavior and will be an error " 1337 "in future versions of TensorFlow. Traceback:\n{}".format( 1338 value, str(value.dtype), str(dtype), 1339 "".join(traceback.format_stack()))) 1340 1341 1342@tf_export("TensorArraySpec") 1343@type_spec.register("tf.TensorArraySpec") 1344class TensorArraySpec(type_spec.TypeSpec): 1345 """Type specification for a `tf.TensorArray`.""" 1346 1347 __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"] 1348 1349 value_type = property(lambda self: TensorArray) 1350 1351 def __init__(self, 1352 element_shape=None, 1353 dtype=dtypes.float32, 1354 dynamic_size=False, 1355 infer_shape=True): 1356 """Constructs a type specification for a `tf.TensorArray`. 1357 1358 Args: 1359 element_shape: The shape of each element in the `TensorArray`. 1360 dtype: Data type of the `TensorArray`. 1361 dynamic_size: Whether the `TensorArray` can grow past its initial size. 1362 infer_shape: Whether shape inference is enabled. 1363 """ 1364 self._element_shape = tensor_shape.as_shape(element_shape) 1365 self._dtype = dtypes.as_dtype(dtype) 1366 self._dynamic_size = dynamic_size 1367 self._infer_shape = infer_shape 1368 1369 def is_compatible_with(self, other): 1370 # pylint: disable=protected-access 1371 if not isinstance(other, type_spec.TypeSpec): 1372 other = type_spec.type_spec_from_value(other) 1373 1374 # Note: we intentionally exclude infer_shape in this check. 1375 return (isinstance(other, TensorArraySpec) and 1376 self._dtype.is_compatible_with(other._dtype) and 1377 self._element_shape.is_compatible_with(other._element_shape) and 1378 self._dynamic_size == other._dynamic_size) 1379 1380 def most_specific_compatible_type(self, other): 1381 # pylint: disable=protected-access 1382 if not self.is_compatible_with(other): 1383 raise ValueError(f"Type `{self}` is not compatible with `{other}`.") 1384 infer_shape = self._infer_shape and other._infer_shape 1385 return TensorArraySpec( 1386 self._element_shape.most_specific_compatible_shape( 1387 other._element_shape), self._dtype, self._dynamic_size, infer_shape) 1388 1389 def _serialize(self): 1390 return (self._element_shape, self._dtype, self._dynamic_size, 1391 self._infer_shape) 1392 1393 @property 1394 def _component_specs(self): 1395 return [tensor_spec.TensorSpec([], dtypes.variant)] 1396 1397 def _to_components(self, value): 1398 if not isinstance(value, TensorArray): 1399 raise TypeError("Expected value to be a TensorArray, but got: `{}`".format( 1400 type(value))) 1401 if value.flow is not None and value.flow.dtype == dtypes.variant: 1402 return [value.flow] 1403 else: 1404 # Convert to a TF2-style TensorArray. 1405 # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or 1406 # "implementation / as_variant" arg to TensorArray constructor. 1407 with ops.name_scope("convert_tensor_array"): 1408 flow = list_ops.tensor_list_from_tensor( 1409 tensor=value.stack(), element_shape=value.element_shape) 1410 return [flow] 1411 1412 def _from_components(self, tensor_list): 1413 # This will return a TF2 Graph-style TensorArray because tensor_list[0] is 1414 # a variant object. size == -1 implies unknown size. 1415 ret = TensorArray( 1416 dtype=self._dtype, 1417 flow=tensor_list[0], 1418 dynamic_size=self._dynamic_size, 1419 infer_shape=self._infer_shape) 1420 ret._implementation._element_shape = [self._element_shape] # pylint: disable=protected-access 1421 return ret 1422 1423 @staticmethod 1424 def from_value(value): 1425 if not isinstance(value, TensorArray): 1426 raise TypeError("Expected value to be a TensorArray, but got: `{}`".format( 1427 type(value))) 1428 1429 return TensorArraySpec( 1430 dtype=value.dtype, 1431 element_shape=value.element_shape, 1432 dynamic_size=value.dynamic_size, 1433 infer_shape=value._infer_shape) # pylint: disable=protected-access 1434 1435 def _to_legacy_output_types(self): 1436 return self._dtype 1437 1438 def _to_legacy_output_shapes(self): 1439 # Sneak the dynamic_size and infer_shape values into the legacy shape. 1440 return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape 1441 ]).concatenate(self._element_shape)) 1442 1443 def _to_legacy_output_classes(self): 1444 return TensorArray 1445 1446 1447# Register the TypeSpec for TensorArray. If TensorArray is updated to be a 1448# CompositeTensor, then this registration can be deleted. 1449type_spec.register_type_spec_from_value_converter( 1450 TensorArray, TensorArraySpec.from_value, allow_subclass=True) 1451