• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorArray: a dynamically sized array of Tensors."""
16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23
24import traceback
25import weakref
26
27import numpy as np
28
29from tensorflow.python.eager import context
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors_impl
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.framework import type_spec
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_util
40from tensorflow.python.ops import gen_control_flow_ops
41from tensorflow.python.ops import gen_data_flow_ops
42from tensorflow.python.ops import list_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import tf_should_use
46from tensorflow.python.util.tf_export import tf_export
47
48
49# _GraphTensorArray accesses many of the hidden generated ops, but is in
50# fact built to wrap these methods.
51# pylint: disable=protected-access
52class _GraphTensorArray(object):
53  """Graph-mode implementation of TensorArray."""
54
55  def __init__(self,
56               dtype,
57               size=None,
58               dynamic_size=None,
59               clear_after_read=None,
60               tensor_array_name=None,
61               handle=None,
62               flow=None,
63               infer_shape=True,
64               element_shape=None,
65               colocate_with_first_write_call=True,
66               name=None):
67    """Constructs a graph mode TensorArray.
68
69    Args:
70      dtype: (required) data type of the TensorArray.
71      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
72        Required if handle is not provided.
73      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
74        can grow the TensorArray past its initial size.  Default: False.
75      clear_after_read: Boolean (optional, default: True).  If True, clear
76        TensorArray values after reading them.  This disables read-many
77        semantics, but allows early release of memory.
78      tensor_array_name: (optional) Python string: the name of the TensorArray.
79        This is used when creating the TensorArray handle.  If this value is
80        set, handle should be None.
81      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
82        is set, tensor_array_name should be None. Only supported in graph mode.
83      flow: (optional) A float `Tensor` scalar coming from an existing
84        `TensorArray.flow`. Only supported in graph mode.
85      infer_shape: (optional, default: True) If True, shape inference is
86        enabled.  In this case, all elements must have the same shape.
87      element_shape: (optional, default: None) A `TensorShape` object specifying
88        the shape constraints of each of the elements of the TensorArray. Need
89        not be fully defined.
90      colocate_with_first_write_call: If `True`, the TensorArray will be
91        colocated on the same device as the Tensor used on its first write
92        (write operations include `write`, `unstack`, and `split`).  If `False`,
93        the TensorArray will be placed on the device determined by the device
94        context available during its initialization.
95      name: A name for the operation (optional).
96
97    Raises:
98      ValueError: if both handle and tensor_array_name are provided.
99      TypeError: if handle is provided but is not a Tensor.
100    """
101    if handle is not None and tensor_array_name:
102      raise ValueError(
103          "Cannot provide both `handle` and `tensor_array_name` arguments at "
104          "the same time.")
105    if handle is not None and not isinstance(handle, ops.Tensor):
106      raise TypeError(
107          f"Expected `handle` to be a Tensor, but got `{handle}` of type "
108          f"`{type(handle)}` instead.")
109    if handle is None and size is None:
110      raise ValueError(
111          "Argument `size` must be provided if handle is not provided.")
112    if handle is not None and size is not None:
113      raise ValueError("Cannot provide both a `handle` and `size` arguments "
114                       "at the same time.")
115    if handle is not None and element_shape is not None:
116      raise ValueError(
117          "Cannot provide both `handle` and `element_shape` arguments "
118          "at the same time.")
119    if handle is not None and dynamic_size is not None:
120      raise ValueError(
121          "Cannot provide both `handle` and `dynamic_size` arguments "
122          "at the same time.")
123    if handle is not None and clear_after_read is not None:
124      raise ValueError(
125          "Cannot provide both `handle` and `clear_after_read` arguments "
126          "at the same time.")
127
128    if clear_after_read is None:
129      clear_after_read = True
130    self._dynamic_size = dynamic_size or False
131    self._dtype = dtypes.as_dtype(dtype).base_dtype
132
133    # Used to keep track of what tensors the TensorArray should be
134    # colocated with.  We choose to colocate the TensorArray with the
135    # first tensor written to it.
136    self._colocate_with_first_write_call = colocate_with_first_write_call
137    if colocate_with_first_write_call:
138      self._colocate_with = []
139    else:
140      self._colocate_with = None
141
142    # Record the current static shape for the array elements. The element
143    # shape is defined either by `element_shape` or the shape of the tensor
144    # of the first write. If `infer_shape` is true, all writes checks for
145    # shape equality.
146    self._element_shape = [tensor_shape.as_shape(element_shape)]
147    self._infer_shape = infer_shape
148    self._size = size
149    with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
150      if handle is not None:
151        self._handle = handle
152        if flow is None:
153          raise ValueError("flow must not be None if handle is not None.")
154        self._flow = flow
155      else:
156        # Construct the TensorArray with an empty device.  The first
157        # write into the TensorArray from a Tensor with a set device
158        # will retroactively set the device value of this op.
159        def create():
160          """Create the TensorArray op."""
161          return gen_data_flow_ops.tensor_array_v3(
162              dtype=dtype,
163              size=size,
164              element_shape=element_shape,
165              identical_element_shapes=infer_shape,
166              dynamic_size=self._dynamic_size,
167              clear_after_read=clear_after_read,
168              tensor_array_name=tensor_array_name,
169              name=scope)
170
171        if colocate_with_first_write_call:
172          with ops.device(None), ops.colocate_with(None, ignore_existing=True):
173            self._handle, self._flow = create()
174        else:
175          self._handle, self._flow = create()
176
177  @property
178  def flow(self):
179    return self._flow
180
181  @property
182  def dtype(self):
183    return self._dtype
184
185  @property
186  def handle(self):
187    return self._handle
188
189  @property
190  def element_shape(self):
191    return self._element_shape[0]
192
193  def _check_element_shape(self, shape):
194    """Changes the element shape of the array given a shape to merge with.
195
196    Args:
197      shape: A `TensorShape` object to merge with.
198
199    Raises:
200      ValueError: if the provided shape is incompatible with the current
201          element shape of the `TensorArray`.
202    """
203    if not shape.is_compatible_with(self.element_shape):
204      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
205                       (shape, self.element_shape))
206    if self._infer_shape:
207      self._element_shape[0] = self.element_shape.merge_with(shape)
208
209  @contextlib.contextmanager
210  def _maybe_colocate_with(self, value):
211    """Colocate operations with an internal colocation group or `value`.
212
213    Args:
214      value: `Tensor`, the tensor to try to colocate with.
215
216    Yields:
217      Does not yield anything, but the new context is a colocation context.
218
219    If no internal colocation group is set, colocate with `value` and set
220    the internal colocation group to be value.
221    """
222    if not self._colocate_with_first_write_call:
223      yield
224    else:
225      if not self._colocate_with:
226        self._colocate_with.append(value)
227      with ops.colocate_with(self._colocate_with[0]):
228        yield
229
230  def identity(self):
231    """See TensorArray."""
232    flow = array_ops.identity(self._flow)
233    return build_ta_with_new_flow(self, flow)
234
235  def grad(self, source, flow=None, name=None):
236    """See TensorArray."""
237    # tensor_array_grad requires a flow input when forward
238    # TensorArrays are dynamically sized.  This forces the creation
239    # of the grad TensorArray only once the final forward array's size
240    # is fixed.
241    if flow is None:
242      flow = self.flow
243    with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
244      with ops.colocate_with(self._handle):
245        g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3(
246            handle=self._handle, source=source, flow_in=flow, name=name)
247        with ops.control_dependencies([g_handle]):
248          flow = array_ops.identity(flow, name="gradient_flow")
249        g = TensorArray(
250            dtype=self._dtype,
251            handle=g_handle,
252            flow=flow,
253            infer_shape=self._infer_shape,
254            colocate_with_first_write_call=False)
255        # pylint: disable=protected-access
256        g._implementation._element_shape = self._element_shape
257        # pylint: enable=protected-access
258        return g
259
260  def read(self, index, name=None):
261    """See TensorArray."""
262    value = gen_data_flow_ops.tensor_array_read_v3(
263        handle=self._handle,
264        index=index,
265        flow_in=self._flow,
266        dtype=self._dtype,
267        name=name)
268    if self._element_shape:
269      value.set_shape(self._element_shape[0].dims)
270    return value
271
272  def write(self, index, value, name=None):
273    """See TensorArray."""
274    with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
275      # TODO(b/129870929): Fix after all callers provide proper init dtype.
276      value = ops.convert_to_tensor(
277          value, preferred_dtype=self._dtype, name="value")
278      _check_dtypes(value, self._dtype)
279      self._check_element_shape(value.shape)
280      with self._maybe_colocate_with(value):
281        flow_out = gen_data_flow_ops.tensor_array_write_v3(
282            handle=self._handle,
283            index=index,
284            value=value,
285            flow_in=self._flow,
286            name=name)
287      return build_ta_with_new_flow(self, flow_out)
288
289  def stack(self, name=None):
290    """See TensorArray."""
291    with ops.colocate_with(self._handle):
292      with ops.name_scope(name, "TensorArrayStack", [self._handle]):
293        value = self.gather(math_ops.range(0, self.size()), name=name)
294        if (self.element_shape and not self._dynamic_size and
295            self._size is not None):
296          value.set_shape([tensor_util.constant_value(self._size)] +
297                          self.element_shape.dims)
298        return value
299
300  def gather(self, indices, name=None):
301    """See TensorArray."""
302    if self._element_shape:
303      element_shape = self._element_shape[0]
304    else:
305      element_shape = tensor_shape.unknown_shape(None)
306    value = gen_data_flow_ops.tensor_array_gather_v3(
307        handle=self._handle,
308        indices=indices,
309        flow_in=self._flow,
310        dtype=self._dtype,
311        name=name,
312        element_shape=element_shape)
313    if self.element_shape:
314      value.set_shape([None] + self.element_shape.dims)
315    return value
316
317  def concat(self, name=None):
318    """See TensorArray."""
319    value, _ = gen_data_flow_ops.tensor_array_concat_v3(
320        handle=self._handle,
321        flow_in=self._flow,
322        dtype=self._dtype,
323        name=name,
324        element_shape_except0=self.element_shape[1:])
325    if self.element_shape:
326      value.set_shape([None] + self.element_shape.dims[1:])
327    return value
328
329  @tf_should_use.should_use_result
330  def unstack(self, value, name=None):
331    """See TensorArray."""
332    with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
333      num_elements = array_ops.shape(value)[0]
334      return self.scatter(
335          indices=math_ops.range(0, num_elements), value=value, name=name)
336
337  @tf_should_use.should_use_result
338  def scatter(self, indices, value, name=None):
339    """See TensorArray."""
340    with ops.name_scope(name, "TensorArrayScatter",
341                        [self._handle, value, indices]):
342      # TODO(b/129870929): Fix after all callers provide proper init dtype.
343      value = ops.convert_to_tensor(
344          value, preferred_dtype=self._dtype, name="value")
345      _check_dtypes(value, self._dtype)
346      if not context.executing_eagerly():
347        self._check_element_shape(value.shape[1:])
348      with self._maybe_colocate_with(value):
349        flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
350            handle=self._handle,
351            indices=indices,
352            value=value,
353            flow_in=self._flow,
354            name=name)
355      return build_ta_with_new_flow(self, flow_out)
356
357  @tf_should_use.should_use_result
358  def split(self, value, lengths, name=None):
359    """See TensorArray."""
360    with ops.name_scope(name, "TensorArraySplit",
361                        [self._handle, value, lengths]):
362      value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
363      with self._maybe_colocate_with(value):
364        lengths_64 = math_ops.cast(lengths, dtypes.int64)
365        if not context.executing_eagerly():
366          clengths = tensor_util.constant_value(lengths_64)
367          if value.shape.dims is not None and clengths is not None:
368            if clengths.shape and clengths.max() == clengths.min():
369              self._check_element_shape(
370                  tensor_shape.TensorShape([clengths[0]
371                                           ]).concatenate(value.shape[1:]))
372        flow_out = gen_data_flow_ops.tensor_array_split_v3(
373            handle=self._handle,
374            value=value,
375            lengths=lengths_64,
376            flow_in=self._flow,
377            name=name)
378      return build_ta_with_new_flow(self, flow_out)
379
380  def size(self, name=None):
381    """See TensorArray."""
382    if not self._dynamic_size and self._size is not None:
383      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
384    else:
385      return gen_data_flow_ops.tensor_array_size_v3(
386          handle=self._handle, flow_in=self.flow, name=name)
387
388  @tf_should_use.should_use_result
389  def close(self, name=None):
390    """See TensorArray."""
391    return gen_data_flow_ops.tensor_array_close_v3(
392        handle=self._handle, name=name)
393
394
395class _GraphTensorArrayV2(object):
396  """Graph-mode implementation of TensorArray backed by TensorLists.
397
398  The backing tensor of this TensorArray is a TensorList variant tensor which is
399  stored in the `flow`. The `handle` is always none here. The reason we use the
400  `flow` field and not the `handle` field is to ensure backwards compatibility
401  with legacy control flow.
402  """
403
404  def __init__(self,
405               dtype,
406               size=None,
407               dynamic_size=None,
408               clear_after_read=None,
409               tensor_array_name=None,
410               handle=None,
411               flow=None,
412               infer_shape=True,
413               element_shape=None,
414               colocate_with_first_write_call=True,
415               name=None):
416    """Constructs a graph mode TensorArray.
417
418    Args:
419      dtype: (required) data type of the TensorArray.
420      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
421        Required if flow is not provided.
422      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
423        can grow the TensorArray past its initial size.  Default: False.
424      clear_after_read: (optional) unused. Not supported in TensorLists.
425      tensor_array_name: (optional) unused.
426      handle: (optional) Must always be None.
427      flow: (optional) A variant `Tensor` scalar for a TensorList.
428      infer_shape: (optional, default: True) If True, shape inference is
429        enabled.  In this case, all elements must have the same shape.
430      element_shape: (optional, default: None) A `TensorShape` object specifying
431        the shape constraints of each of the elements of the TensorArray. Need
432        not be fully defined.
433      colocate_with_first_write_call: (optional). unused.
434      name: (optional) A name for the operation.
435
436    Raises:
437      ValueError: if both handle and tensor_array_name are provided.
438      TypeError: if handle is provided but is not a Tensor.
439    """
440    assert handle is None
441    del handle
442    del clear_after_read
443    del tensor_array_name
444    del colocate_with_first_write_call
445
446    self._dynamic_size = dynamic_size
447    self._size = size
448
449    if (flow is not None and
450        (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)):
451      raise TypeError(
452          f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` "
453          f"instead.")
454    if flow is None and size is None:
455      raise ValueError("Argument `size` must be provided if argument `flow` "
456                       "is not provided.")
457    if flow is not None and size is not None:
458      raise ValueError("Cannot provide both `flow` and `size` arguments "
459                       "at the same time.")
460    if flow is not None and element_shape is not None:
461      raise ValueError(
462          "Cannot provide both `flow` and `element_shape` arguments"
463          "at the same time.")
464
465    self._dtype = dtypes.as_dtype(dtype).base_dtype
466
467    # Record the current static shape for the array elements. The element
468    # shape is defined either by `element_shape` or the shape of the tensor
469    # of the first write. If `infer_shape` is true, all writes checks for
470    # shape equality.
471    self._element_shape = [tensor_shape.as_shape(element_shape)]
472    self._infer_shape = infer_shape
473    with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
474      if flow is None:
475        self._flow = list_ops.tensor_list_reserve(
476            element_shape=element_shape,
477            num_elements=size,
478            element_dtype=dtype,
479            name=scope)
480      else:
481        self._flow = flow
482
483    # For backwards compatibility.
484    self._colocate_with_first_write_call = None
485    self._colocate_with = None
486
487  @property
488  def flow(self):
489    return self._flow
490
491  @property
492  def dtype(self):
493    return self._dtype
494
495  @property
496  def element_shape(self):
497    return self._element_shape[0]
498
499  @property
500  def handle(self):
501    # We intentionally do not raise an error so that legacy while_loop does not
502    # complain.
503    return None
504
505  def _check_element_shape(self, shape):
506    """Changes the element shape of the array given a shape to merge with.
507
508    Args:
509      shape: A `TensorShape` object to merge with.
510
511    Raises:
512      ValueError: if the provided shape is incompatible with the current
513          element shape of the `TensorArray`.
514    """
515    if not shape.is_compatible_with(self.element_shape):
516      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
517                       (shape, self.element_shape))
518    if self._infer_shape:
519      self._element_shape[0] = self.element_shape.merge_with(shape)
520
521  def identity(self):
522    """See TensorArray."""
523    flow = array_ops.identity(self._flow)
524    return build_ta_with_new_flow(self, flow)
525
526  def grad(self, source, flow=None, name=None):
527    """Not supported."""
528    raise NotImplementedError()
529
530  def read(self, index, name=None):
531    """See TensorArray."""
532    with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
533      value = list_ops.tensor_list_get_item(
534          input_handle=self._flow,
535          index=index,
536          element_dtype=self._dtype,
537          element_shape=self.element_shape,
538          name=name)
539      return value
540
541  def write(self, index, value, name=None):
542    """See TensorArray."""
543    with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
544      # TODO(b/129870929): Fix after all callers provide proper init dtype.
545      value = ops.convert_to_tensor(
546          value, preferred_dtype=self._dtype, name="value")
547      _check_dtypes(value, self._dtype)
548      self._check_element_shape(value.shape)
549      flow_out = list_ops.tensor_list_set_item(
550          input_handle=self._flow,
551          index=index,
552          item=value,
553          resize_if_index_out_of_bounds=self._dynamic_size,
554          name=name)
555      return build_ta_with_new_flow(self, flow_out)
556
557  def stack(self, name=None):
558    """See TensorArray."""
559    with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
560      # TODO(b/139941163): remove constant_value after changing num_elements to regular input
561      if not self._dynamic_size and self._size is not None:
562        ta_size = tensor_util.constant_value(self._size)
563      else:
564        ta_size = -1
565      value = list_ops.tensor_list_stack(
566          input_handle=self._flow,
567          element_dtype=self._dtype,
568          num_elements=ta_size,
569          element_shape=self.element_shape)
570      return value
571
572  def gather(self, indices, name=None):
573    """See TensorArray."""
574    value = list_ops.tensor_list_gather(
575        input_handle=self._flow,
576        indices=indices,
577        element_dtype=self._dtype,
578        element_shape=self.element_shape,
579        name=name)
580    return value
581
582  def concat(self, name=None):
583    """See TensorArray."""
584    if self.element_shape:
585      element_shape = [None] + self.element_shape.dims[1:]
586    else:
587      element_shape = None
588
589    value = list_ops.tensor_list_concat(
590        input_handle=self._flow,
591        element_dtype=self._dtype,
592        element_shape=element_shape,
593        name=name)
594    return value
595
596  @tf_should_use.should_use_result
597  def unstack(self, value, name=None):
598    """See TensorArray."""
599    with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
600      # TODO(b/129870929): Fix after all callers provide proper init dtype.
601      value = ops.convert_to_tensor(
602          value, preferred_dtype=self._dtype, name="value")
603      _check_dtypes(value, self._dtype)
604      self._check_element_shape(value.shape[1:])
605      flow_out = list_ops.tensor_list_from_tensor(
606          tensor=value, element_shape=value.shape[1:])
607      return build_ta_with_new_flow(self, flow_out)
608
609  @tf_should_use.should_use_result
610  def scatter(self, indices, value, name=None):
611    """See TensorArray."""
612    with ops.name_scope(name, "TensorArrayScatter",
613                        [self._flow, value, indices]):
614      # TODO(b/129870929): Fix after all callers provide proper init dtype.
615      value = ops.convert_to_tensor(
616          value, preferred_dtype=self._dtype, name="value")
617      _check_dtypes(value, self._dtype)
618      self._check_element_shape(value.shape[1:])
619      flow_out = list_ops.tensor_list_scatter(
620          tensor=value,
621          indices=indices,
622          element_shape=self.element_shape,
623          input_handle=self._flow)
624      return build_ta_with_new_flow(self, flow_out)
625
626  @tf_should_use.should_use_result
627  def split(self, value, lengths, name=None):
628    """See TensorArray."""
629    with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
630      # TODO(b/129870929): Fix after all callers provide proper init dtype.
631      value = ops.convert_to_tensor(
632          value, preferred_dtype=self._dtype, name="value")
633      _check_dtypes(value, self._dtype)
634      lengths_64 = math_ops.cast(lengths, dtypes.int64)
635      if not context.executing_eagerly():
636        clengths = tensor_util.constant_value(lengths_64)
637        if value.shape.dims is not None and clengths is not None:
638          if clengths.shape and clengths.max() == clengths.min():
639            self._check_element_shape(
640                tensor_shape.TensorShape([clengths[0]
641                                         ]).concatenate(value.shape[1:]))
642      flow_out = list_ops.tensor_list_split(
643          tensor=value,
644          lengths=lengths_64,
645          element_shape=self.element_shape,
646          name=name)
647      return build_ta_with_new_flow(self, flow_out)
648
649  def size(self, name=None):
650    """See TensorArray."""
651    if not self._dynamic_size and self._size is not None:
652      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
653    else:
654      return list_ops.tensor_list_length(input_handle=self._flow, name=name)
655
656  def close(self, name=None):
657    """See TensorArray."""
658    return gen_control_flow_ops.no_op(name=name)
659
660
661# pylint: enable=protected-access
662
663
664class _EagerTensorArray(object):
665  """Eager-compatible implementation of TensorArray."""
666
667  def __init__(self,
668               dtype,
669               size=None,
670               dynamic_size=None,
671               clear_after_read=None,
672               tensor_array_name=None,
673               handle=None,
674               flow=None,
675               infer_shape=True,
676               element_shape=None,
677               colocate_with_first_write_call=True,
678               name=None):
679    """Constructs a TensorArray compatible with eager execution.
680
681    Args:
682      dtype: (required) data type of the TensorArray.
683      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
684        Required if handle is not provided.
685      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
686        can grow the TensorArray past its initial size.  Default: False.
687      clear_after_read: Boolean (optional, default: True).  If True, clear
688        TensorArray values after reading them.  This disables read-many
689        semantics, but allows early release of memory.
690      tensor_array_name: unused.
691      handle: unsupported.
692      flow: unsupported.
693      infer_shape: used for error checking, same semantics as TensorArray.
694      element_shape: used for error checking, same semantics as TensorArray.
695      colocate_with_first_write_call: unsupported.
696      name: unsupported.
697
698    Raises:
699      ValueError: handle or flow are supplied, or if size is not supplied.
700    """
701
702    del (flow, tensor_array_name, name)  # Unused.
703
704    if handle is not None:
705      raise ValueError("TensorArray handles are not supported when eager "
706                       "execution is enabled.")
707    if size is None:
708      raise ValueError("Size must be declared for TensorArrays when eager "
709                       "execution is enabled.")
710
711    # These attributes are not meaningful when eager is enabled, but some
712    # library functions (e.g., those in control_flow_ops.py) access them to
713    # create new tensor arrays; as such, we define them for the sake of
714    # compatibility.
715    self._handle = None
716    # we assign a dummy value to _flow in case other code assumes it to be
717    # a Tensor
718    self._flow = constant_op.constant(0, dtype=dtypes.int32)
719    self._infer_shape = infer_shape
720    self._element_shape = tensor_shape.as_shape(element_shape)
721    self._colocate_with_first_write_call = colocate_with_first_write_call
722
723    self._dtype = dtypes.as_dtype(dtype).base_dtype
724    self._dynamic_size = dynamic_size or False
725    self._clear_after_read = (True
726                              if clear_after_read is None else clear_after_read)
727    self._previously_read_indices = []
728
729    if isinstance(size, ops.EagerTensor):
730      size = size.numpy()
731    self._tensor_array = [None for _ in range(size)]
732
733  @property
734  def flow(self):
735    """For compatibility; flows are not meaningful when eager is enabled."""
736    return self._flow
737
738  @property
739  def dtype(self):
740    return self._dtype
741
742  @property
743  def handle(self):
744    """For compatibility; handles are not meaningful when eager is enabled."""
745    return self._handle
746
747  @property
748  def element_shape(self):
749    return self._element_shape
750
751  def identity(self):
752    """See TensorArray."""
753    return self.parent()
754
755  def grad(self, source, flow=None, name=None):
756    raise NotImplementedError(
757        "TensorArray.grad is not supported when executing eagerly; eager's "
758        "gradient implementation does not use/need this function to compute "
759        "gradients of operations that use TensorArrays.")
760
761  def read(self, index, name=None):
762    """See TensorArray."""
763    del name  # not meaningful when executing eagerly.
764
765    if isinstance(index, ops.EagerTensor):
766      index = index.numpy()
767
768    if index < 0:
769      raise errors_impl.OutOfRangeError(
770          None, None,
771          "Reading from negative indices (index %d) is not allowed." % index)
772
773    if index >= len(self._tensor_array):
774      raise errors_impl.OutOfRangeError(
775          None, None, "Tried to read from index %d but array size is: %d " %
776          (index, len(self._tensor_array)))
777
778    tensor = self._tensor_array[index]
779    if tensor is None:
780      if index in self._previously_read_indices:
781        raise errors_impl.InvalidArgumentError(
782            None, None,
783            "Could not read index %d twice because it was cleared after "
784            "a previous read (perhaps try setting clear_after_read = false?)" %
785            index)
786      else:
787        tensor = self._maybe_zero(index)
788
789    if self._clear_after_read:
790      self._tensor_array[index] = None
791      self._previously_read_indices.append(index)
792    return tensor
793
794  def _write(self, index, value):
795    """Writes `value` into index named by `index`.
796
797    Args:
798      index: 0-D.  int32 scalar with the index to write to.
799      value: N-D.  Tensor of type `dtype`.  The `Tensor` to write to `index`.
800
801    Raises:
802      errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
803      errors_impl.OutOfRangeError: `index` is out of bounds.
804      ValueError: shape of `value` is not consistent with inferred shape.
805    """
806
807    if isinstance(index, ops.EagerTensor):
808      index = index.numpy()
809
810    if index < 0:
811      raise errors_impl.OutOfRangeError(
812          None, None,
813          "Writing to negative indices (index %d) is not allowed." % index)
814
815    size = len(self._tensor_array)
816    if index >= size:
817      if not self._dynamic_size:
818        raise errors_impl.OutOfRangeError(
819            None, None,
820            "Tried to write to index %d but array is not resizeable and size "
821            "is: %d " % (index, size))
822      self._tensor_array.extend(None for _ in range(index - size + 1))
823
824    if not isinstance(value, ops.EagerTensor):
825      # TODO(b/129870929): Fix after all callers provide proper init dtype.
826      value = ops.convert_to_tensor(
827          value, preferred_dtype=self._dtype, name="value")
828
829    if self._dtype != value.dtype:
830      raise errors_impl.InvalidArgumentError(
831          None, None,
832          "TensorArray dtype is %s but Op is trying to write dtype %s " %
833          (self._dtype.name, value.dtype.name))
834
835    if not self._element_shape.is_compatible_with(value.shape):
836      raise ValueError("Incompatible shape for value (%s), expected (%s)" %
837                       (value.shape, self._element_shape))
838
839    if self._infer_shape:
840      self._element_shape = self._element_shape.merge_with(value.shape)
841
842    self._tensor_array[index] = value
843
844  def write(self, index, value, name=None):
845    """See TensorArray."""
846    del name  # not meaningful when executing eagerly.
847    self._write(index, value)
848    return self.parent()
849
850  def _maybe_zero(self, ix):
851    val = self._tensor_array[ix]
852    if val is None:
853      val = self._tensor_array[ix] = array_ops.zeros(
854          shape=self._element_shape, dtype=self._dtype)
855    return val
856
857  def stack(self, name=None):
858    """See TensorArray."""
859    if self._tensor_array:
860      for ix in range(len(self._tensor_array)):
861        self._maybe_zero(ix)
862    if not self._tensor_array and self._element_shape.is_fully_defined():
863      return ops.convert_to_tensor(
864          np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype)
865    else:
866      return ops.convert_to_tensor(
867          self._tensor_array, name=name, dtype=self._dtype)
868
869  def gather(self, indices, name=None):
870    """See TensorArray."""
871    del name  # not meaningful when executing eagerly.
872    if isinstance(indices, ops.EagerTensor):
873      indices = indices.numpy()
874    return array_ops.stack([self._maybe_zero(i) for i in indices])
875
876  def concat(self, name=None):
877    """See TensorArray."""
878    try:
879      return array_ops.concat(
880          [self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
881          0,
882          name=name)
883    except errors_impl.OpError:
884      # Reproduce a subset of the error-handling for graph-mode TensorArrays.
885      shapes = [t.shape for t in self._tensor_array]
886      ndims = [s.ndims for s in shapes]
887      if 0 in ndims:
888        idx = ndims.index(0)
889        raise errors_impl.InvalidArgumentError(
890            None, None, "Concat saw a scalar shape at index %d but requires "
891            "at least vectors." % idx)
892      else:
893        raise
894
895  def unstack(self, value, name=None):
896    """See TensorArray."""
897    tensors = array_ops.unstack(value, name=name)
898    if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
899      raise ValueError(
900          "Cannot unstack %d tensors into a TensorArray of static size %d " %
901          (len(tensors), len(self._tensor_array)))
902    self._tensor_array = tensors
903    return self.parent()
904
905  def scatter(self, indices, value, name=None):
906    """See TensorArray."""
907    del name  # not meaningful when executing eagerly.
908    if isinstance(indices, ops.EagerTensor):
909      indices = indices.numpy()
910    for index, val in zip(indices, array_ops.unstack(value)):
911      self._write(index, val)  # pylint: disable=protected-access
912    return self.parent()
913
914  def split(self, value, lengths, name=None):
915    """See TensorArray."""
916    # TODO(b/129870929): Fix after all callers provide proper init dtype.
917    value = ops.convert_to_tensor(
918        value, preferred_dtype=self._dtype, name="value")
919    _check_dtypes(value, self._dtype)
920    lengths = ops.convert_to_tensor(lengths)
921    sum_lengths = math_ops.reduce_sum(lengths)
922    if lengths.shape.ndims != 1:
923      raise errors_impl.InvalidArgumentError(
924          None, None, "Expected lengths to be a vector, received shape: %s " %
925          lengths.shape.as_list())
926    elif value.shape.ndims == 0:
927      raise errors_impl.InvalidArgumentError(
928          None, None, "Expected value to be at least a vector, "
929          "but received shape: %s " % value.shape.as_list())
930    elif sum_lengths.numpy() != value.shape.as_list()[0]:
931      raise errors_impl.InvalidArgumentError(
932          None, None, "Expected sum of lengths to be equal to "
933          "values.shape[0], but sum of lengths is %d and "
934          "value's shape is: %s " % (sum_lengths.numpy(),
935                                     value.shape.as_list()))
936    elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
937      raise errors_impl.InvalidArgumentError(
938          None, None, "TensorArray's size is not equal to the size of "
939          "lengths (%d vs. %d), and the TensorArray is not marked as "
940          "dynamically resizeable." %
941          (len(self._tensor_array), lengths.shape[0]))
942    else:
943      self._tensor_array = array_ops.split(value, lengths, name=name)
944      return self.parent()
945
946  def size(self, name=None):
947    """See TensorArray."""
948    del name  # not meaningful when executing eagerly.
949    return constant_op.constant(len(self._tensor_array))
950
951  def close(self, name=None):
952    del name  # not meaningful when executing eagerly.
953    del self._tensor_array[:]
954
955
956# TensorArray is designed to hide an underlying implementation object
957# and as such accesses many of that object's hidden fields.
958# pylint: disable=protected-access
959# pylint:disable=line-too-long
960@tf_export("TensorArray")
961class TensorArray(object):
962  """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
963
964  This class is meant to be used with dynamic iteration primitives such as
965  `while_loop` and `map_fn`.  It supports gradient back-propagation via special
966  "flow" control flow dependencies.
967
968  Example 1: Plain reading and writing.
969
970  >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
971  >>> ta = ta.write(0, 10)
972  >>> ta = ta.write(1, 20)
973  >>> ta = ta.write(2, 30)
974  >>>
975  >>> ta.read(0)
976  <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
977  >>> ta.read(1)
978  <tf.Tensor: shape=(), dtype=float32, numpy=20.0>
979  >>> ta.read(2)
980  <tf.Tensor: shape=(), dtype=float32, numpy=30.0>
981  >>> ta.stack()
982  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.],
983  dtype=float32)>
984
985  Example 2: Fibonacci sequence algorithm that writes in a loop then returns.
986
987  >>> @tf.function
988  ... def fibonacci(n):
989  ...   ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
990  ...   ta = ta.unstack([0., 1.])
991  ...
992  ...   for i in range(2, n):
993  ...     ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2))
994  ...
995  ...   return ta.stack()
996  >>>
997  >>> fibonacci(7)
998  <tf.Tensor: shape=(7,), dtype=float32,
999  numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)>
1000
1001  Example 3: A simple loop interacting with a `tf.Variable`.
1002
1003  >>> v = tf.Variable(1)
1004  >>> @tf.function
1005  ... def f(x):
1006  ...   ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
1007  ...   for i in tf.range(x):
1008  ...     v.assign_add(i)
1009  ...     ta = ta.write(i, v)
1010  ...   return ta.stack()
1011  >>> f(5)
1012  <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1,  2,  4,  7, 11],
1013  dtype=int32)>
1014  """
1015
1016  def __init__(self,
1017               dtype,
1018               size=None,
1019               dynamic_size=None,
1020               clear_after_read=None,
1021               tensor_array_name=None,
1022               handle=None,
1023               flow=None,
1024               infer_shape=True,
1025               element_shape=None,
1026               colocate_with_first_write_call=True,
1027               name=None):
1028    """Construct a new TensorArray or wrap an existing TensorArray handle.
1029
1030    A note about the parameter `name`:
1031
1032    The name of the `TensorArray` (even if passed in) is uniquified: each time
1033    a new `TensorArray` is created at runtime it is assigned its own name for
1034    the duration of the run.  This avoids name collisions if a `TensorArray`
1035    is created within a `while_loop`.
1036
1037    Args:
1038      dtype: (required) data type of the TensorArray.
1039      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
1040        Required if handle is not provided.
1041      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
1042        can grow the TensorArray past its initial size.  Default: False.
1043      clear_after_read: Boolean (optional, default: True).  If True, clear
1044        TensorArray values after reading them.  This disables read-many
1045        semantics, but allows early release of memory.
1046      tensor_array_name: (optional) Python string: the name of the TensorArray.
1047        This is used when creating the TensorArray handle.  If this value is
1048        set, handle should be None.
1049      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
1050        is set, tensor_array_name should be None. Only supported in graph mode.
1051      flow: (optional) A float `Tensor` scalar coming from an existing
1052        `TensorArray.flow`. Only supported in graph mode.
1053      infer_shape: (optional, default: True) If True, shape inference is
1054        enabled.  In this case, all elements must have the same shape.
1055      element_shape: (optional, default: None) A `TensorShape` object specifying
1056        the shape constraints of each of the elements of the TensorArray. Need
1057        not be fully defined.
1058      colocate_with_first_write_call: If `True`, the TensorArray will be
1059        colocated on the same device as the Tensor used on its first write
1060        (write operations include `write`, `unstack`, and `split`).  If `False`,
1061        the TensorArray will be placed on the device determined by the device
1062        context available during its initialization.
1063      name: A name for the operation (optional).
1064
1065    Raises:
1066      ValueError: if both handle and tensor_array_name are provided.
1067      TypeError: if handle is provided but is not a Tensor.
1068    """
1069    if (context.executing_eagerly() and
1070        (flow is None or flow.dtype != dtypes.variant)):
1071      # It is possible to create a Variant-style TensorArray even in eager mode,
1072      # and this is fine but can have performance implications in eager.
1073      # An example of when this happens is if a tf.function returns a
1074      # TensorArray in its output; its flow variant object is returned to Eager.
1075      # This can be wrapped back up in a Variant-style TensorArray.
1076      implementation = _EagerTensorArray
1077    elif (flow is not None and flow.dtype == dtypes.variant or
1078          control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1079      implementation = _GraphTensorArrayV2
1080    else:
1081      implementation = _GraphTensorArray
1082    self._implementation = implementation(
1083        dtype,
1084        size=size,
1085        dynamic_size=dynamic_size,
1086        clear_after_read=clear_after_read,
1087        tensor_array_name=tensor_array_name,
1088        handle=handle,
1089        flow=flow,
1090        infer_shape=infer_shape,
1091        element_shape=element_shape,
1092        colocate_with_first_write_call=colocate_with_first_write_call,
1093        name=name)
1094
1095    self._implementation.parent = weakref.ref(self)
1096
1097  @property
1098  def flow(self):
1099    """The flow `Tensor` forcing ops leading to this TensorArray state."""
1100    return self._implementation._flow
1101
1102  @property
1103  def dtype(self):
1104    """The data type of this TensorArray."""
1105    return self._implementation._dtype
1106
1107  @property
1108  def handle(self):
1109    """The reference to the TensorArray."""
1110    return self._implementation.handle
1111
1112  @property
1113  def element_shape(self):
1114    """The `tf.TensorShape` of elements in this TensorArray."""
1115    return self._implementation.element_shape
1116
1117  @property
1118  def dynamic_size(self):
1119    """Python bool; if `True` the TensorArray can grow dynamically."""
1120    return self._implementation._dynamic_size
1121
1122  @property
1123  def _infer_shape(self):
1124    # TODO(slebedev): consider making public or changing TensorArrayStructure
1125    # to access _implementation directly. Note that dynamic_size is also
1126    # only used by TensorArrayStructure.
1127    return self._implementation._infer_shape
1128
1129  def identity(self):
1130    """Returns a TensorArray with the same content and properties.
1131
1132    Returns:
1133      A new TensorArray object with flow that ensures the control dependencies
1134      from the contexts will become control dependencies for writes, reads, etc.
1135      Use this object for all subsequent operations.
1136    """
1137    return self._implementation.identity()
1138
1139  def grad(self, source, flow=None, name=None):
1140    return self._implementation.grad(source, flow=flow, name=name)
1141
1142  def read(self, index, name=None):
1143    """Read the value at location `index` in the TensorArray.
1144
1145    Args:
1146      index: 0-D.  int32 tensor with the index to read from.
1147      name: A name for the operation (optional).
1148
1149    Returns:
1150      The tensor at index `index`.
1151    """
1152    return self._implementation.read(index, name=name)
1153
1154  @tf_should_use.should_use_result(warn_in_eager=True)
1155  def write(self, index, value, name=None):
1156    """Write `value` into index `index` of the TensorArray.
1157
1158    Args:
1159      index: 0-D.  int32 scalar with the index to write to.
1160      value: N-D.  Tensor of type `dtype`.  The Tensor to write to this index.
1161      name: A name for the operation (optional).
1162
1163    Returns:
1164      A new TensorArray object with flow that ensures the write occurs.
1165      Use this object for all subsequent operations.
1166
1167    Raises:
1168      ValueError: if there are more writers than specified.
1169    """
1170    return self._implementation.write(index, value, name=name)
1171
1172  def stack(self, name=None):
1173    """Return the values in the TensorArray as a stacked `Tensor`.
1174
1175    All of the values must have been written and their shapes must all match.
1176    If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
1177
1178    For example:
1179
1180
1181    >>> ta = tf.TensorArray(tf.int32, size=3)
1182    >>> ta.write(0, tf.constant([1, 2]))
1183    >>> ta.write(1, tf.constant([3, 4]))
1184    >>> ta.write(2, tf.constant([5, 6]))
1185    >>> ta.stack()
1186    <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
1187    array([[1, 2],
1188           [3, 4],
1189           [5, 6]], dtype=int32)>
1190
1191
1192    Args:
1193      name: A name for the operation (optional).
1194
1195    Returns:
1196      All the tensors in the TensorArray stacked into one tensor.
1197    """
1198    return self._implementation.stack(name=name)
1199
1200  def gather(self, indices, name=None):
1201    """Return selected values in the TensorArray as a packed `Tensor`.
1202
1203    All of selected values must have been written and their shapes
1204    must all match.
1205
1206    Args:
1207      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If the
1208        `TensorArray` is not dynamic, `max_value=size()`.
1209      name: A name for the operation (optional).
1210
1211    Returns:
1212      The tensors in the `TensorArray` selected by `indices`, packed into one
1213      tensor.
1214    """
1215    return self._implementation.gather(indices, name=name)
1216
1217  def concat(self, name=None):
1218    """Return the values in the TensorArray as a concatenated `Tensor`.
1219
1220    All of the values must have been written, their ranks must match, and
1221    and their shapes must all match for all dimensions except the first.
1222
1223    Args:
1224      name: A name for the operation (optional).
1225
1226    Returns:
1227      All the tensors in the TensorArray concatenated into one tensor.
1228    """
1229    return self._implementation.concat(name=name)
1230
1231  @tf_should_use.should_use_result
1232  def unstack(self, value, name=None):
1233    """Unstack the values of a `Tensor` in the TensorArray.
1234
1235    If input value shapes have rank-`R`, then the output TensorArray will
1236    contain elements whose shapes are rank-`(R-1)`.
1237
1238    Args:
1239      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unstack.
1240      name: A name for the operation (optional).
1241
1242    Returns:
1243      A new TensorArray object with flow that ensures the unstack occurs.
1244      Use this object for all subsequent operations.
1245
1246    Raises:
1247      ValueError: if the shape inference fails.
1248    """
1249    return self._implementation.unstack(value, name=name)
1250
1251  @tf_should_use.should_use_result
1252  def scatter(self, indices, value, name=None):
1253    """Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
1254
1255    Args:
1256      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If the
1257        `TensorArray` is not dynamic, `max_value=size()`.
1258      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unpack.
1259      name: A name for the operation (optional).
1260
1261    Returns:
1262      A new TensorArray object with flow that ensures the scatter occurs.
1263      Use this object for all subsequent operations.
1264
1265    Raises:
1266      ValueError: if the shape inference fails.
1267    """
1268    return self._implementation.scatter(indices, value, name=name)
1269
1270  @tf_should_use.should_use_result
1271  def split(self, value, lengths, name=None):
1272    """Split the values of a `Tensor` into the TensorArray.
1273
1274    Args:
1275      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to split.
1276      lengths: 1-D.  int32 vector with the lengths to use when splitting `value`
1277        along its first dimension.
1278      name: A name for the operation (optional).
1279
1280    Returns:
1281      A new TensorArray object with flow that ensures the split occurs.
1282      Use this object for all subsequent operations.
1283
1284    Raises:
1285      ValueError: if the shape inference fails.
1286    """
1287    return self._implementation.split(value, lengths, name=name)
1288
1289  def size(self, name=None):
1290    """Return the size of the TensorArray."""
1291    return self._implementation.size(name=name)
1292
1293  @tf_should_use.should_use_result
1294  def close(self, name=None):
1295    """Close the current TensorArray."""
1296    return self._implementation.close(name=name)
1297
1298
1299def build_ta_with_new_flow(old_ta, flow):
1300  """Builds a TensorArray with a new `flow` tensor."""
1301  # Sometimes we get old_ta as the implementation, sometimes it's the
1302  # TensorArray wrapper object.
1303  impl = (old_ta._implementation if isinstance(old_ta, TensorArray) else old_ta)
1304
1305  if not context.executing_eagerly():
1306    if (not isinstance(impl, _GraphTensorArrayV2) and
1307        control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1308      raise NotImplementedError("Attempting to build a graph-mode TF2-style "
1309                                "TensorArray from either an eager-mode "
1310                                "TensorArray or a TF1-style TensorArray.  "
1311                                "This is not currently supported.  You may be "
1312                                "attempting to capture a TensorArray "
1313                                "inside a tf.function or tf.data map function. "
1314                                "Instead, construct a new TensorArray inside "
1315                                "the function.")
1316  new_ta = TensorArray(
1317      dtype=impl.dtype,
1318      handle=impl.handle,
1319      flow=flow,
1320      infer_shape=impl._infer_shape,
1321      colocate_with_first_write_call=impl._colocate_with_first_write_call)
1322  new_impl = new_ta._implementation
1323  new_impl._dynamic_size = impl._dynamic_size
1324  new_impl._size = impl._size
1325  new_impl._colocate_with = impl._colocate_with
1326  new_impl._element_shape = impl._element_shape  # Share _element_shape.
1327  return new_ta
1328
1329
1330# pylint: enable=protected-access
1331
1332
1333def _check_dtypes(value, dtype):
1334  if value.dtype != dtype:
1335    logging.error("Error: Input value {} has dtype {}, but expected dtype {}.  "
1336                  "This leads to undefined behavior and will be an error "
1337                  "in future versions of TensorFlow.  Traceback:\n{}".format(
1338                      value, str(value.dtype), str(dtype),
1339                      "".join(traceback.format_stack())))
1340
1341
1342@tf_export("TensorArraySpec")
1343@type_spec.register("tf.TensorArraySpec")
1344class TensorArraySpec(type_spec.TypeSpec):
1345  """Type specification for a `tf.TensorArray`."""
1346
1347  __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"]
1348
1349  value_type = property(lambda self: TensorArray)
1350
1351  def __init__(self,
1352               element_shape=None,
1353               dtype=dtypes.float32,
1354               dynamic_size=False,
1355               infer_shape=True):
1356    """Constructs a type specification for a `tf.TensorArray`.
1357
1358    Args:
1359      element_shape: The shape of each element in the `TensorArray`.
1360      dtype: Data type of the `TensorArray`.
1361      dynamic_size: Whether the `TensorArray` can grow past its initial size.
1362      infer_shape: Whether shape inference is enabled.
1363    """
1364    self._element_shape = tensor_shape.as_shape(element_shape)
1365    self._dtype = dtypes.as_dtype(dtype)
1366    self._dynamic_size = dynamic_size
1367    self._infer_shape = infer_shape
1368
1369  def is_compatible_with(self, other):
1370    # pylint: disable=protected-access
1371    if not isinstance(other, type_spec.TypeSpec):
1372      other = type_spec.type_spec_from_value(other)
1373
1374    # Note: we intentionally exclude infer_shape in this check.
1375    return (isinstance(other, TensorArraySpec) and
1376            self._dtype.is_compatible_with(other._dtype) and
1377            self._element_shape.is_compatible_with(other._element_shape) and
1378            self._dynamic_size == other._dynamic_size)
1379
1380  def most_specific_compatible_type(self, other):
1381    # pylint: disable=protected-access
1382    if not self.is_compatible_with(other):
1383      raise ValueError(f"Type `{self}` is not compatible with `{other}`.")
1384    infer_shape = self._infer_shape and other._infer_shape
1385    return TensorArraySpec(
1386        self._element_shape.most_specific_compatible_shape(
1387            other._element_shape), self._dtype, self._dynamic_size, infer_shape)
1388
1389  def _serialize(self):
1390    return (self._element_shape, self._dtype, self._dynamic_size,
1391            self._infer_shape)
1392
1393  @property
1394  def _component_specs(self):
1395    return [tensor_spec.TensorSpec([], dtypes.variant)]
1396
1397  def _to_components(self, value):
1398    if not isinstance(value, TensorArray):
1399      raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
1400          type(value)))
1401    if value.flow is not None and value.flow.dtype == dtypes.variant:
1402      return [value.flow]
1403    else:
1404      # Convert to a TF2-style TensorArray.
1405      # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
1406      # "implementation / as_variant" arg to TensorArray constructor.
1407      with ops.name_scope("convert_tensor_array"):
1408        flow = list_ops.tensor_list_from_tensor(
1409            tensor=value.stack(), element_shape=value.element_shape)
1410      return [flow]
1411
1412  def _from_components(self, tensor_list):
1413    # This will return a TF2 Graph-style TensorArray because tensor_list[0] is
1414    # a variant object.  size == -1 implies unknown size.
1415    ret = TensorArray(
1416        dtype=self._dtype,
1417        flow=tensor_list[0],
1418        dynamic_size=self._dynamic_size,
1419        infer_shape=self._infer_shape)
1420    ret._implementation._element_shape = [self._element_shape]  # pylint: disable=protected-access
1421    return ret
1422
1423  @staticmethod
1424  def from_value(value):
1425    if not isinstance(value, TensorArray):
1426      raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
1427          type(value)))
1428
1429    return TensorArraySpec(
1430        dtype=value.dtype,
1431        element_shape=value.element_shape,
1432        dynamic_size=value.dynamic_size,
1433        infer_shape=value._infer_shape)  # pylint: disable=protected-access
1434
1435  def _to_legacy_output_types(self):
1436    return self._dtype
1437
1438  def _to_legacy_output_shapes(self):
1439    # Sneak the dynamic_size and infer_shape values into the legacy shape.
1440    return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape
1441                                     ]).concatenate(self._element_shape))
1442
1443  def _to_legacy_output_classes(self):
1444    return TensorArray
1445
1446
1447# Register the TypeSpec for TensorArray.  If TensorArray is updated to be a
1448# CompositeTensor, then this registration can be deleted.
1449type_spec.register_type_spec_from_value_converter(
1450    TensorArray, TensorArraySpec.from_value, allow_subclass=True)
1451