• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorArray: a dynamically sized array of Tensors."""
16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23
24import traceback
25import weakref
26
27import numpy as np
28
29from tensorflow.python.eager import context
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors_impl
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.framework import type_spec
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_util
40from tensorflow.python.ops import gen_control_flow_ops
41from tensorflow.python.ops import gen_data_flow_ops
42from tensorflow.python.ops import list_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import tf_should_use
46from tensorflow.python.util.tf_export import tf_export
47
48
49# _GraphTensorArray accesses many of the hidden generated ops, but is in
50# fact built to wrap these methods.
51# pylint: disable=protected-access
52class _GraphTensorArray(object):
53  """Graph-mode implementation of TensorArray.
54  """
55
56  def __init__(self,
57               dtype,
58               size=None,
59               dynamic_size=None,
60               clear_after_read=None,
61               tensor_array_name=None,
62               handle=None,
63               flow=None,
64               infer_shape=True,
65               element_shape=None,
66               colocate_with_first_write_call=True,
67               name=None):
68    """Constructs a graph mode TensorArray.
69
70    Args:
71      dtype: (required) data type of the TensorArray.
72      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
73        Required if handle is not provided.
74      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
75        can grow the TensorArray past its initial size.  Default: False.
76      clear_after_read: Boolean (optional, default: True).  If True, clear
77        TensorArray values after reading them.  This disables read-many
78        semantics, but allows early release of memory.
79      tensor_array_name: (optional) Python string: the name of the TensorArray.
80        This is used when creating the TensorArray handle.  If this value is
81        set, handle should be None.
82      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
83        is set, tensor_array_name should be None. Only supported in graph mode.
84      flow: (optional) A float `Tensor` scalar coming from an existing
85        `TensorArray.flow`. Only supported in graph mode.
86      infer_shape: (optional, default: True) If True, shape inference
87        is enabled.  In this case, all elements must have the same shape.
88      element_shape: (optional, default: None) A `TensorShape` object specifying
89        the shape constraints of each of the elements of the TensorArray.
90        Need not be fully defined.
91      colocate_with_first_write_call: If `True`, the TensorArray will be
92        colocated on the same device as the Tensor used on its first write
93        (write operations include `write`, `unstack`, and `split`).  If `False`,
94        the TensorArray will be placed on the device determined by the
95        device context available during its initialization.
96      name: A name for the operation (optional).
97
98    Raises:
99      ValueError: if both handle and tensor_array_name are provided.
100      TypeError: if handle is provided but is not a Tensor.
101    """
102    if handle is not None and tensor_array_name:
103      raise ValueError(
104          "Cannot construct with both handle and tensor_array_name")
105    if handle is not None and not isinstance(handle, ops.Tensor):
106      raise TypeError("Handle must be a Tensor")
107    if handle is None and size is None:
108      raise ValueError("Size must be provided if handle is not provided")
109    if handle is not None and size is not None:
110      raise ValueError("Cannot provide both a handle and size "
111                       "at the same time")
112    if handle is not None and element_shape is not None:
113      raise ValueError("Cannot provide both a handle and element_shape "
114                       "at the same time")
115    if handle is not None and dynamic_size is not None:
116      raise ValueError("Cannot provide both a handle and dynamic_size "
117                       "at the same time")
118    if handle is not None and clear_after_read is not None:
119      raise ValueError("Cannot provide both a handle and clear_after_read "
120                       "at the same time")
121
122    if clear_after_read is None:
123      clear_after_read = True
124    self._dynamic_size = dynamic_size or False
125    self._dtype = dtypes.as_dtype(dtype).base_dtype
126
127    # Used to keep track of what tensors the TensorArray should be
128    # colocated with.  We choose to colocate the TensorArray with the
129    # first tensor written to it.
130    self._colocate_with_first_write_call = colocate_with_first_write_call
131    if colocate_with_first_write_call:
132      self._colocate_with = []
133    else:
134      self._colocate_with = None
135
136    # Record the current static shape for the array elements. The element
137    # shape is defined either by `element_shape` or the shape of the tensor
138    # of the first write. If `infer_shape` is true, all writes checks for
139    # shape equality.
140    self._element_shape = [tensor_shape.as_shape(element_shape)]
141    self._infer_shape = infer_shape
142    self._size = size
143    with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
144      if handle is not None:
145        self._handle = handle
146        if flow is None:
147          raise ValueError("flow must not be None if handle is not None.")
148        self._flow = flow
149      else:
150        # Construct the TensorArray with an empty device.  The first
151        # write into the TensorArray from a Tensor with a set device
152        # will retroactively set the device value of this op.
153        def create():
154          """Create the TensorArray op."""
155          return gen_data_flow_ops.tensor_array_v3(
156              dtype=dtype,
157              size=size,
158              element_shape=element_shape,
159              identical_element_shapes=infer_shape,
160              dynamic_size=self._dynamic_size,
161              clear_after_read=clear_after_read,
162              tensor_array_name=tensor_array_name,
163              name=scope)
164        if colocate_with_first_write_call:
165          with ops.device(None), ops.colocate_with(None, ignore_existing=True):
166            self._handle, self._flow = create()
167        else:
168          self._handle, self._flow = create()
169
170  @property
171  def flow(self):
172    return self._flow
173
174  @property
175  def dtype(self):
176    return self._dtype
177
178  @property
179  def handle(self):
180    return self._handle
181
182  @property
183  def element_shape(self):
184    return self._element_shape[0]
185
186  def _check_element_shape(self, shape):
187    """Changes the element shape of the array given a shape to merge with.
188
189    Args:
190      shape: A `TensorShape` object to merge with.
191
192    Raises:
193      ValueError: if the provided shape is incompatible with the current
194          element shape of the `TensorArray`.
195    """
196    if not shape.is_compatible_with(self.element_shape):
197      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
198                       (shape, self.element_shape))
199    if self._infer_shape:
200      self._element_shape[0] = self.element_shape.merge_with(shape)
201
202  @contextlib.contextmanager
203  def _maybe_colocate_with(self, value):
204    """Colocate operations with an internal colocation group or `value`.
205
206    Args:
207      value: `Tensor`, the tensor to try to colocate with.
208
209    Yields:
210      Does not yield anything, but the new context is a colocation context.
211
212    If no internal colocation group is set, colocate with `value` and set
213    the internal colocation group to be value.
214    """
215    if not self._colocate_with_first_write_call:
216      yield
217    else:
218      if not self._colocate_with:
219        self._colocate_with.append(value)
220      with ops.colocate_with(self._colocate_with[0]):
221        yield
222
223  def identity(self):
224    """See TensorArray."""
225    flow = array_ops.identity(self._flow)
226    return build_ta_with_new_flow(self, flow)
227
228  def grad(self, source, flow=None, name=None):
229    """See TensorArray."""
230    # tensor_array_grad requires a flow input when forward
231    # TensorArrays are dynamically sized.  This forces the creation
232    # of the grad TensorArray only once the final forward array's size
233    # is fixed.
234    if flow is None:
235      flow = self.flow
236    with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
237      with ops.colocate_with(self._handle):
238        g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3(
239            handle=self._handle, source=source, flow_in=flow, name=name)
240        with ops.control_dependencies([g_handle]):
241          flow = array_ops.identity(flow, name="gradient_flow")
242        g = TensorArray(
243            dtype=self._dtype,
244            handle=g_handle,
245            flow=flow,
246            infer_shape=self._infer_shape,
247            colocate_with_first_write_call=False)
248        # pylint: disable=protected-access
249        g._implementation._element_shape = self._element_shape
250        # pylint: enable=protected-access
251        return g
252
253  def read(self, index, name=None):
254    """See TensorArray."""
255    value = gen_data_flow_ops.tensor_array_read_v3(
256        handle=self._handle,
257        index=index,
258        flow_in=self._flow,
259        dtype=self._dtype,
260        name=name)
261    if self._element_shape:
262      value.set_shape(self._element_shape[0].dims)
263    return value
264
265  def write(self, index, value, name=None):
266    """See TensorArray."""
267    with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
268      # TODO(b/129870929): Fix after all callers provide proper init dtype.
269      value = ops.convert_to_tensor(
270          value, preferred_dtype=self._dtype, name="value")
271      _check_dtypes(value, self._dtype)
272      self._check_element_shape(value.shape)
273      with self._maybe_colocate_with(value):
274        flow_out = gen_data_flow_ops.tensor_array_write_v3(
275            handle=self._handle,
276            index=index,
277            value=value,
278            flow_in=self._flow,
279            name=name)
280      return build_ta_with_new_flow(self, flow_out)
281
282  def stack(self, name=None):
283    """See TensorArray."""
284    with ops.colocate_with(self._handle):
285      with ops.name_scope(name, "TensorArrayStack", [self._handle]):
286        value = self.gather(math_ops.range(0, self.size()), name=name)
287        if (self.element_shape and not self._dynamic_size and
288            self._size is not None):
289          value.set_shape([tensor_util.constant_value(self._size)] +
290                          self.element_shape.dims)
291        return value
292
293  def gather(self, indices, name=None):
294    """See TensorArray."""
295    if self._element_shape:
296      element_shape = self._element_shape[0]
297    else:
298      element_shape = tensor_shape.unknown_shape(None)
299    value = gen_data_flow_ops.tensor_array_gather_v3(
300        handle=self._handle,
301        indices=indices,
302        flow_in=self._flow,
303        dtype=self._dtype,
304        name=name,
305        element_shape=element_shape)
306    if self.element_shape:
307      value.set_shape([None] + self.element_shape.dims)
308    return value
309
310  def concat(self, name=None):
311    """See TensorArray."""
312    value, _ = gen_data_flow_ops.tensor_array_concat_v3(
313        handle=self._handle,
314        flow_in=self._flow,
315        dtype=self._dtype,
316        name=name,
317        element_shape_except0=self.element_shape[1:])
318    if self.element_shape:
319      value.set_shape([None] + self.element_shape.dims[1:])
320    return value
321
322  @tf_should_use.should_use_result
323  def unstack(self, value, name=None):
324    """See TensorArray."""
325    with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
326      num_elements = array_ops.shape(value)[0]
327      return self.scatter(
328          indices=math_ops.range(0, num_elements), value=value, name=name)
329
330  @tf_should_use.should_use_result
331  def scatter(self, indices, value, name=None):
332    """See TensorArray."""
333    with ops.name_scope(name, "TensorArrayScatter",
334                        [self._handle, value, indices]):
335      # TODO(b/129870929): Fix after all callers provide proper init dtype.
336      value = ops.convert_to_tensor(
337          value, preferred_dtype=self._dtype, name="value")
338      _check_dtypes(value, self._dtype)
339      if not context.executing_eagerly():
340        self._check_element_shape(value.shape[1:])
341      with self._maybe_colocate_with(value):
342        flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
343            handle=self._handle,
344            indices=indices,
345            value=value,
346            flow_in=self._flow,
347            name=name)
348      return build_ta_with_new_flow(self, flow_out)
349
350  @tf_should_use.should_use_result
351  def split(self, value, lengths, name=None):
352    """See TensorArray."""
353    with ops.name_scope(name, "TensorArraySplit",
354                        [self._handle, value, lengths]):
355      value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
356      with self._maybe_colocate_with(value):
357        lengths_64 = math_ops.cast(lengths, dtypes.int64)
358        if not context.executing_eagerly():
359          clengths = tensor_util.constant_value(lengths_64)
360          if value.shape.dims is not None and clengths is not None:
361            if clengths.shape and clengths.max() == clengths.min():
362              self._check_element_shape(
363                  tensor_shape.TensorShape([clengths[0]]).concatenate(
364                      value.shape[1:]))
365        flow_out = gen_data_flow_ops.tensor_array_split_v3(
366            handle=self._handle,
367            value=value,
368            lengths=lengths_64,
369            flow_in=self._flow,
370            name=name)
371      return build_ta_with_new_flow(self, flow_out)
372
373  def size(self, name=None):
374    """See TensorArray."""
375    if not self._dynamic_size and self._size is not None:
376      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
377    else:
378      return gen_data_flow_ops.tensor_array_size_v3(
379          handle=self._handle, flow_in=self.flow, name=name)
380
381  @tf_should_use.should_use_result
382  def close(self, name=None):
383    """See TensorArray."""
384    return gen_data_flow_ops.tensor_array_close_v3(
385        handle=self._handle, name=name)
386
387
388class _GraphTensorArrayV2(object):
389  """Graph-mode implementation of TensorArray backed by TensorLists.
390
391  The backing tensor of this TensorArray is a TensorList variant tensor which is
392  stored in the `flow`. The `handle` is always none here. The reason we use the
393  `flow` field and not the `handle` field is to ensure backwards compatibility
394  with legacy control flow.
395  """
396
397  def __init__(self,
398               dtype,
399               size=None,
400               dynamic_size=None,
401               clear_after_read=None,
402               tensor_array_name=None,
403               handle=None,
404               flow=None,
405               infer_shape=True,
406               element_shape=None,
407               colocate_with_first_write_call=True,
408               name=None):
409    """Constructs a graph mode TensorArray.
410
411    Args:
412      dtype: (required) data type of the TensorArray.
413      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
414        Required if flow is not provided.
415      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
416        can grow the TensorArray past its initial size.  Default: False.
417      clear_after_read: (optional) unused. Not supported in TensorLists.
418      tensor_array_name: (optional) unused.
419      handle: (optional) Must always be None.
420      flow: (optional) A variant `Tensor` scalar for a TensorList.
421      infer_shape: (optional, default: True) If True, shape inference is
422        enabled.  In this case, all elements must have the same shape.
423      element_shape: (optional, default: None) A `TensorShape` object specifying
424        the shape constraints of each of the elements of the TensorArray. Need
425        not be fully defined.
426      colocate_with_first_write_call: (optional). unused.
427      name: (optional) A name for the operation.
428
429    Raises:
430      ValueError: if both handle and tensor_array_name are provided.
431      TypeError: if handle is provided but is not a Tensor.
432    """
433    assert handle is None
434    del handle
435    del clear_after_read
436    del tensor_array_name
437    del colocate_with_first_write_call
438
439    self._dynamic_size = dynamic_size
440    self._size = size
441
442    if (flow is not None and
443        (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)):
444      raise TypeError("flow must be a variant tensor")
445    if flow is None and size is None:
446      raise ValueError("Size must be provided if flow is not provided")
447    if flow is not None and size is not None:
448      raise ValueError("Cannot provide both a flow and size "
449                       "at the same time")
450    if flow is not None and element_shape is not None:
451      raise ValueError("Cannot provide both a flow and element_shape "
452                       "at the same time")
453
454    self._dtype = dtypes.as_dtype(dtype).base_dtype
455
456    # Record the current static shape for the array elements. The element
457    # shape is defined either by `element_shape` or the shape of the tensor
458    # of the first write. If `infer_shape` is true, all writes checks for
459    # shape equality.
460    self._element_shape = [tensor_shape.as_shape(element_shape)]
461    self._infer_shape = infer_shape
462    with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
463      if flow is None:
464        self._flow = list_ops.tensor_list_reserve(
465            element_shape=element_shape,
466            num_elements=size,
467            element_dtype=dtype,
468            name=scope)
469      else:
470        self._flow = flow
471
472    # For backwards compatibility.
473    self._colocate_with_first_write_call = None
474    self._colocate_with = None
475
476  @property
477  def flow(self):
478    return self._flow
479
480  @property
481  def dtype(self):
482    return self._dtype
483
484  @property
485  def element_shape(self):
486    return self._element_shape[0]
487
488  @property
489  def handle(self):
490    # We intentionally do not raise an error so that legacy while_loop does not
491    # complain.
492    return None
493
494  def _check_element_shape(self, shape):
495    """Changes the element shape of the array given a shape to merge with.
496
497    Args:
498      shape: A `TensorShape` object to merge with.
499
500    Raises:
501      ValueError: if the provided shape is incompatible with the current
502          element shape of the `TensorArray`.
503    """
504    if not shape.is_compatible_with(self.element_shape):
505      raise ValueError("Inconsistent shapes: saw %s but expected %s " %
506                       (shape, self.element_shape))
507    if self._infer_shape:
508      self._element_shape[0] = self.element_shape.merge_with(shape)
509
510  def identity(self):
511    """See TensorArray."""
512    flow = array_ops.identity(self._flow)
513    return build_ta_with_new_flow(self, flow)
514
515  def grad(self, source, flow=None, name=None):
516    """Not supported."""
517    raise NotImplementedError()
518
519  def read(self, index, name=None):
520    """See TensorArray."""
521    with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
522      value = list_ops.tensor_list_get_item(
523          input_handle=self._flow,
524          index=index,
525          element_dtype=self._dtype,
526          element_shape=self.element_shape,
527          name=name)
528      return value
529
530  def write(self, index, value, name=None):
531    """See TensorArray."""
532    with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
533      # TODO(b/129870929): Fix after all callers provide proper init dtype.
534      value = ops.convert_to_tensor(
535          value, preferred_dtype=self._dtype, name="value")
536      _check_dtypes(value, self._dtype)
537      self._check_element_shape(value.shape)
538      flow_out = list_ops.tensor_list_set_item(
539          input_handle=self._flow,
540          index=index,
541          item=value,
542          resize_if_index_out_of_bounds=self._dynamic_size,
543          name=name)
544      return build_ta_with_new_flow(self, flow_out)
545
546  def stack(self, name=None):
547    """See TensorArray."""
548    with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
549      # TODO(b/139941163): remove constant_value after changing num_elements to regular input
550      if not self._dynamic_size and self._size is not None:
551        ta_size = tensor_util.constant_value(self._size)
552      else:
553        ta_size = -1
554      value = list_ops.tensor_list_stack(
555          input_handle=self._flow,
556          element_dtype=self._dtype,
557          num_elements=ta_size,
558          element_shape=self.element_shape)
559      return value
560
561  def gather(self, indices, name=None):
562    """See TensorArray."""
563    value = list_ops.tensor_list_gather(
564        input_handle=self._flow,
565        indices=indices,
566        element_dtype=self._dtype,
567        element_shape=self.element_shape,
568        name=name)
569    return value
570
571  def concat(self, name=None):
572    """See TensorArray."""
573    if self.element_shape:
574      element_shape = [None] + self.element_shape.dims[1:]
575    else:
576      element_shape = None
577
578    value = list_ops.tensor_list_concat(
579        input_handle=self._flow,
580        element_dtype=self._dtype,
581        element_shape=element_shape,
582        name=name)
583    return value
584
585  @tf_should_use.should_use_result
586  def unstack(self, value, name=None):
587    """See TensorArray."""
588    with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
589      # TODO(b/129870929): Fix after all callers provide proper init dtype.
590      value = ops.convert_to_tensor(
591          value, preferred_dtype=self._dtype, name="value")
592      _check_dtypes(value, self._dtype)
593      self._check_element_shape(value.shape[1:])
594      flow_out = list_ops.tensor_list_from_tensor(
595          tensor=value, element_shape=value.shape[1:])
596      return build_ta_with_new_flow(self, flow_out)
597
598  @tf_should_use.should_use_result
599  def scatter(self, indices, value, name=None):
600    """See TensorArray."""
601    with ops.name_scope(name, "TensorArrayScatter",
602                        [self._flow, value, indices]):
603      # TODO(b/129870929): Fix after all callers provide proper init dtype.
604      value = ops.convert_to_tensor(
605          value, preferred_dtype=self._dtype, name="value")
606      _check_dtypes(value, self._dtype)
607      self._check_element_shape(value.shape[1:])
608      flow_out = list_ops.tensor_list_scatter(
609          tensor=value, indices=indices, element_shape=self.element_shape,
610          input_handle=self._flow)
611      return build_ta_with_new_flow(self, flow_out)
612
613  @tf_should_use.should_use_result
614  def split(self, value, lengths, name=None):
615    """See TensorArray."""
616    with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
617      # TODO(b/129870929): Fix after all callers provide proper init dtype.
618      value = ops.convert_to_tensor(
619          value, preferred_dtype=self._dtype, name="value")
620      _check_dtypes(value, self._dtype)
621      lengths_64 = math_ops.cast(lengths, dtypes.int64)
622      if not context.executing_eagerly():
623        clengths = tensor_util.constant_value(lengths_64)
624        if value.shape.dims is not None and clengths is not None:
625          if clengths.shape and clengths.max() == clengths.min():
626            self._check_element_shape(
627                tensor_shape.TensorShape([clengths[0]]).concatenate(
628                    value.shape[1:]))
629      flow_out = list_ops.tensor_list_split(
630          tensor=value,
631          lengths=lengths_64,
632          element_shape=self.element_shape,
633          name=name)
634      return build_ta_with_new_flow(self, flow_out)
635
636  def size(self, name=None):
637    """See TensorArray."""
638    if not self._dynamic_size and self._size is not None:
639      return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
640    else:
641      return list_ops.tensor_list_length(input_handle=self._flow, name=name)
642
643  def close(self, name=None):
644    """See TensorArray."""
645    return gen_control_flow_ops.no_op(name=name)
646
647# pylint: enable=protected-access
648
649
650class _EagerTensorArray(object):
651  """Eager-compatible implementation of TensorArray.
652  """
653
654  def __init__(self,
655               dtype,
656               size=None,
657               dynamic_size=None,
658               clear_after_read=None,
659               tensor_array_name=None,
660               handle=None,
661               flow=None,
662               infer_shape=True,
663               element_shape=None,
664               colocate_with_first_write_call=True,
665               name=None):
666    """Constructs a TensorArray compatible with eager execution.
667
668    Args:
669      dtype: (required) data type of the TensorArray.
670      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
671        Required if handle is not provided.
672      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
673        can grow the TensorArray past its initial size.  Default: False.
674      clear_after_read: Boolean (optional, default: True).  If True, clear
675        TensorArray values after reading them.  This disables read-many
676        semantics, but allows early release of memory.
677      tensor_array_name: unused.
678      handle: unsupported.
679      flow: unsupported.
680      infer_shape: used for error checking, same semantics as TensorArray.
681      element_shape: used for error checking, same semantics as TensorArray.
682      colocate_with_first_write_call: unsupported.
683      name: unsupported.
684
685    Raises:
686      ValueError: handle or flow are supplied, or if size is not supplied.
687    """
688
689    del (flow, tensor_array_name, name)  # Unused.
690
691    if handle is not None:
692      raise ValueError("TensorArray handles are not supported when eager "
693                       "execution is enabled.")
694    if size is None:
695      raise ValueError("Size must be declared for TensorArrays when eager "
696                       "execution is enabled.")
697
698    # These attributes are not meaningful when eager is enabled, but some
699    # library functions (e.g., those in control_flow_ops.py) access them to
700    # create new tensor arrays; as such, we define them for the sake of
701    # compatibility.
702    self._handle = None
703    # we assign a dummy value to _flow in case other code assumes it to be
704    # a Tensor
705    self._flow = constant_op.constant(0, dtype=dtypes.int32)
706    self._infer_shape = infer_shape
707    self._element_shape = tensor_shape.as_shape(element_shape)
708    self._colocate_with_first_write_call = colocate_with_first_write_call
709
710    self._dtype = dtypes.as_dtype(dtype).base_dtype
711    self._dynamic_size = dynamic_size or False
712    self._clear_after_read = (
713        True if clear_after_read is None else clear_after_read)
714    self._previously_read_indices = []
715
716    if isinstance(size, ops.EagerTensor):
717      size = size.numpy()
718    self._tensor_array = [None for _ in range(size)]
719
720  @property
721  def flow(self):
722    """For compatibility; flows are not meaningful when eager is enabled."""
723    return self._flow
724
725  @property
726  def dtype(self):
727    return self._dtype
728
729  @property
730  def handle(self):
731    """For compatibility; handles are not meaningful when eager is enabled."""
732    return self._handle
733
734  @property
735  def element_shape(self):
736    return self._element_shape
737
738  def identity(self):
739    """See TensorArray."""
740    return self.parent()
741
742  def grad(self, source, flow=None, name=None):
743    raise NotImplementedError(
744        "TensorArray.grad is not supported when executing eagerly; eager's "
745        "gradient implementation does not use/need this function to compute "
746        "gradients of operations that use TensorArrays.")
747
748  def read(self, index, name=None):
749    """See TensorArray."""
750    del name  # not meaningful when executing eagerly.
751
752    if isinstance(index, ops.EagerTensor):
753      index = index.numpy()
754
755    if index < 0:
756      raise errors_impl.OutOfRangeError(
757          None, None,
758          "Reading from negative indices (index %d) is not allowed." % index)
759
760    if index >= len(self._tensor_array):
761      raise errors_impl.OutOfRangeError(
762          None, None, "Tried to read from index %d but array size is: %d" %
763          (index, len(self._tensor_array)))
764
765    tensor = self._tensor_array[index]
766    if tensor is None:
767      if index in self._previously_read_indices:
768        raise errors_impl.InvalidArgumentError(
769            None, None,
770            "Could not read index %d twice because it was cleared after "
771            "a previous read (perhaps try setting clear_after_read = false?)" %
772            index)
773      else:
774        tensor = self._maybe_zero(index)
775
776    if self._clear_after_read:
777      self._tensor_array[index] = None
778      self._previously_read_indices.append(index)
779    return tensor
780
781  def _write(self, index, value):
782    """Writes `value` into index named by `index`.
783
784    Args:
785      index: 0-D.  int32 scalar with the index to write to.
786      value: N-D.  Tensor of type `dtype`.  The `Tensor` to write to `index`.
787
788    Raises:
789      errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
790      errors_impl.OutOfRangeError: `index` is out of bounds.
791      ValueError: shape of `value` is not consistent with inferred shape.
792    """
793
794    if isinstance(index, ops.EagerTensor):
795      index = index.numpy()
796
797    if index < 0:
798      raise errors_impl.OutOfRangeError(
799          None, None,
800          "Writing to negative indices (index %d) is not allowed." % index)
801
802    size = len(self._tensor_array)
803    if index >= size:
804      if not self._dynamic_size:
805        raise errors_impl.OutOfRangeError(
806            None, None,
807            "Tried to write to index %d but array is not resizeable and size "
808            "is: %d" % (index, size))
809      self._tensor_array.extend(None for _ in range(index - size + 1))
810
811    if not isinstance(value, ops.EagerTensor):
812      # TODO(b/129870929): Fix after all callers provide proper init dtype.
813      value = ops.convert_to_tensor(
814          value, preferred_dtype=self._dtype, name="value")
815
816    if self._dtype != value.dtype:
817      raise errors_impl.InvalidArgumentError(
818          None, None,
819          "TensorArray dtype is %s but Op is trying to write dtype %s" %
820          (self._dtype.name, value.dtype.name))
821
822    if not self._element_shape.is_compatible_with(value.shape):
823      raise ValueError("Incompatible shape for value (%s), expected (%s)" %
824                       (value.shape, self._element_shape))
825
826    if self._infer_shape:
827      self._element_shape = self._element_shape.merge_with(value.shape)
828
829    self._tensor_array[index] = value
830
831  def write(self, index, value, name=None):
832    """See TensorArray."""
833    del name  # not meaningful when executing eagerly.
834    self._write(index, value)
835    return self.parent()
836
837  def _maybe_zero(self, ix):
838    val = self._tensor_array[ix]
839    if val is None:
840      val = self._tensor_array[ix] = array_ops.zeros(
841          shape=self._element_shape, dtype=self._dtype)
842    return val
843
844  def stack(self, name=None):
845    """See TensorArray."""
846    if self._tensor_array:
847      for ix in range(len(self._tensor_array)):
848        self._maybe_zero(ix)
849    if not self._tensor_array and self._element_shape.is_fully_defined():
850      return ops.convert_to_tensor(
851          np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype)
852    else:
853      return ops.convert_to_tensor(
854          self._tensor_array, name=name, dtype=self._dtype)
855
856  def gather(self, indices, name=None):
857    """See TensorArray."""
858    del name  # not meaningful when executing eagerly.
859    if isinstance(indices, ops.EagerTensor):
860      indices = indices.numpy()
861    return array_ops.stack([self._maybe_zero(i) for i in indices])
862
863  def concat(self, name=None):
864    """See TensorArray."""
865    try:
866      return array_ops.concat(
867          [self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
868          0, name=name)
869    except errors_impl.OpError:
870      # Reproduce a subset of the error-handling for graph-mode TensorArrays.
871      shapes = [t.shape for t in self._tensor_array]
872      ndims = [s.ndims for s in shapes]
873      if 0 in ndims:
874        idx = ndims.index(0)
875        raise errors_impl.InvalidArgumentError(
876            None, None, "Concat saw a scalar shape at index %d but requires "
877            "at least vectors." % idx)
878      else:
879        raise
880
881  def unstack(self, value, name=None):
882    """See TensorArray."""
883    tensors = array_ops.unstack(value, name=name)
884    if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
885      raise ValueError(
886          "Cannot unstack %d tensors into a TensorArray of static size %d" %
887          (len(tensors), len(self._tensor_array)))
888    self._tensor_array = tensors
889    return self.parent()
890
891  def scatter(self, indices, value, name=None):
892    """See TensorArray."""
893    del name  # not meaningful when executing eagerly.
894    if isinstance(indices, ops.EagerTensor):
895      indices = indices.numpy()
896    for index, val in zip(indices, array_ops.unstack(value)):
897      self._write(index, val)  # pylint: disable=protected-access
898    return self.parent()
899
900  def split(self, value, lengths, name=None):
901    """See TensorArray."""
902    # TODO(b/129870929): Fix after all callers provide proper init dtype.
903    value = ops.convert_to_tensor(
904        value, preferred_dtype=self._dtype, name="value")
905    _check_dtypes(value, self._dtype)
906    lengths = ops.convert_to_tensor(lengths)
907    sum_lengths = math_ops.reduce_sum(lengths)
908    if lengths.shape.ndims != 1:
909      raise errors_impl.InvalidArgumentError(
910          None, None, "Expected lengths to be a vector, received shape: %s" %
911          lengths.shape.as_list())
912    elif value.shape.ndims == 0:
913      raise errors_impl.InvalidArgumentError(
914          None, None, "Expected value to be at least a vector, "
915          "but received shape: %s" % value.shape.as_list())
916    elif sum_lengths.numpy() != value.shape.as_list()[0]:
917      raise errors_impl.InvalidArgumentError(
918          None, None, "Expected sum of lengths to be equal to "
919          "values.shape[0], but sum of lengths is %d and "
920          "value's shape is: %s " % (sum_lengths.numpy(),
921                                     value.shape.as_list()))
922    elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
923      raise errors_impl.InvalidArgumentError(
924          None, None, "TensorArray's size is not equal to the size of "
925          "lengths (%d vs. %d), and the TensorArray is not marked as "
926          "dynamically resizeable" % (len(self._tensor_array),
927                                      lengths.shape[0]))
928    else:
929      self._tensor_array = array_ops.split(value, lengths, name=name)
930      return self.parent()
931
932  def size(self, name=None):
933    """See TensorArray."""
934    del name  # not meaningful when executing eagerly.
935    return constant_op.constant(len(self._tensor_array))
936
937  def close(self, name=None):
938    del name  # not meaningful when executing eagerly.
939    del self._tensor_array[:]
940
941
942# TensorArray is designed to hide an underlying implementation object
943# and as such accesses many of that object's hidden fields.
944# pylint: disable=protected-access
945# pylint:disable=line-too-long
946@tf_export("TensorArray")
947class TensorArray(object):
948  """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
949
950  This class is meant to be used with dynamic iteration primitives such as
951  `while_loop` and `map_fn`.  It supports gradient back-propagation via special
952  "flow" control flow dependencies.
953
954  Example 1: Plain reading and writing.
955
956  >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
957  >>> ta = ta.write(0, 10)
958  >>> ta = ta.write(1, 20)
959  >>> ta = ta.write(2, 30)
960  >>>
961  >>> ta.read(0)
962  <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
963  >>> ta.read(1)
964  <tf.Tensor: shape=(), dtype=float32, numpy=20.0>
965  >>> ta.read(2)
966  <tf.Tensor: shape=(), dtype=float32, numpy=30.0>
967  >>> ta.stack()
968  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.],
969  dtype=float32)>
970
971  Example 2: Fibonacci sequence algorithm that writes in a loop then returns.
972
973  >>> @tf.function
974  ... def fibonacci(n):
975  ...   ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
976  ...   ta = ta.unstack([0., 1.])
977  ...
978  ...   for i in range(2, n):
979  ...     ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2))
980  ...
981  ...   return ta.stack()
982  >>>
983  >>> fibonacci(7)
984  <tf.Tensor: shape=(7,), dtype=float32,
985  numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)>
986
987  Example 3: A simple loop interacting with a `tf.Variable`.
988
989  >>> v = tf.Variable(1)
990  >>> @tf.function
991  ... def f(x):
992  ...   ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
993  ...   for i in tf.range(x):
994  ...     v.assign_add(i)
995  ...     ta = ta.write(i, v)
996  ...   return ta.stack()
997  >>> f(5)
998  <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1,  2,  4,  7, 11],
999  dtype=int32)>
1000  """
1001
1002  def __init__(self,
1003               dtype,
1004               size=None,
1005               dynamic_size=None,
1006               clear_after_read=None,
1007               tensor_array_name=None,
1008               handle=None,
1009               flow=None,
1010               infer_shape=True,
1011               element_shape=None,
1012               colocate_with_first_write_call=True,
1013               name=None):
1014    """Construct a new TensorArray or wrap an existing TensorArray handle.
1015
1016    A note about the parameter `name`:
1017
1018    The name of the `TensorArray` (even if passed in) is uniquified: each time
1019    a new `TensorArray` is created at runtime it is assigned its own name for
1020    the duration of the run.  This avoids name collisions if a `TensorArray`
1021    is created within a `while_loop`.
1022
1023    Args:
1024      dtype: (required) data type of the TensorArray.
1025      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
1026        Required if handle is not provided.
1027      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
1028        can grow the TensorArray past its initial size.  Default: False.
1029      clear_after_read: Boolean (optional, default: True).  If True, clear
1030        TensorArray values after reading them.  This disables read-many
1031        semantics, but allows early release of memory.
1032      tensor_array_name: (optional) Python string: the name of the TensorArray.
1033        This is used when creating the TensorArray handle.  If this value is
1034        set, handle should be None.
1035      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
1036        is set, tensor_array_name should be None. Only supported in graph mode.
1037      flow: (optional) A float `Tensor` scalar coming from an existing
1038        `TensorArray.flow`. Only supported in graph mode.
1039      infer_shape: (optional, default: True) If True, shape inference
1040        is enabled.  In this case, all elements must have the same shape.
1041      element_shape: (optional, default: None) A `TensorShape` object specifying
1042        the shape constraints of each of the elements of the TensorArray.
1043        Need not be fully defined.
1044      colocate_with_first_write_call: If `True`, the TensorArray will be
1045        colocated on the same device as the Tensor used on its first write
1046        (write operations include `write`, `unstack`, and `split`).  If `False`,
1047        the TensorArray will be placed on the device determined by the
1048        device context available during its initialization.
1049      name: A name for the operation (optional).
1050
1051    Raises:
1052      ValueError: if both handle and tensor_array_name are provided.
1053      TypeError: if handle is provided but is not a Tensor.
1054    """
1055    if (context.executing_eagerly() and
1056        (flow is None or flow.dtype != dtypes.variant)):
1057      # It is possible to create a Variant-style TensorArray even in eager mode,
1058      # and this is fine but can have performance implications in eager.
1059      # An example of when this happens is if a tf.function returns a
1060      # TensorArray in its output; its flow variant object is returned to Eager.
1061      # This can be wrapped back up in a Variant-style TensorArray.
1062      implementation = _EagerTensorArray
1063    elif (flow is not None and flow.dtype == dtypes.variant or
1064          control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1065      implementation = _GraphTensorArrayV2
1066    else:
1067      implementation = _GraphTensorArray
1068    self._implementation = implementation(
1069        dtype,
1070        size=size,
1071        dynamic_size=dynamic_size,
1072        clear_after_read=clear_after_read,
1073        tensor_array_name=tensor_array_name,
1074        handle=handle,
1075        flow=flow,
1076        infer_shape=infer_shape,
1077        element_shape=element_shape,
1078        colocate_with_first_write_call=colocate_with_first_write_call,
1079        name=name)
1080
1081    self._implementation.parent = weakref.ref(self)
1082
1083  @property
1084  def flow(self):
1085    """The flow `Tensor` forcing ops leading to this TensorArray state."""
1086    return self._implementation._flow
1087
1088  @property
1089  def dtype(self):
1090    """The data type of this TensorArray."""
1091    return self._implementation._dtype
1092
1093  @property
1094  def handle(self):
1095    """The reference to the TensorArray."""
1096    return self._implementation.handle
1097
1098  @property
1099  def element_shape(self):
1100    """The `tf.TensorShape` of elements in this TensorArray."""
1101    return self._implementation.element_shape
1102
1103  @property
1104  def dynamic_size(self):
1105    """Python bool; if `True` the TensorArray can grow dynamically."""
1106    return self._implementation._dynamic_size
1107
1108  @property
1109  def _infer_shape(self):
1110    # TODO(slebedev): consider making public or changing TensorArrayStructure
1111    # to access _implementation directly. Note that dynamic_size is also
1112    # only used by TensorArrayStructure.
1113    return self._implementation._infer_shape
1114
1115  def identity(self):
1116    """Returns a TensorArray with the same content and properties.
1117
1118    Returns:
1119      A new TensorArray object with flow that ensures the control dependencies
1120      from the contexts will become control dependencies for writes, reads, etc.
1121      Use this object for all subsequent operations.
1122    """
1123    return self._implementation.identity()
1124
1125  def grad(self, source, flow=None, name=None):
1126    return self._implementation.grad(source, flow=flow, name=name)
1127
1128  def read(self, index, name=None):
1129    """Read the value at location `index` in the TensorArray.
1130
1131    Args:
1132      index: 0-D.  int32 tensor with the index to read from.
1133      name: A name for the operation (optional).
1134
1135    Returns:
1136      The tensor at index `index`.
1137    """
1138    return self._implementation.read(index, name=name)
1139
1140  @tf_should_use.should_use_result(warn_in_eager=True)
1141  def write(self, index, value, name=None):
1142    """Write `value` into index `index` of the TensorArray.
1143
1144    Args:
1145      index: 0-D.  int32 scalar with the index to write to.
1146      value: N-D.  Tensor of type `dtype`.  The Tensor to write to this index.
1147      name: A name for the operation (optional).
1148
1149    Returns:
1150      A new TensorArray object with flow that ensures the write occurs.
1151      Use this object for all subsequent operations.
1152
1153    Raises:
1154      ValueError: if there are more writers than specified.
1155    """
1156    return self._implementation.write(index, value, name=name)
1157
1158  def stack(self, name=None):
1159    """Return the values in the TensorArray as a stacked `Tensor`.
1160
1161    All of the values must have been written and their shapes must all match.
1162    If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
1163
1164    Args:
1165      name: A name for the operation (optional).
1166
1167    Returns:
1168      All the tensors in the TensorArray stacked into one tensor.
1169    """
1170    return self._implementation.stack(name=name)
1171
1172  def gather(self, indices, name=None):
1173    """Return selected values in the TensorArray as a packed `Tensor`.
1174
1175    All of selected values must have been written and their shapes
1176    must all match.
1177
1178    Args:
1179      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If
1180        the `TensorArray` is not dynamic, `max_value=size()`.
1181      name: A name for the operation (optional).
1182
1183    Returns:
1184      The tensors in the `TensorArray` selected by `indices`, packed into one
1185      tensor.
1186    """
1187    return self._implementation.gather(indices, name=name)
1188
1189  def concat(self, name=None):
1190    """Return the values in the TensorArray as a concatenated `Tensor`.
1191
1192    All of the values must have been written, their ranks must match, and
1193    and their shapes must all match for all dimensions except the first.
1194
1195    Args:
1196      name: A name for the operation (optional).
1197
1198    Returns:
1199      All the tensors in the TensorArray concatenated into one tensor.
1200    """
1201    return self._implementation.concat(name=name)
1202
1203  @tf_should_use.should_use_result
1204  def unstack(self, value, name=None):
1205    """Unstack the values of a `Tensor` in the TensorArray.
1206
1207    If input value shapes have rank-`R`, then the output TensorArray will
1208    contain elements whose shapes are rank-`(R-1)`.
1209
1210    Args:
1211      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unstack.
1212      name: A name for the operation (optional).
1213
1214    Returns:
1215      A new TensorArray object with flow that ensures the unstack occurs.
1216      Use this object for all subsequent operations.
1217
1218    Raises:
1219      ValueError: if the shape inference fails.
1220    """
1221    return self._implementation.unstack(value, name=name)
1222
1223  @tf_should_use.should_use_result
1224  def scatter(self, indices, value, name=None):
1225    """Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
1226
1227    Args:
1228      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If
1229        the `TensorArray` is not dynamic, `max_value=size()`.
1230      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unpack.
1231      name: A name for the operation (optional).
1232
1233    Returns:
1234      A new TensorArray object with flow that ensures the scatter occurs.
1235      Use this object for all subsequent operations.
1236
1237    Raises:
1238      ValueError: if the shape inference fails.
1239    """
1240    return self._implementation.scatter(indices, value, name=name)
1241
1242  @tf_should_use.should_use_result
1243  def split(self, value, lengths, name=None):
1244    """Split the values of a `Tensor` into the TensorArray.
1245
1246    Args:
1247      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to split.
1248      lengths: 1-D.  int32 vector with the lengths to use when splitting
1249        `value` along its first dimension.
1250      name: A name for the operation (optional).
1251
1252    Returns:
1253      A new TensorArray object with flow that ensures the split occurs.
1254      Use this object for all subsequent operations.
1255
1256    Raises:
1257      ValueError: if the shape inference fails.
1258    """
1259    return self._implementation.split(value, lengths, name=name)
1260
1261  def size(self, name=None):
1262    """Return the size of the TensorArray."""
1263    return self._implementation.size(name=name)
1264
1265  @tf_should_use.should_use_result
1266  def close(self, name=None):
1267    """Close the current TensorArray."""
1268    return self._implementation.close(name=name)
1269
1270
1271def build_ta_with_new_flow(old_ta, flow):
1272  """Builds a TensorArray with a new `flow` tensor."""
1273  # Sometimes we get old_ta as the implementation, sometimes it's the
1274  # TensorArray wrapper object.
1275  impl = (old_ta._implementation if isinstance(old_ta, TensorArray)
1276          else old_ta)
1277
1278  if not context.executing_eagerly():
1279    if (not isinstance(impl, _GraphTensorArrayV2) and
1280        control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
1281      raise NotImplementedError("Attempting to build a graph-mode TF2-style "
1282                                "TensorArray from either an eager-mode "
1283                                "TensorArray or a TF1-style TensorArray.  "
1284                                "This is not currently supported.  You may be "
1285                                "attempting to capture a TensorArray "
1286                                "inside a tf.function or tf.data map function. "
1287                                "Instead, construct a new TensorArray inside "
1288                                "the function.")
1289  new_ta = TensorArray(
1290      dtype=impl.dtype,
1291      handle=impl.handle,
1292      flow=flow,
1293      infer_shape=impl._infer_shape,
1294      colocate_with_first_write_call=impl._colocate_with_first_write_call)
1295  new_impl = new_ta._implementation
1296  new_impl._dynamic_size = impl._dynamic_size
1297  new_impl._size = impl._size
1298  new_impl._colocate_with = impl._colocate_with
1299  new_impl._element_shape = impl._element_shape  # Share _element_shape.
1300  return new_ta
1301
1302# pylint: enable=protected-access
1303
1304
1305def _check_dtypes(value, dtype):
1306  if value.dtype != dtype:
1307    logging.error(
1308        "Error: Input value {} has dtype {}, but expected dtype {}.  "
1309        "This leads to undefined behavior and will be an error "
1310        "in future versions of TensorFlow.  Traceback:\n{}".format(
1311            value, str(value.dtype), str(dtype),
1312            "".join(traceback.format_stack())))
1313
1314
1315@tf_export("TensorArraySpec")
1316@type_spec.register("tf.TensorArraySpec")
1317class TensorArraySpec(type_spec.TypeSpec):
1318  """Type specification for a `tf.TensorArray`."""
1319
1320  __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"]
1321
1322  value_type = property(lambda self: TensorArray)
1323
1324  def __init__(self, element_shape=None, dtype=dtypes.float32,
1325               dynamic_size=False, infer_shape=True):
1326    """Constructs a type specification for a `tf.TensorArray`.
1327
1328    Args:
1329      element_shape: The shape of each element in the `TensorArray`.
1330      dtype: Data type of the `TensorArray`.
1331      dynamic_size: Whether the `TensorArray` can grow past its initial size.
1332      infer_shape: Whether shape inference is enabled.
1333    """
1334    self._element_shape = tensor_shape.as_shape(element_shape)
1335    self._dtype = dtypes.as_dtype(dtype)
1336    self._dynamic_size = dynamic_size
1337    self._infer_shape = infer_shape
1338
1339  def is_compatible_with(self, other):
1340    # pylint: disable=protected-access
1341    if not isinstance(other, type_spec.TypeSpec):
1342      other = type_spec.type_spec_from_value(other)
1343
1344    # Note: we intentionally exclude infer_shape in this check.
1345    return (isinstance(other, TensorArraySpec) and
1346            self._dtype.is_compatible_with(other._dtype) and
1347            self._element_shape.is_compatible_with(other._element_shape) and
1348            self._dynamic_size == other._dynamic_size)
1349
1350  def most_specific_compatible_type(self, other):
1351    # pylint: disable=protected-access
1352    if not self.is_compatible_with(other):
1353      raise ValueError("Types are not compatible")
1354    infer_shape = self._infer_shape and other._infer_shape
1355    return TensorArraySpec(
1356        self._element_shape.most_specific_compatible_shape(
1357            other._element_shape),
1358        self._dtype, self._dynamic_size, infer_shape)
1359
1360  def _serialize(self):
1361    return (self._element_shape, self._dtype, self._dynamic_size,
1362            self._infer_shape)
1363
1364  @property
1365  def _component_specs(self):
1366    return [tensor_spec.TensorSpec([], dtypes.variant)]
1367
1368  def _to_components(self, value):
1369    if not isinstance(value, TensorArray):
1370      raise TypeError("value must be a TensorArray, but saw: {}"
1371                      .format(type(value)))
1372    if value.flow is not None and value.flow.dtype == dtypes.variant:
1373      return [value.flow]
1374    else:
1375      # Convert to a TF2-style TensorArray.
1376      # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
1377      # "implementation / as_variant" arg to TensorArray constructor.
1378      with ops.name_scope("convert_tensor_array"):
1379        flow = list_ops.tensor_list_from_tensor(
1380            tensor=value.stack(), element_shape=value.element_shape)
1381      return [flow]
1382
1383  def _from_components(self, tensor_list):
1384    # This will return a TF2 Graph-style TensorArray because tensor_list[0] is
1385    # a variant object.  size == -1 implies unknown size.
1386    ret = TensorArray(
1387        dtype=self._dtype,
1388        flow=tensor_list[0],
1389        dynamic_size=self._dynamic_size,
1390        infer_shape=self._infer_shape)
1391    ret._implementation._element_shape = [self._element_shape]  # pylint: disable=protected-access
1392    return ret
1393
1394  @staticmethod
1395  def from_value(value):
1396    if not isinstance(value, TensorArray):
1397      raise TypeError("Expected value to be a TensorArray, but saw: {}".
1398                      format(type(value)))
1399
1400    return TensorArraySpec(
1401        dtype=value.dtype,
1402        element_shape=value.element_shape,
1403        dynamic_size=value.dynamic_size,
1404        infer_shape=value._infer_shape)  # pylint: disable=protected-access
1405
1406  def _to_legacy_output_types(self):
1407    return self._dtype
1408
1409  def _to_legacy_output_shapes(self):
1410    # Sneak the dynamic_size and infer_shape values into the legacy shape.
1411    return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape
1412                                     ]).concatenate(self._element_shape))
1413
1414  def _to_legacy_output_classes(self):
1415    return TensorArray
1416
1417
1418# Register the TypeSpec for TensorArray.  If TensorArray is updated to be a
1419# CompositeTensor, then this registration can be deleted.
1420type_spec.register_type_spec_from_value_converter(
1421    TensorArray, TensorArraySpec.from_value, allow_subclass=True)
1422