1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorArray: a dynamically sized array of Tensors.""" 16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23import weakref 24 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors_impl 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_util 34from tensorflow.python.ops import gen_control_flow_ops 35from tensorflow.python.ops import gen_data_flow_ops 36from tensorflow.python.ops import list_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.util import tf_should_use 39from tensorflow.python.util.tf_export import tf_export 40 41 42# _GraphTensorArray accesses many of the hidden generated ops, but is in 43# fact built to wrap these methods. 44# pylint: disable=protected-access 45class _GraphTensorArray(object): 46 """Graph-mode implementation of TensorArray. 47 """ 48 49 def __init__(self, 50 dtype, 51 size=None, 52 dynamic_size=None, 53 clear_after_read=None, 54 tensor_array_name=None, 55 handle=None, 56 flow=None, 57 infer_shape=True, 58 element_shape=None, 59 colocate_with_first_write_call=True, 60 name=None): 61 """Constructs a graph mode TensorArray. 62 63 Args: 64 dtype: (required) data type of the TensorArray. 65 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 66 Required if handle is not provided. 67 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 68 can grow the TensorArray past its initial size. Default: False. 69 clear_after_read: Boolean (optional, default: True). If True, clear 70 TensorArray values after reading them. This disables read-many 71 semantics, but allows early release of memory. 72 tensor_array_name: (optional) Python string: the name of the TensorArray. 73 This is used when creating the TensorArray handle. If this value is 74 set, handle should be None. 75 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 76 is set, tensor_array_name should be None. Only supported in graph mode. 77 flow: (optional) A float `Tensor` scalar coming from an existing 78 `TensorArray.flow`. Only supported in graph mode. 79 infer_shape: (optional, default: True) If True, shape inference 80 is enabled. In this case, all elements must have the same shape. 81 element_shape: (optional, default: None) A `TensorShape` object specifying 82 the shape constraints of each of the elements of the TensorArray. 83 Need not be fully defined. 84 colocate_with_first_write_call: If `True`, the TensorArray will be 85 colocated on the same device as the Tensor used on its first write 86 (write operations include `write`, `unstack`, and `split`). If `False`, 87 the TensorArray will be placed on the device determined by the 88 device context available during its initialization. 89 name: A name for the operation (optional). 90 91 Raises: 92 ValueError: if both handle and tensor_array_name are provided. 93 TypeError: if handle is provided but is not a Tensor. 94 """ 95 if handle is not None and tensor_array_name: 96 raise ValueError( 97 "Cannot construct with both handle and tensor_array_name") 98 if handle is not None and not isinstance(handle, ops.Tensor): 99 raise TypeError("Handle must be a Tensor") 100 if handle is None and size is None: 101 raise ValueError("Size must be provided if handle is not provided") 102 if handle is not None and size is not None: 103 raise ValueError("Cannot provide both a handle and size " 104 "at the same time") 105 if handle is not None and element_shape is not None: 106 raise ValueError("Cannot provide both a handle and element_shape " 107 "at the same time") 108 if handle is not None and dynamic_size is not None: 109 raise ValueError("Cannot provide both a handle and dynamic_size " 110 "at the same time") 111 if handle is not None and clear_after_read is not None: 112 raise ValueError("Cannot provide both a handle and clear_after_read " 113 "at the same time") 114 115 if clear_after_read is None: 116 clear_after_read = True 117 self._dynamic_size = None 118 dynamic_size = dynamic_size or False 119 120 self._dtype = dtype 121 122 # Used to keep track of what tensors the TensorArray should be 123 # colocated with. We choose to colocate the TensorArray with the 124 # first tensor written to it. 125 self._colocate_with_first_write_call = colocate_with_first_write_call 126 if colocate_with_first_write_call: 127 self._colocate_with = [] 128 else: 129 self._colocate_with = None 130 131 # Record the current static shape for the array elements. The element 132 # shape is defined either by `element_shape` or the shape of the tensor 133 # of the first write. If `infer_shape` is true, all writes checks for 134 # shape equality. 135 if element_shape is None: 136 self._infer_shape = infer_shape 137 self._element_shape = [] 138 else: 139 self._infer_shape = True 140 self._element_shape = [tensor_shape.TensorShape(element_shape)] 141 with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: 142 if handle is not None: 143 self._handle = handle 144 if flow is None: 145 raise ValueError("flow must not be None if handle is not None.") 146 self._flow = flow 147 else: 148 # Construct the TensorArray with an empty device. The first 149 # write into the TensorArray from a Tensor with a set device 150 # will retroactively set the device value of this op. 151 def create(): 152 """Create the TensorArray op.""" 153 return gen_data_flow_ops.tensor_array_v3( 154 dtype=dtype, 155 size=size, 156 element_shape=element_shape, 157 identical_element_shapes=infer_shape, 158 dynamic_size=dynamic_size, 159 clear_after_read=clear_after_read, 160 tensor_array_name=tensor_array_name, 161 name=scope) 162 if colocate_with_first_write_call: 163 with ops.device(None), ops.colocate_with(None, ignore_existing=True): 164 self._handle, self._flow = create() 165 else: 166 self._handle, self._flow = create() 167 168 @property 169 def flow(self): 170 return self._flow 171 172 @property 173 def dtype(self): 174 return self._dtype 175 176 @property 177 def handle(self): 178 return self._handle 179 180 def _merge_element_shape(self, shape): 181 """Changes the element shape of the array given a shape to merge with. 182 183 Args: 184 shape: A `TensorShape` object to merge with. 185 186 Raises: 187 ValueError: if the provided shape is incompatible with the current 188 element shape of the `TensorArray`. 189 """ 190 191 if self._element_shape: 192 if not shape.is_compatible_with(self._element_shape[0]): 193 raise ValueError( 194 "Inconsistent shapes: saw %s but expected %s " 195 "(and infer_shape=True)" % (shape, self._element_shape[0])) 196 self._element_shape[0] = self._element_shape[0].merge_with(shape) 197 else: 198 self._element_shape.append(shape) 199 200 @contextlib.contextmanager 201 def _maybe_colocate_with(self, value): 202 """Colocate operations with an internal colocation group or `value`. 203 204 Args: 205 value: `Tensor`, the tensor to try to colocate with. 206 207 Yields: 208 Does not yield anything, but the new context is a colocation context. 209 210 If no internal colocation group is set, colocate with `value` and set 211 the internal colocation group to be value. 212 """ 213 if not self._colocate_with_first_write_call: 214 yield 215 else: 216 if not self._colocate_with: 217 self._colocate_with.append(value) 218 with ops.colocate_with(self._colocate_with[0]): 219 yield 220 221 def identity(self): 222 """See TensorArray.""" 223 flow = array_ops.identity(self._flow) 224 ta = TensorArray( 225 dtype=self._dtype, 226 handle=self._handle, 227 flow=flow, 228 infer_shape=self._infer_shape, 229 colocate_with_first_write_call=self._colocate_with_first_write_call) 230 ta._element_shape = self._element_shape 231 ta._colocate_with = self._colocate_with 232 return ta 233 234 def grad(self, source, flow=None, name=None): 235 """See TensorArray.""" 236 # tensor_array_grad requires a flow input when forward 237 # TensorArrays are dynamically sized. This forces the creation 238 # of the grad TensorArray only once the final forward array's size 239 # is fixed. 240 if flow is None: 241 flow = self.flow 242 with ops.name_scope(name, "TensorArrayGrad", [self._handle]): 243 with ops.colocate_with(self._handle): 244 g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3( 245 handle=self._handle, source=source, flow_in=flow, name=name) 246 with ops.control_dependencies([g_handle]): 247 flow = array_ops.identity(flow, name="gradient_flow") 248 g = TensorArray( 249 dtype=self._dtype, 250 handle=g_handle, 251 flow=flow, 252 infer_shape=self._infer_shape, 253 colocate_with_first_write_call=False) 254 g._element_shape = self._element_shape 255 return g 256 257 def read(self, index, name=None): 258 """See TensorArray.""" 259 value = gen_data_flow_ops.tensor_array_read_v3( 260 handle=self._handle, 261 index=index, 262 flow_in=self._flow, 263 dtype=self._dtype, 264 name=name) 265 if self._element_shape: 266 value.set_shape(self._element_shape[0].dims) 267 return value 268 269 @tf_should_use.should_use_result 270 def write(self, index, value, name=None): 271 """See TensorArray.""" 272 with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]): 273 value = ops.convert_to_tensor(value, name="value") 274 if self._infer_shape: 275 self._merge_element_shape(value.shape) 276 with self._maybe_colocate_with(value): 277 flow_out = gen_data_flow_ops.tensor_array_write_v3( 278 handle=self._handle, 279 index=index, 280 value=value, 281 flow_in=self._flow, 282 name=name) 283 ta = TensorArray( 284 dtype=self._dtype, 285 handle=self._handle, 286 flow=flow_out, 287 colocate_with_first_write_call=self._colocate_with_first_write_call) 288 ta._infer_shape = self._infer_shape 289 ta._element_shape = self._element_shape 290 ta._colocate_with = self._colocate_with 291 return ta 292 293 def stack(self, name=None): 294 """See TensorArray.""" 295 with ops.colocate_with(self._handle): 296 with ops.name_scope(name, "TensorArrayStack", [self._handle]): 297 return self.gather(math_ops.range(0, self.size()), name=name) 298 299 def gather(self, indices, name=None): 300 """See TensorArray.""" 301 if self._element_shape: 302 element_shape = self._element_shape[0] 303 else: 304 element_shape = tensor_shape.TensorShape(None) 305 value = gen_data_flow_ops.tensor_array_gather_v3( 306 handle=self._handle, 307 indices=indices, 308 flow_in=self._flow, 309 dtype=self._dtype, 310 name=name, 311 element_shape=element_shape) 312 if self._element_shape and self._element_shape[0].dims is not None: 313 value.set_shape([None] + self._element_shape[0].dims) 314 return value 315 316 def concat(self, name=None): 317 """See TensorArray.""" 318 if self._element_shape and self._element_shape[0].dims is not None: 319 element_shape_except0 = ( 320 tensor_shape.TensorShape(self._element_shape[0].dims[1:])) 321 else: 322 element_shape_except0 = tensor_shape.TensorShape(None) 323 value, _ = gen_data_flow_ops.tensor_array_concat_v3( 324 handle=self._handle, 325 flow_in=self._flow, 326 dtype=self._dtype, 327 name=name, 328 element_shape_except0=element_shape_except0) 329 if self._element_shape and self._element_shape[0].dims is not None: 330 value.set_shape([None] + self._element_shape[0].dims[1:]) 331 return value 332 333 @tf_should_use.should_use_result 334 def unstack(self, value, name=None): 335 """See TensorArray.""" 336 with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]): 337 num_elements = array_ops.shape(value)[0] 338 return self.scatter( 339 indices=math_ops.range(0, num_elements), value=value, name=name) 340 341 @tf_should_use.should_use_result 342 def scatter(self, indices, value, name=None): 343 """See TensorArray.""" 344 with ops.name_scope(name, "TensorArrayScatter", 345 [self._handle, value, indices]): 346 value = ops.convert_to_tensor(value, name="value") 347 if self._infer_shape and not context.executing_eagerly(): 348 self._merge_element_shape(value.shape[1:]) 349 with self._maybe_colocate_with(value): 350 flow_out = gen_data_flow_ops.tensor_array_scatter_v3( 351 handle=self._handle, 352 indices=indices, 353 value=value, 354 flow_in=self._flow, 355 name=name) 356 ta = TensorArray( 357 dtype=self._dtype, 358 handle=self._handle, 359 flow=flow_out, 360 colocate_with_first_write_call=self._colocate_with_first_write_call) 361 ta._infer_shape = self._infer_shape 362 ta._element_shape = self._element_shape 363 ta._colocate_with = self._colocate_with 364 return ta 365 366 @tf_should_use.should_use_result 367 def split(self, value, lengths, name=None): 368 """See TensorArray.""" 369 with ops.name_scope(name, "TensorArraySplit", 370 [self._handle, value, lengths]): 371 value = ops.convert_to_tensor(value, name="value") 372 with self._maybe_colocate_with(value): 373 lengths_64 = math_ops.cast(lengths, dtypes.int64) 374 if self._infer_shape and not context.executing_eagerly(): 375 clengths = tensor_util.constant_value(lengths_64) 376 if value.shape.dims is not None: 377 if clengths is not None and clengths.max() == clengths.min(): 378 self._merge_element_shape( 379 tensor_shape.TensorShape([clengths[0]]).concatenate( 380 value.shape[1:])) 381 flow_out = gen_data_flow_ops.tensor_array_split_v3( 382 handle=self._handle, 383 value=value, 384 lengths=lengths_64, 385 flow_in=self._flow, 386 name=name) 387 ta = TensorArray( 388 dtype=self._dtype, 389 handle=self._handle, 390 flow=flow_out, 391 colocate_with_first_write_call=self._colocate_with_first_write_call) 392 ta._infer_shape = self._infer_shape 393 ta._element_shape = self._element_shape 394 ta._colocate_with = self._colocate_with 395 return ta 396 397 def size(self, name=None): 398 """See TensorArray.""" 399 return gen_data_flow_ops.tensor_array_size_v3( 400 handle=self._handle, flow_in=self.flow, name=name) 401 402 @tf_should_use.should_use_result 403 def close(self, name=None): 404 """See TensorArray.""" 405 return gen_data_flow_ops.tensor_array_close_v3( 406 handle=self._handle, name=name) 407 408 409class _GraphTensorArrayV2(object): 410 """Graph-mode implementation of TensorArray backed by TensorLists. 411 412 The backing tensor of this TensorArray is a TensorList variant tensor which is 413 stored in the `flow`. The `handle` is always none here. The reason we use the 414 `flow` field and not the `handle` field is to ensure backwards compatibility 415 with legacy control flow. 416 """ 417 418 def __init__(self, 419 dtype, 420 size=None, 421 dynamic_size=None, 422 clear_after_read=None, 423 tensor_array_name=None, 424 handle=None, 425 flow=None, 426 infer_shape=True, 427 element_shape=None, 428 colocate_with_first_write_call=True, 429 name=None): 430 """Constructs a graph mode TensorArray. 431 432 Args: 433 dtype: (required) data type of the TensorArray. 434 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 435 Required if flow is not provided. 436 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 437 can grow the TensorArray past its initial size. Default: False. 438 clear_after_read: (optional) unused. Not supported in TensorLists. 439 tensor_array_name: (optional) unused. 440 handle: (optional) Must always be None. 441 flow: (optional) A variant `Tensor` scalar for a TensorList. 442 infer_shape: (optional, default: True) If True, shape inference is 443 enabled. In this case, all elements must have the same shape. 444 element_shape: (optional, default: None) A `TensorShape` object specifying 445 the shape constraints of each of the elements of the TensorArray. Need 446 not be fully defined. 447 colocate_with_first_write_call: (optional). unused. 448 name: (optional) A name for the operation. 449 450 Raises: 451 ValueError: if both handle and tensor_array_name are provided. 452 TypeError: if handle is provided but is not a Tensor. 453 """ 454 assert handle is None 455 del handle 456 del clear_after_read 457 del tensor_array_name 458 del colocate_with_first_write_call 459 460 self._dynamic_size = dynamic_size 461 462 if (flow is not None and 463 (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)): 464 raise TypeError("flow must be a variant tensor") 465 if flow is None and size is None: 466 raise ValueError("Size must be provided if flow is not provided") 467 if flow is not None and size is not None: 468 raise ValueError("Cannot provide both a flow and size " 469 "at the same time") 470 if flow is not None and element_shape is not None: 471 raise ValueError("Cannot provide both a flow and element_shape " 472 "at the same time") 473 474 self._dtype = dtype 475 476 # Record the current static shape for the array elements. The element 477 # shape is defined either by `element_shape` or the shape of the tensor 478 # of the first write. If `infer_shape` is true, all writes checks for 479 # shape equality. 480 if element_shape is None: 481 self._infer_shape = infer_shape 482 self._element_shape = [] 483 else: 484 self._infer_shape = True 485 self._element_shape = [tensor_shape.TensorShape(element_shape)] 486 with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope: 487 if flow is None: 488 self._flow = list_ops.tensor_list_reserve( 489 element_shape=element_shape, 490 num_elements=size, 491 element_dtype=dtype, 492 name=scope) 493 else: 494 self._flow = flow 495 496 # For backwards compatibility. 497 self._colocate_with_first_write_call = None 498 self._colocate_with = None 499 500 @property 501 def flow(self): 502 return self._flow 503 504 @property 505 def dtype(self): 506 return self._dtype 507 508 @property 509 def handle(self): 510 # We intentionally do not raise an error so that legacy while_loop does not 511 # complain. 512 return None 513 514 def _merge_element_shape(self, shape): 515 """Changes the element shape of the array given a shape to merge with. 516 517 Args: 518 shape: A `TensorShape` object to merge with. 519 520 Raises: 521 ValueError: if the provided shape is incompatible with the current 522 element shape of the `TensorArray`. 523 """ 524 525 if self._element_shape: 526 if not shape.is_compatible_with(self._element_shape[0]): 527 raise ValueError( 528 "Inconsistent shapes: saw %s but expected %s " 529 "(and infer_shape=True)" % (shape, self._element_shape[0])) 530 self._element_shape[0] = self._element_shape[0].merge_with(shape) 531 else: 532 self._element_shape.append(shape) 533 534 def identity(self): 535 """See TensorArray.""" 536 flow = array_ops.identity(self._flow) 537 return build_ta_with_new_flow(self, flow) 538 539 def grad(self, source, flow=None, name=None): 540 """Not supported.""" 541 raise NotImplementedError() 542 543 def read(self, index, name=None): 544 """See TensorArray.""" 545 with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]): 546 if self._element_shape: 547 element_shape = self._element_shape[0] 548 else: 549 element_shape = tensor_shape.TensorShape(None) 550 value = list_ops.tensor_list_get_item( 551 input_handle=self._flow, 552 index=index, 553 element_dtype=self._dtype, 554 element_shape=element_shape, 555 name=name) 556 if self._element_shape: 557 value.set_shape(self._element_shape[0].dims) 558 return value 559 560 @tf_should_use.should_use_result 561 def write(self, index, value, name=None): 562 """See TensorArray.""" 563 with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]): 564 value = ops.convert_to_tensor(value, name="value") 565 if self._infer_shape: 566 self._merge_element_shape(value.shape) 567 flow_out = list_ops.tensor_list_set_item( 568 input_handle=self._flow, 569 index=index, 570 item=value, 571 resize_if_index_out_of_bounds=self._dynamic_size, 572 name=name) 573 return build_ta_with_new_flow(self, flow_out) 574 575 def stack(self, name=None): 576 """See TensorArray.""" 577 with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]): 578 if self._element_shape: 579 element_shape = self._element_shape[0] 580 else: 581 element_shape = tensor_shape.TensorShape(None) 582 value = list_ops.tensor_list_stack( 583 input_handle=self._flow, 584 element_dtype=self._dtype, 585 element_shape=element_shape) 586 if self._element_shape and self._element_shape[0].dims is not None: 587 value.set_shape([None] + self._element_shape[0].dims) 588 return value 589 590 def gather(self, indices, name=None): 591 """See TensorArray.""" 592 if self._element_shape: 593 element_shape = self._element_shape[0] 594 else: 595 element_shape = tensor_shape.TensorShape(None) 596 value = list_ops.tensor_list_gather( 597 input_handle=self._flow, 598 indices=indices, 599 element_dtype=self._dtype, 600 element_shape=element_shape, 601 name=name) 602 if self._element_shape and self._element_shape[0].dims is not None: 603 value.set_shape([None] + self._element_shape[0].dims) 604 return value 605 606 def concat(self, name=None): 607 """See TensorArray.""" 608 if self._element_shape and self._element_shape[0].dims is not None: 609 element_shape = [None] + self._element_shape[0].dims[1:] 610 else: 611 element_shape = None 612 613 value = list_ops.tensor_list_concat( 614 input_handle=self._flow, 615 element_dtype=self._dtype, 616 element_shape=element_shape, 617 name=name) 618 return value 619 620 @tf_should_use.should_use_result 621 def unstack(self, value, name=None): 622 """See TensorArray.""" 623 with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]): 624 value = ops.convert_to_tensor(value, name="value") 625 if self._infer_shape and not context.executing_eagerly(): 626 self._merge_element_shape(value.shape[1:]) 627 flow_out = list_ops.tensor_list_from_tensor( 628 tensor=value, element_shape=value.shape[1:]) 629 return build_ta_with_new_flow(self, flow_out) 630 631 @tf_should_use.should_use_result 632 def scatter(self, indices, value, name=None): 633 """See TensorArray.""" 634 with ops.name_scope(name, "TensorArrayScatter", 635 [self._flow, value, indices]): 636 value = ops.convert_to_tensor(value, name="value") 637 if self._infer_shape and not context.executing_eagerly(): 638 self._merge_element_shape(value.shape[1:]) 639 element_shape = self._element_shape[0] if self._element_shape else None 640 flow_out = list_ops.tensor_list_scatter( 641 tensor=value, indices=indices, input_handle=self._flow) 642 return build_ta_with_new_flow(self, flow_out) 643 644 @tf_should_use.should_use_result 645 def split(self, value, lengths, name=None): 646 """See TensorArray.""" 647 with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]): 648 value = ops.convert_to_tensor(value, name="value") 649 lengths_64 = math_ops.cast(lengths, dtypes.int64) 650 if self._infer_shape and not context.executing_eagerly(): 651 clengths = tensor_util.constant_value(lengths_64) 652 if value.shape.dims is not None: 653 if clengths is not None and clengths.max() == clengths.min(): 654 self._merge_element_shape( 655 tensor_shape.TensorShape([clengths[0]]).concatenate( 656 value.shape[1:])) 657 flow_out = list_ops.tensor_list_split( 658 tensor=value, 659 lengths=lengths_64, 660 element_shape=self._element_shape[0] if self._element_shape else None, 661 name=name) 662 return build_ta_with_new_flow(self, flow_out) 663 664 def size(self, name=None): 665 """See TensorArray.""" 666 return list_ops.tensor_list_length(input_handle=self._flow, name=name) 667 668 @tf_should_use.should_use_result 669 def close(self, name=None): 670 """See TensorArray.""" 671 return gen_control_flow_ops.no_op(name=name) 672 673# pylint: enable=protected-access 674 675 676class _EagerTensorArray(object): 677 """Eager-compatible implementation of TensorArray. 678 """ 679 680 def __init__(self, 681 dtype, 682 size=None, 683 dynamic_size=None, 684 clear_after_read=None, 685 tensor_array_name=None, 686 handle=None, 687 flow=None, 688 infer_shape=True, 689 element_shape=None, 690 colocate_with_first_write_call=True, 691 name=None): 692 """Constructs a TensorArray compatible with eager execution. 693 694 Args: 695 dtype: (required) data type of the TensorArray. 696 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 697 Required if handle is not provided. 698 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 699 can grow the TensorArray past its initial size. Default: False. 700 clear_after_read: Boolean (optional, default: True). If True, clear 701 TensorArray values after reading them. This disables read-many 702 semantics, but allows early release of memory. 703 tensor_array_name: unused. 704 handle: unsupported. 705 flow: unsupported. 706 infer_shape: used for error checking, same semantics as TensorArray. 707 element_shape: used for error checking, same semantics as TensorArray. 708 colocate_with_first_write_call: unsupported. 709 name: unsupported. 710 711 Raises: 712 ValueError: handle or flow are supplied, or if size is not supplied. 713 """ 714 715 del (flow, tensor_array_name, name) # Unused. 716 717 if handle is not None: 718 raise ValueError("TensorArray handles are not supported when eager " 719 "execution is enabled.") 720 if size is None: 721 raise ValueError("Size must be declared for TensorArrays when eager " 722 "execution is enabled.") 723 724 # These attributes are not meaningful when eager is enabled, but some 725 # library functions (e.g., those in control_flow_ops.py) access them to 726 # create new tensor arrays; as such, we define them for the sake of 727 # compatibility. 728 self._handle = None 729 # we assign a dummy value to _flow in case other code assumes it to be 730 # a Tensor 731 self._flow = constant_op.constant(0, dtype=dtypes.int32) 732 self._infer_shape = infer_shape 733 self._element_shape = element_shape 734 self._colocate_with_first_write_call = colocate_with_first_write_call 735 736 self._dtype = dtype 737 self._dynamic_size = dynamic_size or False 738 self._clear_after_read = ( 739 True if clear_after_read is None else clear_after_read) 740 self._previously_read_indices = [] 741 742 if isinstance(size, ops.EagerTensor): 743 size = size.numpy() 744 self._tensor_array = [None for _ in range(size)] 745 746 @property 747 def flow(self): 748 """For compatibility; flows are not meaningful when eager is enabled.""" 749 return self._flow 750 751 @property 752 def dtype(self): 753 return self._dtype 754 755 @property 756 def handle(self): 757 """For compatibility; handles are not meaningful when eager is enabled.""" 758 return self._handle 759 760 def identity(self): 761 """See TensorArray.""" 762 return self.parent() 763 764 def grad(self, source, flow=None, name=None): 765 raise NotImplementedError( 766 "TensorArray.grad is not supported when executing eagerly; eager's " 767 "gradient implementation does not use/need this function to compute " 768 "gradients of operations that use TensorArrays.") 769 770 def read(self, index, name=None): 771 """See TensorArray.""" 772 del name # not meaningful when executing eagerly. 773 774 if isinstance(index, ops.EagerTensor): 775 index = index.numpy() 776 777 if index < 0: 778 raise errors_impl.OutOfRangeError( 779 None, None, 780 "Reading from negative indices (index %d) is not allowed." % index) 781 782 if index >= len(self._tensor_array): 783 raise errors_impl.OutOfRangeError( 784 None, None, "Tried to read from index %d but array size is: %d" % 785 (index, len(self._tensor_array))) 786 787 tensor = self._tensor_array[index] 788 if tensor is None: 789 if index in self._previously_read_indices: 790 raise errors_impl.InvalidArgumentError( 791 None, None, 792 "Could not read index %d twice because it was cleared after " 793 "a previous read (perhaps try setting clear_after_read = false?)" % 794 index) 795 else: 796 tensor = self._maybe_zero(index) 797 798 if self._clear_after_read: 799 self._tensor_array[index] = None 800 self._previously_read_indices.append(index) 801 return tensor 802 803 def _write(self, index, value): 804 """Writes `value` into index named by `index`. 805 806 Args: 807 index: 0-D. int32 scalar with the index to write to. 808 value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`. 809 810 Raises: 811 errors_impl.InvalidArgumentError: `value` dtype does not match dtype. 812 errors_impl.OutOfRangeError: `index` is out of bounds. 813 ValueError: shape of `value` is not consistent with inferred shape. 814 """ 815 816 if isinstance(index, ops.EagerTensor): 817 index = index.numpy() 818 819 if index < 0: 820 raise errors_impl.OutOfRangeError( 821 None, None, 822 "Writing to negative indices (index %d) is not allowed." % index) 823 824 size = len(self._tensor_array) 825 if index >= size: 826 if not self._dynamic_size: 827 raise errors_impl.OutOfRangeError( 828 None, None, 829 "Tried to write to index %d but array is not resizeable and size " 830 "is: %d" % (index, size)) 831 self._tensor_array.extend([None for _ in range(index - size + 1)]) 832 833 if not isinstance(value, ops.EagerTensor): 834 value = ops.convert_to_tensor(value) 835 836 if self._infer_shape: 837 if self._element_shape is None: 838 self._element_shape = value.shape 839 elif not self._element_shape.is_compatible_with(value.shape): 840 raise ValueError("Incompatible shape for value (%s), expected (%s)" % 841 (value.shape.as_list(), self._element_shape.as_list())) 842 843 if self._dtype != value.dtype: 844 raise errors_impl.InvalidArgumentError( 845 None, None, 846 "TensorArray dtype is %s but Op is trying to write dtype %s" % 847 (self._dtype.name, value.dtype.name)) 848 self._tensor_array[index] = value 849 850 def write(self, index, value, name=None): 851 """See TensorArray.""" 852 del name # not meaningful when executing eagerly. 853 self._write(index, value) 854 return self.parent() 855 856 def _maybe_zero(self, ix): 857 val = self._tensor_array[ix] 858 if val is None: 859 val = self._tensor_array[ix] = array_ops.zeros( 860 shape=self._element_shape, dtype=self._dtype) 861 return val 862 863 def stack(self, name=None): 864 """See TensorArray.""" 865 if self._tensor_array: 866 for ix in range(len(self._tensor_array)): 867 self._maybe_zero(ix) 868 return ops.convert_to_tensor( 869 self._tensor_array, name=name, dtype=self._dtype) 870 871 def gather(self, indices, name=None): 872 """See TensorArray.""" 873 del name # not meaningful when executing eagerly. 874 if isinstance(indices, ops.EagerTensor): 875 indices = indices.numpy() 876 return array_ops.stack([self._maybe_zero(i) for i in indices]) 877 878 def concat(self, name=None): 879 """See TensorArray.""" 880 try: 881 return array_ops.concat( 882 [self._maybe_zero(ix) for ix in range(len(self._tensor_array))], 883 0, name=name) 884 except errors_impl.OpError: 885 # Reproduce a subset of the error-handling for graph-mode TensorArrays. 886 shapes = [t.shape for t in self._tensor_array] 887 ndims = [s.ndims for s in shapes] 888 if 0 in ndims: 889 idx = ndims.index(0) 890 raise errors_impl.InvalidArgumentError( 891 None, None, "Concat saw a scalar shape at index %d but requires " 892 "at least vectors." % idx) 893 else: 894 raise 895 896 def unstack(self, value, name=None): 897 """See TensorArray.""" 898 tensors = array_ops.unstack(value, name=name) 899 if len(tensors) > len(self._tensor_array) and not self._dynamic_size: 900 raise ValueError( 901 "Cannot unstack %d tensors into a TensorArray of static size %d" % 902 (len(tensors), len(self._tensor_array))) 903 self._tensor_array = tensors 904 return self.parent() 905 906 def scatter(self, indices, value, name=None): 907 """See TensorArray.""" 908 del name # not meaningful when executing eagerly. 909 if isinstance(indices, ops.EagerTensor): 910 indices = indices.numpy() 911 for index, val in zip(indices, array_ops.unstack(value)): 912 self._write(index, val) # pylint: disable=protected-access 913 return self.parent() 914 915 def split(self, value, lengths, name=None): 916 """See TensorArray.""" 917 # error checking to match graph-mode errors 918 value = ops.convert_to_tensor(value) 919 lengths = ops.convert_to_tensor(lengths) 920 sum_lengths = math_ops.reduce_sum(lengths) 921 if lengths.shape.ndims != 1: 922 raise errors_impl.InvalidArgumentError( 923 None, None, "Expected lengths to be a vector, received shape: %s" % 924 lengths.shape.as_list()) 925 elif value.shape.ndims == 0: 926 raise errors_impl.InvalidArgumentError( 927 None, None, "Expected value to be at least a vector, " 928 "but received shape: %s" % value.shape.as_list()) 929 elif sum_lengths.numpy() != value.shape.as_list()[0]: 930 raise errors_impl.InvalidArgumentError( 931 None, None, "Expected sum of lengths to be equal to " 932 "values.shape[0], but sum of lengths is %d and " 933 "value's shape is: %s " % (sum_lengths.numpy(), 934 value.shape.as_list())) 935 elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array): 936 raise errors_impl.InvalidArgumentError( 937 None, None, "TensorArray's size is not equal to the size of " 938 "lengths (%d vs. %d), and the TensorArray is not marked as " 939 "dynamically resizeable" % (len(self._tensor_array), 940 lengths.shape[0])) 941 else: 942 self._tensor_array = array_ops.split(value, lengths, name=name) 943 return self.parent() 944 945 def size(self, name=None): 946 """See TensorArray.""" 947 del name # not meaningful when executing eagerly. 948 return constant_op.constant(len(self._tensor_array)) 949 950 def close(self, name=None): 951 del name # not meaningful when executing eagerly. 952 del self._tensor_array[:] 953 954 955# TensorArray is designed to hide an underlying implementation object 956# and as such accesses many of that object's hidden fields. 957# pylint: disable=protected-access 958@tf_export("TensorArray") 959class TensorArray(object): 960 """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays. 961 962 This class is meant to be used with dynamic iteration primitives such as 963 `while_loop` and `map_fn`. It supports gradient back-propagation via special 964 "flow" control flow dependencies. 965 """ 966 967 def __init__(self, 968 dtype, 969 size=None, 970 dynamic_size=None, 971 clear_after_read=None, 972 tensor_array_name=None, 973 handle=None, 974 flow=None, 975 infer_shape=True, 976 element_shape=None, 977 colocate_with_first_write_call=True, 978 name=None): 979 """Construct a new TensorArray or wrap an existing TensorArray handle. 980 981 A note about the parameter `name`: 982 983 The name of the `TensorArray` (even if passed in) is uniquified: each time 984 a new `TensorArray` is created at runtime it is assigned its own name for 985 the duration of the run. This avoids name collisions if a `TensorArray` 986 is created within a `while_loop`. 987 988 Args: 989 dtype: (required) data type of the TensorArray. 990 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 991 Required if handle is not provided. 992 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 993 can grow the TensorArray past its initial size. Default: False. 994 clear_after_read: Boolean (optional, default: True). If True, clear 995 TensorArray values after reading them. This disables read-many 996 semantics, but allows early release of memory. 997 tensor_array_name: (optional) Python string: the name of the TensorArray. 998 This is used when creating the TensorArray handle. If this value is 999 set, handle should be None. 1000 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 1001 is set, tensor_array_name should be None. Only supported in graph mode. 1002 flow: (optional) A float `Tensor` scalar coming from an existing 1003 `TensorArray.flow`. Only supported in graph mode. 1004 infer_shape: (optional, default: True) If True, shape inference 1005 is enabled. In this case, all elements must have the same shape. 1006 element_shape: (optional, default: None) A `TensorShape` object specifying 1007 the shape constraints of each of the elements of the TensorArray. 1008 Need not be fully defined. 1009 colocate_with_first_write_call: If `True`, the TensorArray will be 1010 colocated on the same device as the Tensor used on its first write 1011 (write operations include `write`, `unstack`, and `split`). If `False`, 1012 the TensorArray will be placed on the device determined by the 1013 device context available during its initialization. 1014 name: A name for the operation (optional). 1015 1016 Raises: 1017 ValueError: if both handle and tensor_array_name are provided. 1018 TypeError: if handle is provided but is not a Tensor. 1019 """ 1020 if context.executing_eagerly(): 1021 implementation = _EagerTensorArray 1022 else: 1023 if control_flow_util.EnableControlFlowV2(ops.get_default_graph()): 1024 implementation = _GraphTensorArrayV2 1025 else: 1026 implementation = _GraphTensorArray 1027 self._implementation = implementation( 1028 dtype, 1029 size=size, 1030 dynamic_size=dynamic_size, 1031 clear_after_read=clear_after_read, 1032 tensor_array_name=tensor_array_name, 1033 handle=handle, 1034 flow=flow, 1035 infer_shape=infer_shape, 1036 element_shape=element_shape, 1037 colocate_with_first_write_call=colocate_with_first_write_call, 1038 name=name) 1039 1040 self._implementation.parent = weakref.ref(self) 1041 1042 @property 1043 def flow(self): 1044 """The flow `Tensor` forcing ops leading to this TensorArray state.""" 1045 return self._implementation._flow 1046 1047 @property 1048 def dtype(self): 1049 """The data type of this TensorArray.""" 1050 return self._implementation._dtype 1051 1052 @property 1053 def handle(self): 1054 """The reference to the TensorArray.""" 1055 return self._implementation.handle 1056 1057 @property 1058 def _dynamic_size(self): 1059 return self._implementation._dynamic_size 1060 1061 @property 1062 def _infer_shape(self): 1063 return self._implementation._infer_shape 1064 1065 @_infer_shape.setter 1066 def _infer_shape(self, infer_shape): 1067 self._implementation._infer_shape = infer_shape 1068 1069 @property 1070 def _element_shape(self): 1071 return self._implementation._element_shape 1072 1073 @_element_shape.setter 1074 def _element_shape(self, element_shape): 1075 self._implementation._element_shape = element_shape 1076 1077 @property 1078 def _colocate_with_first_write_call(self): 1079 return self._implementation._colocate_with_first_write_call 1080 1081 @property 1082 def _colocate_with(self): 1083 return self._implementation._colocate_with 1084 1085 @_colocate_with.setter 1086 def _colocate_with(self, colocate_with): 1087 self._implementation._colocate_with = colocate_with 1088 1089 def identity(self): 1090 """Returns a TensorArray with the same content and properties. 1091 1092 Returns: 1093 A new TensorArray object with flow that ensures the control dependencies 1094 from the contexts will become control dependencies for writes, reads, etc. 1095 Use this object all for subsequent operations. 1096 """ 1097 return self._implementation.identity() 1098 1099 def grad(self, source, flow=None, name=None): 1100 return self._implementation.grad(source, flow=flow, name=name) 1101 1102 def read(self, index, name=None): 1103 """Read the value at location `index` in the TensorArray. 1104 1105 Args: 1106 index: 0-D. int32 tensor with the index to read from. 1107 name: A name for the operation (optional). 1108 1109 Returns: 1110 The tensor at index `index`. 1111 """ 1112 return self._implementation.read(index, name=name) 1113 1114 @tf_should_use.should_use_result 1115 def write(self, index, value, name=None): 1116 """Write `value` into index `index` of the TensorArray. 1117 1118 Args: 1119 index: 0-D. int32 scalar with the index to write to. 1120 value: N-D. Tensor of type `dtype`. The Tensor to write to this index. 1121 name: A name for the operation (optional). 1122 1123 Returns: 1124 A new TensorArray object with flow that ensures the write occurs. 1125 Use this object all for subsequent operations. 1126 1127 Raises: 1128 ValueError: if there are more writers than specified. 1129 """ 1130 return self._implementation.write(index, value, name=name) 1131 1132 def stack(self, name=None): 1133 """Return the values in the TensorArray as a stacked `Tensor`. 1134 1135 All of the values must have been written and their shapes must all match. 1136 If input shapes have rank-`R`, then output shape will have rank-`(R+1)`. 1137 1138 Args: 1139 name: A name for the operation (optional). 1140 1141 Returns: 1142 All the tensors in the TensorArray stacked into one tensor. 1143 """ 1144 return self._implementation.stack(name=name) 1145 1146 def gather(self, indices, name=None): 1147 """Return selected values in the TensorArray as a packed `Tensor`. 1148 1149 All of selected values must have been written and their shapes 1150 must all match. 1151 1152 Args: 1153 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If 1154 the `TensorArray` is not dynamic, `max_value=size()`. 1155 name: A name for the operation (optional). 1156 1157 Returns: 1158 The tensors in the `TensorArray` selected by `indices`, packed into one 1159 tensor. 1160 """ 1161 return self._implementation.gather(indices, name=name) 1162 1163 def concat(self, name=None): 1164 """Return the values in the TensorArray as a concatenated `Tensor`. 1165 1166 All of the values must have been written, their ranks must match, and 1167 and their shapes must all match for all dimensions except the first. 1168 1169 Args: 1170 name: A name for the operation (optional). 1171 1172 Returns: 1173 All the tensors in the TensorArray concatenated into one tensor. 1174 """ 1175 return self._implementation.concat(name=name) 1176 1177 @tf_should_use.should_use_result 1178 def unstack(self, value, name=None): 1179 """Unstack the values of a `Tensor` in the TensorArray. 1180 1181 If input value shapes have rank-`R`, then the output TensorArray will 1182 contain elements whose shapes are rank-`(R-1)`. 1183 1184 Args: 1185 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack. 1186 name: A name for the operation (optional). 1187 1188 Returns: 1189 A new TensorArray object with flow that ensures the unstack occurs. 1190 Use this object all for subsequent operations. 1191 1192 Raises: 1193 ValueError: if the shape inference fails. 1194 """ 1195 return self._implementation.unstack(value, name=name) 1196 1197 @tf_should_use.should_use_result 1198 def scatter(self, indices, value, name=None): 1199 """Scatter the values of a `Tensor` in specific indices of a `TensorArray`. 1200 1201 Args: 1202 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If 1203 the `TensorArray` is not dynamic, `max_value=size()`. 1204 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack. 1205 name: A name for the operation (optional). 1206 1207 Returns: 1208 A new TensorArray object with flow that ensures the scatter occurs. 1209 Use this object all for subsequent operations. 1210 1211 Raises: 1212 ValueError: if the shape inference fails. 1213 """ 1214 return self._implementation.scatter(indices, value, name=name) 1215 1216 @tf_should_use.should_use_result 1217 def split(self, value, lengths, name=None): 1218 """Split the values of a `Tensor` into the TensorArray. 1219 1220 Args: 1221 value: (N+1)-D. Tensor of type `dtype`. The Tensor to split. 1222 lengths: 1-D. int32 vector with the lengths to use when splitting 1223 `value` along its first dimension. 1224 name: A name for the operation (optional). 1225 1226 Returns: 1227 A new TensorArray object with flow that ensures the split occurs. 1228 Use this object all for subsequent operations. 1229 1230 Raises: 1231 ValueError: if the shape inference fails. 1232 """ 1233 return self._implementation.split(value, lengths, name=name) 1234 1235 def size(self, name=None): 1236 """Return the size of the TensorArray.""" 1237 return self._implementation.size(name=name) 1238 1239 @tf_should_use.should_use_result 1240 def close(self, name=None): 1241 """Close the current TensorArray.""" 1242 return self._implementation.close(name=name) 1243 1244 1245def build_ta_with_new_flow(old_ta, flow): 1246 """Builds a TensorArray with a new `flow` tensor.""" 1247 ta = TensorArray( 1248 dtype=old_ta.dtype, 1249 dynamic_size=old_ta._dynamic_size, 1250 handle=old_ta.handle, 1251 flow=flow, 1252 infer_shape=old_ta._infer_shape, 1253 colocate_with_first_write_call=old_ta._colocate_with_first_write_call) 1254 ta._colocate_with = old_ta._colocate_with 1255 ta._element_shape = old_ta._element_shape 1256 return ta 1257 1258# pylint: enable=protected-access 1259