1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorArray: a dynamically sized array of Tensors.""" 16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23 24import traceback 25import weakref 26 27import numpy as np 28 29from tensorflow.python.eager import context 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors_impl 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_spec 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.framework import type_spec 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_util 40from tensorflow.python.ops import gen_control_flow_ops 41from tensorflow.python.ops import gen_data_flow_ops 42from tensorflow.python.ops import list_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import tf_should_use 46from tensorflow.python.util.tf_export import tf_export 47 48 49# _GraphTensorArray accesses many of the hidden generated ops, but is in 50# fact built to wrap these methods. 51# pylint: disable=protected-access 52class _GraphTensorArray(object): 53 """Graph-mode implementation of TensorArray. 54 """ 55 56 def __init__(self, 57 dtype, 58 size=None, 59 dynamic_size=None, 60 clear_after_read=None, 61 tensor_array_name=None, 62 handle=None, 63 flow=None, 64 infer_shape=True, 65 element_shape=None, 66 colocate_with_first_write_call=True, 67 name=None): 68 """Constructs a graph mode TensorArray. 69 70 Args: 71 dtype: (required) data type of the TensorArray. 72 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 73 Required if handle is not provided. 74 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 75 can grow the TensorArray past its initial size. Default: False. 76 clear_after_read: Boolean (optional, default: True). If True, clear 77 TensorArray values after reading them. This disables read-many 78 semantics, but allows early release of memory. 79 tensor_array_name: (optional) Python string: the name of the TensorArray. 80 This is used when creating the TensorArray handle. If this value is 81 set, handle should be None. 82 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 83 is set, tensor_array_name should be None. Only supported in graph mode. 84 flow: (optional) A float `Tensor` scalar coming from an existing 85 `TensorArray.flow`. Only supported in graph mode. 86 infer_shape: (optional, default: True) If True, shape inference 87 is enabled. In this case, all elements must have the same shape. 88 element_shape: (optional, default: None) A `TensorShape` object specifying 89 the shape constraints of each of the elements of the TensorArray. 90 Need not be fully defined. 91 colocate_with_first_write_call: If `True`, the TensorArray will be 92 colocated on the same device as the Tensor used on its first write 93 (write operations include `write`, `unstack`, and `split`). If `False`, 94 the TensorArray will be placed on the device determined by the 95 device context available during its initialization. 96 name: A name for the operation (optional). 97 98 Raises: 99 ValueError: if both handle and tensor_array_name are provided. 100 TypeError: if handle is provided but is not a Tensor. 101 """ 102 if handle is not None and tensor_array_name: 103 raise ValueError( 104 "Cannot construct with both handle and tensor_array_name") 105 if handle is not None and not isinstance(handle, ops.Tensor): 106 raise TypeError("Handle must be a Tensor") 107 if handle is None and size is None: 108 raise ValueError("Size must be provided if handle is not provided") 109 if handle is not None and size is not None: 110 raise ValueError("Cannot provide both a handle and size " 111 "at the same time") 112 if handle is not None and element_shape is not None: 113 raise ValueError("Cannot provide both a handle and element_shape " 114 "at the same time") 115 if handle is not None and dynamic_size is not None: 116 raise ValueError("Cannot provide both a handle and dynamic_size " 117 "at the same time") 118 if handle is not None and clear_after_read is not None: 119 raise ValueError("Cannot provide both a handle and clear_after_read " 120 "at the same time") 121 122 if clear_after_read is None: 123 clear_after_read = True 124 self._dynamic_size = dynamic_size or False 125 self._dtype = dtypes.as_dtype(dtype).base_dtype 126 127 # Used to keep track of what tensors the TensorArray should be 128 # colocated with. We choose to colocate the TensorArray with the 129 # first tensor written to it. 130 self._colocate_with_first_write_call = colocate_with_first_write_call 131 if colocate_with_first_write_call: 132 self._colocate_with = [] 133 else: 134 self._colocate_with = None 135 136 # Record the current static shape for the array elements. The element 137 # shape is defined either by `element_shape` or the shape of the tensor 138 # of the first write. If `infer_shape` is true, all writes checks for 139 # shape equality. 140 self._element_shape = [tensor_shape.as_shape(element_shape)] 141 self._infer_shape = infer_shape 142 self._size = size 143 with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: 144 if handle is not None: 145 self._handle = handle 146 if flow is None: 147 raise ValueError("flow must not be None if handle is not None.") 148 self._flow = flow 149 else: 150 # Construct the TensorArray with an empty device. The first 151 # write into the TensorArray from a Tensor with a set device 152 # will retroactively set the device value of this op. 153 def create(): 154 """Create the TensorArray op.""" 155 return gen_data_flow_ops.tensor_array_v3( 156 dtype=dtype, 157 size=size, 158 element_shape=element_shape, 159 identical_element_shapes=infer_shape, 160 dynamic_size=self._dynamic_size, 161 clear_after_read=clear_after_read, 162 tensor_array_name=tensor_array_name, 163 name=scope) 164 if colocate_with_first_write_call: 165 with ops.device(None), ops.colocate_with(None, ignore_existing=True): 166 self._handle, self._flow = create() 167 else: 168 self._handle, self._flow = create() 169 170 @property 171 def flow(self): 172 return self._flow 173 174 @property 175 def dtype(self): 176 return self._dtype 177 178 @property 179 def handle(self): 180 return self._handle 181 182 @property 183 def element_shape(self): 184 return self._element_shape[0] 185 186 def _check_element_shape(self, shape): 187 """Changes the element shape of the array given a shape to merge with. 188 189 Args: 190 shape: A `TensorShape` object to merge with. 191 192 Raises: 193 ValueError: if the provided shape is incompatible with the current 194 element shape of the `TensorArray`. 195 """ 196 if not shape.is_compatible_with(self.element_shape): 197 raise ValueError("Inconsistent shapes: saw %s but expected %s " % 198 (shape, self.element_shape)) 199 if self._infer_shape: 200 self._element_shape[0] = self.element_shape.merge_with(shape) 201 202 @contextlib.contextmanager 203 def _maybe_colocate_with(self, value): 204 """Colocate operations with an internal colocation group or `value`. 205 206 Args: 207 value: `Tensor`, the tensor to try to colocate with. 208 209 Yields: 210 Does not yield anything, but the new context is a colocation context. 211 212 If no internal colocation group is set, colocate with `value` and set 213 the internal colocation group to be value. 214 """ 215 if not self._colocate_with_first_write_call: 216 yield 217 else: 218 if not self._colocate_with: 219 self._colocate_with.append(value) 220 with ops.colocate_with(self._colocate_with[0]): 221 yield 222 223 def identity(self): 224 """See TensorArray.""" 225 flow = array_ops.identity(self._flow) 226 return build_ta_with_new_flow(self, flow) 227 228 def grad(self, source, flow=None, name=None): 229 """See TensorArray.""" 230 # tensor_array_grad requires a flow input when forward 231 # TensorArrays are dynamically sized. This forces the creation 232 # of the grad TensorArray only once the final forward array's size 233 # is fixed. 234 if flow is None: 235 flow = self.flow 236 with ops.name_scope(name, "TensorArrayGrad", [self._handle]): 237 with ops.colocate_with(self._handle): 238 g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3( 239 handle=self._handle, source=source, flow_in=flow, name=name) 240 with ops.control_dependencies([g_handle]): 241 flow = array_ops.identity(flow, name="gradient_flow") 242 g = TensorArray( 243 dtype=self._dtype, 244 handle=g_handle, 245 flow=flow, 246 infer_shape=self._infer_shape, 247 colocate_with_first_write_call=False) 248 # pylint: disable=protected-access 249 g._implementation._element_shape = self._element_shape 250 # pylint: enable=protected-access 251 return g 252 253 def read(self, index, name=None): 254 """See TensorArray.""" 255 value = gen_data_flow_ops.tensor_array_read_v3( 256 handle=self._handle, 257 index=index, 258 flow_in=self._flow, 259 dtype=self._dtype, 260 name=name) 261 if self._element_shape: 262 value.set_shape(self._element_shape[0].dims) 263 return value 264 265 def write(self, index, value, name=None): 266 """See TensorArray.""" 267 with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]): 268 # TODO(b/129870929): Fix after all callers provide proper init dtype. 269 value = ops.convert_to_tensor( 270 value, preferred_dtype=self._dtype, name="value") 271 _check_dtypes(value, self._dtype) 272 self._check_element_shape(value.shape) 273 with self._maybe_colocate_with(value): 274 flow_out = gen_data_flow_ops.tensor_array_write_v3( 275 handle=self._handle, 276 index=index, 277 value=value, 278 flow_in=self._flow, 279 name=name) 280 return build_ta_with_new_flow(self, flow_out) 281 282 def stack(self, name=None): 283 """See TensorArray.""" 284 with ops.colocate_with(self._handle): 285 with ops.name_scope(name, "TensorArrayStack", [self._handle]): 286 value = self.gather(math_ops.range(0, self.size()), name=name) 287 if (self.element_shape and not self._dynamic_size and 288 self._size is not None): 289 value.set_shape([tensor_util.constant_value(self._size)] + 290 self.element_shape.dims) 291 return value 292 293 def gather(self, indices, name=None): 294 """See TensorArray.""" 295 if self._element_shape: 296 element_shape = self._element_shape[0] 297 else: 298 element_shape = tensor_shape.unknown_shape(None) 299 value = gen_data_flow_ops.tensor_array_gather_v3( 300 handle=self._handle, 301 indices=indices, 302 flow_in=self._flow, 303 dtype=self._dtype, 304 name=name, 305 element_shape=element_shape) 306 if self.element_shape: 307 value.set_shape([None] + self.element_shape.dims) 308 return value 309 310 def concat(self, name=None): 311 """See TensorArray.""" 312 value, _ = gen_data_flow_ops.tensor_array_concat_v3( 313 handle=self._handle, 314 flow_in=self._flow, 315 dtype=self._dtype, 316 name=name, 317 element_shape_except0=self.element_shape[1:]) 318 if self.element_shape: 319 value.set_shape([None] + self.element_shape.dims[1:]) 320 return value 321 322 @tf_should_use.should_use_result 323 def unstack(self, value, name=None): 324 """See TensorArray.""" 325 with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]): 326 num_elements = array_ops.shape(value)[0] 327 return self.scatter( 328 indices=math_ops.range(0, num_elements), value=value, name=name) 329 330 @tf_should_use.should_use_result 331 def scatter(self, indices, value, name=None): 332 """See TensorArray.""" 333 with ops.name_scope(name, "TensorArrayScatter", 334 [self._handle, value, indices]): 335 # TODO(b/129870929): Fix after all callers provide proper init dtype. 336 value = ops.convert_to_tensor( 337 value, preferred_dtype=self._dtype, name="value") 338 _check_dtypes(value, self._dtype) 339 if not context.executing_eagerly(): 340 self._check_element_shape(value.shape[1:]) 341 with self._maybe_colocate_with(value): 342 flow_out = gen_data_flow_ops.tensor_array_scatter_v3( 343 handle=self._handle, 344 indices=indices, 345 value=value, 346 flow_in=self._flow, 347 name=name) 348 return build_ta_with_new_flow(self, flow_out) 349 350 @tf_should_use.should_use_result 351 def split(self, value, lengths, name=None): 352 """See TensorArray.""" 353 with ops.name_scope(name, "TensorArraySplit", 354 [self._handle, value, lengths]): 355 value = ops.convert_to_tensor(value, dtype=self._dtype, name="value") 356 with self._maybe_colocate_with(value): 357 lengths_64 = math_ops.cast(lengths, dtypes.int64) 358 if not context.executing_eagerly(): 359 clengths = tensor_util.constant_value(lengths_64) 360 if value.shape.dims is not None and clengths is not None: 361 if clengths.shape and clengths.max() == clengths.min(): 362 self._check_element_shape( 363 tensor_shape.TensorShape([clengths[0]]).concatenate( 364 value.shape[1:])) 365 flow_out = gen_data_flow_ops.tensor_array_split_v3( 366 handle=self._handle, 367 value=value, 368 lengths=lengths_64, 369 flow_in=self._flow, 370 name=name) 371 return build_ta_with_new_flow(self, flow_out) 372 373 def size(self, name=None): 374 """See TensorArray.""" 375 if not self._dynamic_size and self._size is not None: 376 return ops.convert_to_tensor(self._size, dtype=dtypes.int32) 377 else: 378 return gen_data_flow_ops.tensor_array_size_v3( 379 handle=self._handle, flow_in=self.flow, name=name) 380 381 @tf_should_use.should_use_result 382 def close(self, name=None): 383 """See TensorArray.""" 384 return gen_data_flow_ops.tensor_array_close_v3( 385 handle=self._handle, name=name) 386 387 388class _GraphTensorArrayV2(object): 389 """Graph-mode implementation of TensorArray backed by TensorLists. 390 391 The backing tensor of this TensorArray is a TensorList variant tensor which is 392 stored in the `flow`. The `handle` is always none here. The reason we use the 393 `flow` field and not the `handle` field is to ensure backwards compatibility 394 with legacy control flow. 395 """ 396 397 def __init__(self, 398 dtype, 399 size=None, 400 dynamic_size=None, 401 clear_after_read=None, 402 tensor_array_name=None, 403 handle=None, 404 flow=None, 405 infer_shape=True, 406 element_shape=None, 407 colocate_with_first_write_call=True, 408 name=None): 409 """Constructs a graph mode TensorArray. 410 411 Args: 412 dtype: (required) data type of the TensorArray. 413 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 414 Required if flow is not provided. 415 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 416 can grow the TensorArray past its initial size. Default: False. 417 clear_after_read: (optional) unused. Not supported in TensorLists. 418 tensor_array_name: (optional) unused. 419 handle: (optional) Must always be None. 420 flow: (optional) A variant `Tensor` scalar for a TensorList. 421 infer_shape: (optional, default: True) If True, shape inference is 422 enabled. In this case, all elements must have the same shape. 423 element_shape: (optional, default: None) A `TensorShape` object specifying 424 the shape constraints of each of the elements of the TensorArray. Need 425 not be fully defined. 426 colocate_with_first_write_call: (optional). unused. 427 name: (optional) A name for the operation. 428 429 Raises: 430 ValueError: if both handle and tensor_array_name are provided. 431 TypeError: if handle is provided but is not a Tensor. 432 """ 433 assert handle is None 434 del handle 435 del clear_after_read 436 del tensor_array_name 437 del colocate_with_first_write_call 438 439 self._dynamic_size = dynamic_size 440 self._size = size 441 442 if (flow is not None and 443 (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)): 444 raise TypeError("flow must be a variant tensor") 445 if flow is None and size is None: 446 raise ValueError("Size must be provided if flow is not provided") 447 if flow is not None and size is not None: 448 raise ValueError("Cannot provide both a flow and size " 449 "at the same time") 450 if flow is not None and element_shape is not None: 451 raise ValueError("Cannot provide both a flow and element_shape " 452 "at the same time") 453 454 self._dtype = dtypes.as_dtype(dtype).base_dtype 455 456 # Record the current static shape for the array elements. The element 457 # shape is defined either by `element_shape` or the shape of the tensor 458 # of the first write. If `infer_shape` is true, all writes checks for 459 # shape equality. 460 self._element_shape = [tensor_shape.as_shape(element_shape)] 461 self._infer_shape = infer_shape 462 with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope: 463 if flow is None: 464 self._flow = list_ops.tensor_list_reserve( 465 element_shape=element_shape, 466 num_elements=size, 467 element_dtype=dtype, 468 name=scope) 469 else: 470 self._flow = flow 471 472 # For backwards compatibility. 473 self._colocate_with_first_write_call = None 474 self._colocate_with = None 475 476 @property 477 def flow(self): 478 return self._flow 479 480 @property 481 def dtype(self): 482 return self._dtype 483 484 @property 485 def element_shape(self): 486 return self._element_shape[0] 487 488 @property 489 def handle(self): 490 # We intentionally do not raise an error so that legacy while_loop does not 491 # complain. 492 return None 493 494 def _check_element_shape(self, shape): 495 """Changes the element shape of the array given a shape to merge with. 496 497 Args: 498 shape: A `TensorShape` object to merge with. 499 500 Raises: 501 ValueError: if the provided shape is incompatible with the current 502 element shape of the `TensorArray`. 503 """ 504 if not shape.is_compatible_with(self.element_shape): 505 raise ValueError("Inconsistent shapes: saw %s but expected %s " % 506 (shape, self.element_shape)) 507 if self._infer_shape: 508 self._element_shape[0] = self.element_shape.merge_with(shape) 509 510 def identity(self): 511 """See TensorArray.""" 512 flow = array_ops.identity(self._flow) 513 return build_ta_with_new_flow(self, flow) 514 515 def grad(self, source, flow=None, name=None): 516 """Not supported.""" 517 raise NotImplementedError() 518 519 def read(self, index, name=None): 520 """See TensorArray.""" 521 with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]): 522 value = list_ops.tensor_list_get_item( 523 input_handle=self._flow, 524 index=index, 525 element_dtype=self._dtype, 526 element_shape=self.element_shape, 527 name=name) 528 return value 529 530 def write(self, index, value, name=None): 531 """See TensorArray.""" 532 with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]): 533 # TODO(b/129870929): Fix after all callers provide proper init dtype. 534 value = ops.convert_to_tensor( 535 value, preferred_dtype=self._dtype, name="value") 536 _check_dtypes(value, self._dtype) 537 self._check_element_shape(value.shape) 538 flow_out = list_ops.tensor_list_set_item( 539 input_handle=self._flow, 540 index=index, 541 item=value, 542 resize_if_index_out_of_bounds=self._dynamic_size, 543 name=name) 544 return build_ta_with_new_flow(self, flow_out) 545 546 def stack(self, name=None): 547 """See TensorArray.""" 548 with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]): 549 # TODO(b/139941163): remove constant_value after changing num_elements to regular input 550 if not self._dynamic_size and self._size is not None: 551 ta_size = tensor_util.constant_value(self._size) 552 else: 553 ta_size = -1 554 value = list_ops.tensor_list_stack( 555 input_handle=self._flow, 556 element_dtype=self._dtype, 557 num_elements=ta_size, 558 element_shape=self.element_shape) 559 return value 560 561 def gather(self, indices, name=None): 562 """See TensorArray.""" 563 value = list_ops.tensor_list_gather( 564 input_handle=self._flow, 565 indices=indices, 566 element_dtype=self._dtype, 567 element_shape=self.element_shape, 568 name=name) 569 return value 570 571 def concat(self, name=None): 572 """See TensorArray.""" 573 if self.element_shape: 574 element_shape = [None] + self.element_shape.dims[1:] 575 else: 576 element_shape = None 577 578 value = list_ops.tensor_list_concat( 579 input_handle=self._flow, 580 element_dtype=self._dtype, 581 element_shape=element_shape, 582 name=name) 583 return value 584 585 @tf_should_use.should_use_result 586 def unstack(self, value, name=None): 587 """See TensorArray.""" 588 with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]): 589 # TODO(b/129870929): Fix after all callers provide proper init dtype. 590 value = ops.convert_to_tensor( 591 value, preferred_dtype=self._dtype, name="value") 592 _check_dtypes(value, self._dtype) 593 self._check_element_shape(value.shape[1:]) 594 flow_out = list_ops.tensor_list_from_tensor( 595 tensor=value, element_shape=value.shape[1:]) 596 return build_ta_with_new_flow(self, flow_out) 597 598 @tf_should_use.should_use_result 599 def scatter(self, indices, value, name=None): 600 """See TensorArray.""" 601 with ops.name_scope(name, "TensorArrayScatter", 602 [self._flow, value, indices]): 603 # TODO(b/129870929): Fix after all callers provide proper init dtype. 604 value = ops.convert_to_tensor( 605 value, preferred_dtype=self._dtype, name="value") 606 _check_dtypes(value, self._dtype) 607 self._check_element_shape(value.shape[1:]) 608 flow_out = list_ops.tensor_list_scatter( 609 tensor=value, indices=indices, element_shape=self.element_shape, 610 input_handle=self._flow) 611 return build_ta_with_new_flow(self, flow_out) 612 613 @tf_should_use.should_use_result 614 def split(self, value, lengths, name=None): 615 """See TensorArray.""" 616 with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]): 617 # TODO(b/129870929): Fix after all callers provide proper init dtype. 618 value = ops.convert_to_tensor( 619 value, preferred_dtype=self._dtype, name="value") 620 _check_dtypes(value, self._dtype) 621 lengths_64 = math_ops.cast(lengths, dtypes.int64) 622 if not context.executing_eagerly(): 623 clengths = tensor_util.constant_value(lengths_64) 624 if value.shape.dims is not None and clengths is not None: 625 if clengths.shape and clengths.max() == clengths.min(): 626 self._check_element_shape( 627 tensor_shape.TensorShape([clengths[0]]).concatenate( 628 value.shape[1:])) 629 flow_out = list_ops.tensor_list_split( 630 tensor=value, 631 lengths=lengths_64, 632 element_shape=self.element_shape, 633 name=name) 634 return build_ta_with_new_flow(self, flow_out) 635 636 def size(self, name=None): 637 """See TensorArray.""" 638 if not self._dynamic_size and self._size is not None: 639 return ops.convert_to_tensor(self._size, dtype=dtypes.int32) 640 else: 641 return list_ops.tensor_list_length(input_handle=self._flow, name=name) 642 643 def close(self, name=None): 644 """See TensorArray.""" 645 return gen_control_flow_ops.no_op(name=name) 646 647# pylint: enable=protected-access 648 649 650class _EagerTensorArray(object): 651 """Eager-compatible implementation of TensorArray. 652 """ 653 654 def __init__(self, 655 dtype, 656 size=None, 657 dynamic_size=None, 658 clear_after_read=None, 659 tensor_array_name=None, 660 handle=None, 661 flow=None, 662 infer_shape=True, 663 element_shape=None, 664 colocate_with_first_write_call=True, 665 name=None): 666 """Constructs a TensorArray compatible with eager execution. 667 668 Args: 669 dtype: (required) data type of the TensorArray. 670 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 671 Required if handle is not provided. 672 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 673 can grow the TensorArray past its initial size. Default: False. 674 clear_after_read: Boolean (optional, default: True). If True, clear 675 TensorArray values after reading them. This disables read-many 676 semantics, but allows early release of memory. 677 tensor_array_name: unused. 678 handle: unsupported. 679 flow: unsupported. 680 infer_shape: used for error checking, same semantics as TensorArray. 681 element_shape: used for error checking, same semantics as TensorArray. 682 colocate_with_first_write_call: unsupported. 683 name: unsupported. 684 685 Raises: 686 ValueError: handle or flow are supplied, or if size is not supplied. 687 """ 688 689 del (flow, tensor_array_name, name) # Unused. 690 691 if handle is not None: 692 raise ValueError("TensorArray handles are not supported when eager " 693 "execution is enabled.") 694 if size is None: 695 raise ValueError("Size must be declared for TensorArrays when eager " 696 "execution is enabled.") 697 698 # These attributes are not meaningful when eager is enabled, but some 699 # library functions (e.g., those in control_flow_ops.py) access them to 700 # create new tensor arrays; as such, we define them for the sake of 701 # compatibility. 702 self._handle = None 703 # we assign a dummy value to _flow in case other code assumes it to be 704 # a Tensor 705 self._flow = constant_op.constant(0, dtype=dtypes.int32) 706 self._infer_shape = infer_shape 707 self._element_shape = tensor_shape.as_shape(element_shape) 708 self._colocate_with_first_write_call = colocate_with_first_write_call 709 710 self._dtype = dtypes.as_dtype(dtype).base_dtype 711 self._dynamic_size = dynamic_size or False 712 self._clear_after_read = ( 713 True if clear_after_read is None else clear_after_read) 714 self._previously_read_indices = [] 715 716 if isinstance(size, ops.EagerTensor): 717 size = size.numpy() 718 self._tensor_array = [None for _ in range(size)] 719 720 @property 721 def flow(self): 722 """For compatibility; flows are not meaningful when eager is enabled.""" 723 return self._flow 724 725 @property 726 def dtype(self): 727 return self._dtype 728 729 @property 730 def handle(self): 731 """For compatibility; handles are not meaningful when eager is enabled.""" 732 return self._handle 733 734 @property 735 def element_shape(self): 736 return self._element_shape 737 738 def identity(self): 739 """See TensorArray.""" 740 return self.parent() 741 742 def grad(self, source, flow=None, name=None): 743 raise NotImplementedError( 744 "TensorArray.grad is not supported when executing eagerly; eager's " 745 "gradient implementation does not use/need this function to compute " 746 "gradients of operations that use TensorArrays.") 747 748 def read(self, index, name=None): 749 """See TensorArray.""" 750 del name # not meaningful when executing eagerly. 751 752 if isinstance(index, ops.EagerTensor): 753 index = index.numpy() 754 755 if index < 0: 756 raise errors_impl.OutOfRangeError( 757 None, None, 758 "Reading from negative indices (index %d) is not allowed." % index) 759 760 if index >= len(self._tensor_array): 761 raise errors_impl.OutOfRangeError( 762 None, None, "Tried to read from index %d but array size is: %d" % 763 (index, len(self._tensor_array))) 764 765 tensor = self._tensor_array[index] 766 if tensor is None: 767 if index in self._previously_read_indices: 768 raise errors_impl.InvalidArgumentError( 769 None, None, 770 "Could not read index %d twice because it was cleared after " 771 "a previous read (perhaps try setting clear_after_read = false?)" % 772 index) 773 else: 774 tensor = self._maybe_zero(index) 775 776 if self._clear_after_read: 777 self._tensor_array[index] = None 778 self._previously_read_indices.append(index) 779 return tensor 780 781 def _write(self, index, value): 782 """Writes `value` into index named by `index`. 783 784 Args: 785 index: 0-D. int32 scalar with the index to write to. 786 value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`. 787 788 Raises: 789 errors_impl.InvalidArgumentError: `value` dtype does not match dtype. 790 errors_impl.OutOfRangeError: `index` is out of bounds. 791 ValueError: shape of `value` is not consistent with inferred shape. 792 """ 793 794 if isinstance(index, ops.EagerTensor): 795 index = index.numpy() 796 797 if index < 0: 798 raise errors_impl.OutOfRangeError( 799 None, None, 800 "Writing to negative indices (index %d) is not allowed." % index) 801 802 size = len(self._tensor_array) 803 if index >= size: 804 if not self._dynamic_size: 805 raise errors_impl.OutOfRangeError( 806 None, None, 807 "Tried to write to index %d but array is not resizeable and size " 808 "is: %d" % (index, size)) 809 self._tensor_array.extend(None for _ in range(index - size + 1)) 810 811 if not isinstance(value, ops.EagerTensor): 812 # TODO(b/129870929): Fix after all callers provide proper init dtype. 813 value = ops.convert_to_tensor( 814 value, preferred_dtype=self._dtype, name="value") 815 816 if self._dtype != value.dtype: 817 raise errors_impl.InvalidArgumentError( 818 None, None, 819 "TensorArray dtype is %s but Op is trying to write dtype %s" % 820 (self._dtype.name, value.dtype.name)) 821 822 if not self._element_shape.is_compatible_with(value.shape): 823 raise ValueError("Incompatible shape for value (%s), expected (%s)" % 824 (value.shape, self._element_shape)) 825 826 if self._infer_shape: 827 self._element_shape = self._element_shape.merge_with(value.shape) 828 829 self._tensor_array[index] = value 830 831 def write(self, index, value, name=None): 832 """See TensorArray.""" 833 del name # not meaningful when executing eagerly. 834 self._write(index, value) 835 return self.parent() 836 837 def _maybe_zero(self, ix): 838 val = self._tensor_array[ix] 839 if val is None: 840 val = self._tensor_array[ix] = array_ops.zeros( 841 shape=self._element_shape, dtype=self._dtype) 842 return val 843 844 def stack(self, name=None): 845 """See TensorArray.""" 846 if self._tensor_array: 847 for ix in range(len(self._tensor_array)): 848 self._maybe_zero(ix) 849 if not self._tensor_array and self._element_shape.is_fully_defined(): 850 return ops.convert_to_tensor( 851 np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype) 852 else: 853 return ops.convert_to_tensor( 854 self._tensor_array, name=name, dtype=self._dtype) 855 856 def gather(self, indices, name=None): 857 """See TensorArray.""" 858 del name # not meaningful when executing eagerly. 859 if isinstance(indices, ops.EagerTensor): 860 indices = indices.numpy() 861 return array_ops.stack([self._maybe_zero(i) for i in indices]) 862 863 def concat(self, name=None): 864 """See TensorArray.""" 865 try: 866 return array_ops.concat( 867 [self._maybe_zero(ix) for ix in range(len(self._tensor_array))], 868 0, name=name) 869 except errors_impl.OpError: 870 # Reproduce a subset of the error-handling for graph-mode TensorArrays. 871 shapes = [t.shape for t in self._tensor_array] 872 ndims = [s.ndims for s in shapes] 873 if 0 in ndims: 874 idx = ndims.index(0) 875 raise errors_impl.InvalidArgumentError( 876 None, None, "Concat saw a scalar shape at index %d but requires " 877 "at least vectors." % idx) 878 else: 879 raise 880 881 def unstack(self, value, name=None): 882 """See TensorArray.""" 883 tensors = array_ops.unstack(value, name=name) 884 if len(tensors) > len(self._tensor_array) and not self._dynamic_size: 885 raise ValueError( 886 "Cannot unstack %d tensors into a TensorArray of static size %d" % 887 (len(tensors), len(self._tensor_array))) 888 self._tensor_array = tensors 889 return self.parent() 890 891 def scatter(self, indices, value, name=None): 892 """See TensorArray.""" 893 del name # not meaningful when executing eagerly. 894 if isinstance(indices, ops.EagerTensor): 895 indices = indices.numpy() 896 for index, val in zip(indices, array_ops.unstack(value)): 897 self._write(index, val) # pylint: disable=protected-access 898 return self.parent() 899 900 def split(self, value, lengths, name=None): 901 """See TensorArray.""" 902 # TODO(b/129870929): Fix after all callers provide proper init dtype. 903 value = ops.convert_to_tensor( 904 value, preferred_dtype=self._dtype, name="value") 905 _check_dtypes(value, self._dtype) 906 lengths = ops.convert_to_tensor(lengths) 907 sum_lengths = math_ops.reduce_sum(lengths) 908 if lengths.shape.ndims != 1: 909 raise errors_impl.InvalidArgumentError( 910 None, None, "Expected lengths to be a vector, received shape: %s" % 911 lengths.shape.as_list()) 912 elif value.shape.ndims == 0: 913 raise errors_impl.InvalidArgumentError( 914 None, None, "Expected value to be at least a vector, " 915 "but received shape: %s" % value.shape.as_list()) 916 elif sum_lengths.numpy() != value.shape.as_list()[0]: 917 raise errors_impl.InvalidArgumentError( 918 None, None, "Expected sum of lengths to be equal to " 919 "values.shape[0], but sum of lengths is %d and " 920 "value's shape is: %s " % (sum_lengths.numpy(), 921 value.shape.as_list())) 922 elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array): 923 raise errors_impl.InvalidArgumentError( 924 None, None, "TensorArray's size is not equal to the size of " 925 "lengths (%d vs. %d), and the TensorArray is not marked as " 926 "dynamically resizeable" % (len(self._tensor_array), 927 lengths.shape[0])) 928 else: 929 self._tensor_array = array_ops.split(value, lengths, name=name) 930 return self.parent() 931 932 def size(self, name=None): 933 """See TensorArray.""" 934 del name # not meaningful when executing eagerly. 935 return constant_op.constant(len(self._tensor_array)) 936 937 def close(self, name=None): 938 del name # not meaningful when executing eagerly. 939 del self._tensor_array[:] 940 941 942# TensorArray is designed to hide an underlying implementation object 943# and as such accesses many of that object's hidden fields. 944# pylint: disable=protected-access 945# pylint:disable=line-too-long 946@tf_export("TensorArray") 947class TensorArray(object): 948 """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. 949 950 This class is meant to be used with dynamic iteration primitives such as 951 `while_loop` and `map_fn`. It supports gradient back-propagation via special 952 "flow" control flow dependencies. 953 954 Example 1: Plain reading and writing. 955 956 >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False) 957 >>> ta = ta.write(0, 10) 958 >>> ta = ta.write(1, 20) 959 >>> ta = ta.write(2, 30) 960 >>> 961 >>> ta.read(0) 962 <tf.Tensor: shape=(), dtype=float32, numpy=10.0> 963 >>> ta.read(1) 964 <tf.Tensor: shape=(), dtype=float32, numpy=20.0> 965 >>> ta.read(2) 966 <tf.Tensor: shape=(), dtype=float32, numpy=30.0> 967 >>> ta.stack() 968 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.], 969 dtype=float32)> 970 971 Example 2: Fibonacci sequence algorithm that writes in a loop then returns. 972 973 >>> @tf.function 974 ... def fibonacci(n): 975 ... ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True) 976 ... ta = ta.unstack([0., 1.]) 977 ... 978 ... for i in range(2, n): 979 ... ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2)) 980 ... 981 ... return ta.stack() 982 >>> 983 >>> fibonacci(7) 984 <tf.Tensor: shape=(7,), dtype=float32, 985 numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)> 986 987 Example 3: A simple loop interacting with a `tf.Variable`. 988 989 >>> v = tf.Variable(1) 990 >>> @tf.function 991 ... def f(x): 992 ... ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True) 993 ... for i in tf.range(x): 994 ... v.assign_add(i) 995 ... ta = ta.write(i, v) 996 ... return ta.stack() 997 >>> f(5) 998 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11], 999 dtype=int32)> 1000 """ 1001 1002 def __init__(self, 1003 dtype, 1004 size=None, 1005 dynamic_size=None, 1006 clear_after_read=None, 1007 tensor_array_name=None, 1008 handle=None, 1009 flow=None, 1010 infer_shape=True, 1011 element_shape=None, 1012 colocate_with_first_write_call=True, 1013 name=None): 1014 """Construct a new TensorArray or wrap an existing TensorArray handle. 1015 1016 A note about the parameter `name`: 1017 1018 The name of the `TensorArray` (even if passed in) is uniquified: each time 1019 a new `TensorArray` is created at runtime it is assigned its own name for 1020 the duration of the run. This avoids name collisions if a `TensorArray` 1021 is created within a `while_loop`. 1022 1023 Args: 1024 dtype: (required) data type of the TensorArray. 1025 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 1026 Required if handle is not provided. 1027 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 1028 can grow the TensorArray past its initial size. Default: False. 1029 clear_after_read: Boolean (optional, default: True). If True, clear 1030 TensorArray values after reading them. This disables read-many 1031 semantics, but allows early release of memory. 1032 tensor_array_name: (optional) Python string: the name of the TensorArray. 1033 This is used when creating the TensorArray handle. If this value is 1034 set, handle should be None. 1035 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 1036 is set, tensor_array_name should be None. Only supported in graph mode. 1037 flow: (optional) A float `Tensor` scalar coming from an existing 1038 `TensorArray.flow`. Only supported in graph mode. 1039 infer_shape: (optional, default: True) If True, shape inference 1040 is enabled. In this case, all elements must have the same shape. 1041 element_shape: (optional, default: None) A `TensorShape` object specifying 1042 the shape constraints of each of the elements of the TensorArray. 1043 Need not be fully defined. 1044 colocate_with_first_write_call: If `True`, the TensorArray will be 1045 colocated on the same device as the Tensor used on its first write 1046 (write operations include `write`, `unstack`, and `split`). If `False`, 1047 the TensorArray will be placed on the device determined by the 1048 device context available during its initialization. 1049 name: A name for the operation (optional). 1050 1051 Raises: 1052 ValueError: if both handle and tensor_array_name are provided. 1053 TypeError: if handle is provided but is not a Tensor. 1054 """ 1055 if (context.executing_eagerly() and 1056 (flow is None or flow.dtype != dtypes.variant)): 1057 # It is possible to create a Variant-style TensorArray even in eager mode, 1058 # and this is fine but can have performance implications in eager. 1059 # An example of when this happens is if a tf.function returns a 1060 # TensorArray in its output; its flow variant object is returned to Eager. 1061 # This can be wrapped back up in a Variant-style TensorArray. 1062 implementation = _EagerTensorArray 1063 elif (flow is not None and flow.dtype == dtypes.variant or 1064 control_flow_util.EnableControlFlowV2(ops.get_default_graph())): 1065 implementation = _GraphTensorArrayV2 1066 else: 1067 implementation = _GraphTensorArray 1068 self._implementation = implementation( 1069 dtype, 1070 size=size, 1071 dynamic_size=dynamic_size, 1072 clear_after_read=clear_after_read, 1073 tensor_array_name=tensor_array_name, 1074 handle=handle, 1075 flow=flow, 1076 infer_shape=infer_shape, 1077 element_shape=element_shape, 1078 colocate_with_first_write_call=colocate_with_first_write_call, 1079 name=name) 1080 1081 self._implementation.parent = weakref.ref(self) 1082 1083 @property 1084 def flow(self): 1085 """The flow `Tensor` forcing ops leading to this TensorArray state.""" 1086 return self._implementation._flow 1087 1088 @property 1089 def dtype(self): 1090 """The data type of this TensorArray.""" 1091 return self._implementation._dtype 1092 1093 @property 1094 def handle(self): 1095 """The reference to the TensorArray.""" 1096 return self._implementation.handle 1097 1098 @property 1099 def element_shape(self): 1100 """The `tf.TensorShape` of elements in this TensorArray.""" 1101 return self._implementation.element_shape 1102 1103 @property 1104 def dynamic_size(self): 1105 """Python bool; if `True` the TensorArray can grow dynamically.""" 1106 return self._implementation._dynamic_size 1107 1108 @property 1109 def _infer_shape(self): 1110 # TODO(slebedev): consider making public or changing TensorArrayStructure 1111 # to access _implementation directly. Note that dynamic_size is also 1112 # only used by TensorArrayStructure. 1113 return self._implementation._infer_shape 1114 1115 def identity(self): 1116 """Returns a TensorArray with the same content and properties. 1117 1118 Returns: 1119 A new TensorArray object with flow that ensures the control dependencies 1120 from the contexts will become control dependencies for writes, reads, etc. 1121 Use this object for all subsequent operations. 1122 """ 1123 return self._implementation.identity() 1124 1125 def grad(self, source, flow=None, name=None): 1126 return self._implementation.grad(source, flow=flow, name=name) 1127 1128 def read(self, index, name=None): 1129 """Read the value at location `index` in the TensorArray. 1130 1131 Args: 1132 index: 0-D. int32 tensor with the index to read from. 1133 name: A name for the operation (optional). 1134 1135 Returns: 1136 The tensor at index `index`. 1137 """ 1138 return self._implementation.read(index, name=name) 1139 1140 @tf_should_use.should_use_result(warn_in_eager=True) 1141 def write(self, index, value, name=None): 1142 """Write `value` into index `index` of the TensorArray. 1143 1144 Args: 1145 index: 0-D. int32 scalar with the index to write to. 1146 value: N-D. Tensor of type `dtype`. The Tensor to write to this index. 1147 name: A name for the operation (optional). 1148 1149 Returns: 1150 A new TensorArray object with flow that ensures the write occurs. 1151 Use this object for all subsequent operations. 1152 1153 Raises: 1154 ValueError: if there are more writers than specified. 1155 """ 1156 return self._implementation.write(index, value, name=name) 1157 1158 def stack(self, name=None): 1159 """Return the values in the TensorArray as a stacked `Tensor`. 1160 1161 All of the values must have been written and their shapes must all match. 1162 If input shapes have rank-`R`, then output shape will have rank-`(R+1)`. 1163 1164 Args: 1165 name: A name for the operation (optional). 1166 1167 Returns: 1168 All the tensors in the TensorArray stacked into one tensor. 1169 """ 1170 return self._implementation.stack(name=name) 1171 1172 def gather(self, indices, name=None): 1173 """Return selected values in the TensorArray as a packed `Tensor`. 1174 1175 All of selected values must have been written and their shapes 1176 must all match. 1177 1178 Args: 1179 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If 1180 the `TensorArray` is not dynamic, `max_value=size()`. 1181 name: A name for the operation (optional). 1182 1183 Returns: 1184 The tensors in the `TensorArray` selected by `indices`, packed into one 1185 tensor. 1186 """ 1187 return self._implementation.gather(indices, name=name) 1188 1189 def concat(self, name=None): 1190 """Return the values in the TensorArray as a concatenated `Tensor`. 1191 1192 All of the values must have been written, their ranks must match, and 1193 and their shapes must all match for all dimensions except the first. 1194 1195 Args: 1196 name: A name for the operation (optional). 1197 1198 Returns: 1199 All the tensors in the TensorArray concatenated into one tensor. 1200 """ 1201 return self._implementation.concat(name=name) 1202 1203 @tf_should_use.should_use_result 1204 def unstack(self, value, name=None): 1205 """Unstack the values of a `Tensor` in the TensorArray. 1206 1207 If input value shapes have rank-`R`, then the output TensorArray will 1208 contain elements whose shapes are rank-`(R-1)`. 1209 1210 Args: 1211 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack. 1212 name: A name for the operation (optional). 1213 1214 Returns: 1215 A new TensorArray object with flow that ensures the unstack occurs. 1216 Use this object for all subsequent operations. 1217 1218 Raises: 1219 ValueError: if the shape inference fails. 1220 """ 1221 return self._implementation.unstack(value, name=name) 1222 1223 @tf_should_use.should_use_result 1224 def scatter(self, indices, value, name=None): 1225 """Scatter the values of a `Tensor` in specific indices of a `TensorArray`. 1226 1227 Args: 1228 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If 1229 the `TensorArray` is not dynamic, `max_value=size()`. 1230 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack. 1231 name: A name for the operation (optional). 1232 1233 Returns: 1234 A new TensorArray object with flow that ensures the scatter occurs. 1235 Use this object for all subsequent operations. 1236 1237 Raises: 1238 ValueError: if the shape inference fails. 1239 """ 1240 return self._implementation.scatter(indices, value, name=name) 1241 1242 @tf_should_use.should_use_result 1243 def split(self, value, lengths, name=None): 1244 """Split the values of a `Tensor` into the TensorArray. 1245 1246 Args: 1247 value: (N+1)-D. Tensor of type `dtype`. The Tensor to split. 1248 lengths: 1-D. int32 vector with the lengths to use when splitting 1249 `value` along its first dimension. 1250 name: A name for the operation (optional). 1251 1252 Returns: 1253 A new TensorArray object with flow that ensures the split occurs. 1254 Use this object for all subsequent operations. 1255 1256 Raises: 1257 ValueError: if the shape inference fails. 1258 """ 1259 return self._implementation.split(value, lengths, name=name) 1260 1261 def size(self, name=None): 1262 """Return the size of the TensorArray.""" 1263 return self._implementation.size(name=name) 1264 1265 @tf_should_use.should_use_result 1266 def close(self, name=None): 1267 """Close the current TensorArray.""" 1268 return self._implementation.close(name=name) 1269 1270 1271def build_ta_with_new_flow(old_ta, flow): 1272 """Builds a TensorArray with a new `flow` tensor.""" 1273 # Sometimes we get old_ta as the implementation, sometimes it's the 1274 # TensorArray wrapper object. 1275 impl = (old_ta._implementation if isinstance(old_ta, TensorArray) 1276 else old_ta) 1277 1278 if not context.executing_eagerly(): 1279 if (not isinstance(impl, _GraphTensorArrayV2) and 1280 control_flow_util.EnableControlFlowV2(ops.get_default_graph())): 1281 raise NotImplementedError("Attempting to build a graph-mode TF2-style " 1282 "TensorArray from either an eager-mode " 1283 "TensorArray or a TF1-style TensorArray. " 1284 "This is not currently supported. You may be " 1285 "attempting to capture a TensorArray " 1286 "inside a tf.function or tf.data map function. " 1287 "Instead, construct a new TensorArray inside " 1288 "the function.") 1289 new_ta = TensorArray( 1290 dtype=impl.dtype, 1291 handle=impl.handle, 1292 flow=flow, 1293 infer_shape=impl._infer_shape, 1294 colocate_with_first_write_call=impl._colocate_with_first_write_call) 1295 new_impl = new_ta._implementation 1296 new_impl._dynamic_size = impl._dynamic_size 1297 new_impl._size = impl._size 1298 new_impl._colocate_with = impl._colocate_with 1299 new_impl._element_shape = impl._element_shape # Share _element_shape. 1300 return new_ta 1301 1302# pylint: enable=protected-access 1303 1304 1305def _check_dtypes(value, dtype): 1306 if value.dtype != dtype: 1307 logging.error( 1308 "Error: Input value {} has dtype {}, but expected dtype {}. " 1309 "This leads to undefined behavior and will be an error " 1310 "in future versions of TensorFlow. Traceback:\n{}".format( 1311 value, str(value.dtype), str(dtype), 1312 "".join(traceback.format_stack()))) 1313 1314 1315@tf_export("TensorArraySpec") 1316@type_spec.register("tf.TensorArraySpec") 1317class TensorArraySpec(type_spec.TypeSpec): 1318 """Type specification for a `tf.TensorArray`.""" 1319 1320 __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"] 1321 1322 value_type = property(lambda self: TensorArray) 1323 1324 def __init__(self, element_shape=None, dtype=dtypes.float32, 1325 dynamic_size=False, infer_shape=True): 1326 """Constructs a type specification for a `tf.TensorArray`. 1327 1328 Args: 1329 element_shape: The shape of each element in the `TensorArray`. 1330 dtype: Data type of the `TensorArray`. 1331 dynamic_size: Whether the `TensorArray` can grow past its initial size. 1332 infer_shape: Whether shape inference is enabled. 1333 """ 1334 self._element_shape = tensor_shape.as_shape(element_shape) 1335 self._dtype = dtypes.as_dtype(dtype) 1336 self._dynamic_size = dynamic_size 1337 self._infer_shape = infer_shape 1338 1339 def is_compatible_with(self, other): 1340 # pylint: disable=protected-access 1341 if not isinstance(other, type_spec.TypeSpec): 1342 other = type_spec.type_spec_from_value(other) 1343 1344 # Note: we intentionally exclude infer_shape in this check. 1345 return (isinstance(other, TensorArraySpec) and 1346 self._dtype.is_compatible_with(other._dtype) and 1347 self._element_shape.is_compatible_with(other._element_shape) and 1348 self._dynamic_size == other._dynamic_size) 1349 1350 def most_specific_compatible_type(self, other): 1351 # pylint: disable=protected-access 1352 if not self.is_compatible_with(other): 1353 raise ValueError("Types are not compatible") 1354 infer_shape = self._infer_shape and other._infer_shape 1355 return TensorArraySpec( 1356 self._element_shape.most_specific_compatible_shape( 1357 other._element_shape), 1358 self._dtype, self._dynamic_size, infer_shape) 1359 1360 def _serialize(self): 1361 return (self._element_shape, self._dtype, self._dynamic_size, 1362 self._infer_shape) 1363 1364 @property 1365 def _component_specs(self): 1366 return [tensor_spec.TensorSpec([], dtypes.variant)] 1367 1368 def _to_components(self, value): 1369 if not isinstance(value, TensorArray): 1370 raise TypeError("value must be a TensorArray, but saw: {}" 1371 .format(type(value))) 1372 if value.flow is not None and value.flow.dtype == dtypes.variant: 1373 return [value.flow] 1374 else: 1375 # Convert to a TF2-style TensorArray. 1376 # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or 1377 # "implementation / as_variant" arg to TensorArray constructor. 1378 with ops.name_scope("convert_tensor_array"): 1379 flow = list_ops.tensor_list_from_tensor( 1380 tensor=value.stack(), element_shape=value.element_shape) 1381 return [flow] 1382 1383 def _from_components(self, tensor_list): 1384 # This will return a TF2 Graph-style TensorArray because tensor_list[0] is 1385 # a variant object. size == -1 implies unknown size. 1386 ret = TensorArray( 1387 dtype=self._dtype, 1388 flow=tensor_list[0], 1389 dynamic_size=self._dynamic_size, 1390 infer_shape=self._infer_shape) 1391 ret._implementation._element_shape = [self._element_shape] # pylint: disable=protected-access 1392 return ret 1393 1394 @staticmethod 1395 def from_value(value): 1396 if not isinstance(value, TensorArray): 1397 raise TypeError("Expected value to be a TensorArray, but saw: {}". 1398 format(type(value))) 1399 1400 return TensorArraySpec( 1401 dtype=value.dtype, 1402 element_shape=value.element_shape, 1403 dynamic_size=value.dynamic_size, 1404 infer_shape=value._infer_shape) # pylint: disable=protected-access 1405 1406 def _to_legacy_output_types(self): 1407 return self._dtype 1408 1409 def _to_legacy_output_shapes(self): 1410 # Sneak the dynamic_size and infer_shape values into the legacy shape. 1411 return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape 1412 ]).concatenate(self._element_shape)) 1413 1414 def _to_legacy_output_classes(self): 1415 return TensorArray 1416 1417 1418# Register the TypeSpec for TensorArray. If TensorArray is updated to be a 1419# CompositeTensor, then this registration can be deleted. 1420type_spec.register_type_spec_from_value_converter( 1421 TensorArray, TensorArraySpec.from_value, allow_subclass=True) 1422