• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorArray: a dynamically sized array of Tensors."""
16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
17# pylint: disable=g-bad-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import weakref
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors_impl
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_util
34from tensorflow.python.ops import gen_control_flow_ops
35from tensorflow.python.ops import gen_data_flow_ops
36from tensorflow.python.ops import list_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.util import tf_should_use
39from tensorflow.python.util.tf_export import tf_export
40
41
42# _GraphTensorArray accesses many of the hidden generated ops, but is in
43# fact built to wrap these methods.
44# pylint: disable=protected-access
45class _GraphTensorArray(object):
46  """Graph-mode implementation of TensorArray.
47  """
48
49  def __init__(self,
50               dtype,
51               size=None,
52               dynamic_size=None,
53               clear_after_read=None,
54               tensor_array_name=None,
55               handle=None,
56               flow=None,
57               infer_shape=True,
58               element_shape=None,
59               colocate_with_first_write_call=True,
60               name=None):
61    """Constructs a graph mode TensorArray.
62
63    Args:
64      dtype: (required) data type of the TensorArray.
65      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
66        Required if handle is not provided.
67      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
68        can grow the TensorArray past its initial size.  Default: False.
69      clear_after_read: Boolean (optional, default: True).  If True, clear
70        TensorArray values after reading them.  This disables read-many
71        semantics, but allows early release of memory.
72      tensor_array_name: (optional) Python string: the name of the TensorArray.
73        This is used when creating the TensorArray handle.  If this value is
74        set, handle should be None.
75      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
76        is set, tensor_array_name should be None. Only supported in graph mode.
77      flow: (optional) A float `Tensor` scalar coming from an existing
78        `TensorArray.flow`. Only supported in graph mode.
79      infer_shape: (optional, default: True) If True, shape inference
80        is enabled.  In this case, all elements must have the same shape.
81      element_shape: (optional, default: None) A `TensorShape` object specifying
82        the shape constraints of each of the elements of the TensorArray.
83        Need not be fully defined.
84      colocate_with_first_write_call: If `True`, the TensorArray will be
85        colocated on the same device as the Tensor used on its first write
86        (write operations include `write`, `unstack`, and `split`).  If `False`,
87        the TensorArray will be placed on the device determined by the
88        device context available during its initialization.
89      name: A name for the operation (optional).
90
91    Raises:
92      ValueError: if both handle and tensor_array_name are provided.
93      TypeError: if handle is provided but is not a Tensor.
94    """
95    if handle is not None and tensor_array_name:
96      raise ValueError(
97          "Cannot construct with both handle and tensor_array_name")
98    if handle is not None and not isinstance(handle, ops.Tensor):
99      raise TypeError("Handle must be a Tensor")
100    if handle is None and size is None:
101      raise ValueError("Size must be provided if handle is not provided")
102    if handle is not None and size is not None:
103      raise ValueError("Cannot provide both a handle and size "
104                       "at the same time")
105    if handle is not None and element_shape is not None:
106      raise ValueError("Cannot provide both a handle and element_shape "
107                       "at the same time")
108    if handle is not None and dynamic_size is not None:
109      raise ValueError("Cannot provide both a handle and dynamic_size "
110                       "at the same time")
111    if handle is not None and clear_after_read is not None:
112      raise ValueError("Cannot provide both a handle and clear_after_read "
113                       "at the same time")
114
115    if clear_after_read is None:
116      clear_after_read = True
117    self._dynamic_size = None
118    dynamic_size = dynamic_size or False
119
120    self._dtype = dtype
121
122    # Used to keep track of what tensors the TensorArray should be
123    # colocated with.  We choose to colocate the TensorArray with the
124    # first tensor written to it.
125    self._colocate_with_first_write_call = colocate_with_first_write_call
126    if colocate_with_first_write_call:
127      self._colocate_with = []
128    else:
129      self._colocate_with = None
130
131    # Record the current static shape for the array elements. The element
132    # shape is defined either by `element_shape` or the shape of the tensor
133    # of the first write. If `infer_shape` is true, all writes checks for
134    # shape equality.
135    if element_shape is None:
136      self._infer_shape = infer_shape
137      self._element_shape = []
138    else:
139      self._infer_shape = True
140      self._element_shape = [tensor_shape.TensorShape(element_shape)]
141    with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
142      if handle is not None:
143        self._handle = handle
144        if flow is None:
145          raise ValueError("flow must not be None if handle is not None.")
146        self._flow = flow
147      else:
148        # Construct the TensorArray with an empty device.  The first
149        # write into the TensorArray from a Tensor with a set device
150        # will retroactively set the device value of this op.
151        def create():
152          """Create the TensorArray op."""
153          return gen_data_flow_ops.tensor_array_v3(
154              dtype=dtype,
155              size=size,
156              element_shape=element_shape,
157              identical_element_shapes=infer_shape,
158              dynamic_size=dynamic_size,
159              clear_after_read=clear_after_read,
160              tensor_array_name=tensor_array_name,
161              name=scope)
162        if colocate_with_first_write_call:
163          with ops.device(None), ops.colocate_with(None, ignore_existing=True):
164            self._handle, self._flow = create()
165        else:
166          self._handle, self._flow = create()
167
168  @property
169  def flow(self):
170    return self._flow
171
172  @property
173  def dtype(self):
174    return self._dtype
175
176  @property
177  def handle(self):
178    return self._handle
179
180  def _merge_element_shape(self, shape):
181    """Changes the element shape of the array given a shape to merge with.
182
183    Args:
184      shape: A `TensorShape` object to merge with.
185
186    Raises:
187      ValueError: if the provided shape is incompatible with the current
188          element shape of the `TensorArray`.
189    """
190
191    if self._element_shape:
192      if not shape.is_compatible_with(self._element_shape[0]):
193        raise ValueError(
194            "Inconsistent shapes: saw %s but expected %s "
195            "(and infer_shape=True)" % (shape, self._element_shape[0]))
196      self._element_shape[0] = self._element_shape[0].merge_with(shape)
197    else:
198      self._element_shape.append(shape)
199
200  @contextlib.contextmanager
201  def _maybe_colocate_with(self, value):
202    """Colocate operations with an internal colocation group or `value`.
203
204    Args:
205      value: `Tensor`, the tensor to try to colocate with.
206
207    Yields:
208      Does not yield anything, but the new context is a colocation context.
209
210    If no internal colocation group is set, colocate with `value` and set
211    the internal colocation group to be value.
212    """
213    if not self._colocate_with_first_write_call:
214      yield
215    else:
216      if not self._colocate_with:
217        self._colocate_with.append(value)
218      with ops.colocate_with(self._colocate_with[0]):
219        yield
220
221  def identity(self):
222    """See TensorArray."""
223    flow = array_ops.identity(self._flow)
224    ta = TensorArray(
225        dtype=self._dtype,
226        handle=self._handle,
227        flow=flow,
228        infer_shape=self._infer_shape,
229        colocate_with_first_write_call=self._colocate_with_first_write_call)
230    ta._element_shape = self._element_shape
231    ta._colocate_with = self._colocate_with
232    return ta
233
234  def grad(self, source, flow=None, name=None):
235    """See TensorArray."""
236    # tensor_array_grad requires a flow input when forward
237    # TensorArrays are dynamically sized.  This forces the creation
238    # of the grad TensorArray only once the final forward array's size
239    # is fixed.
240    if flow is None:
241      flow = self.flow
242    with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
243      with ops.colocate_with(self._handle):
244        g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3(
245            handle=self._handle, source=source, flow_in=flow, name=name)
246        with ops.control_dependencies([g_handle]):
247          flow = array_ops.identity(flow, name="gradient_flow")
248        g = TensorArray(
249            dtype=self._dtype,
250            handle=g_handle,
251            flow=flow,
252            infer_shape=self._infer_shape,
253            colocate_with_first_write_call=False)
254        g._element_shape = self._element_shape
255        return g
256
257  def read(self, index, name=None):
258    """See TensorArray."""
259    value = gen_data_flow_ops.tensor_array_read_v3(
260        handle=self._handle,
261        index=index,
262        flow_in=self._flow,
263        dtype=self._dtype,
264        name=name)
265    if self._element_shape:
266      value.set_shape(self._element_shape[0].dims)
267    return value
268
269  @tf_should_use.should_use_result
270  def write(self, index, value, name=None):
271    """See TensorArray."""
272    with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
273      value = ops.convert_to_tensor(value, name="value")
274      if self._infer_shape:
275        self._merge_element_shape(value.shape)
276      with self._maybe_colocate_with(value):
277        flow_out = gen_data_flow_ops.tensor_array_write_v3(
278            handle=self._handle,
279            index=index,
280            value=value,
281            flow_in=self._flow,
282            name=name)
283      ta = TensorArray(
284          dtype=self._dtype,
285          handle=self._handle,
286          flow=flow_out,
287          colocate_with_first_write_call=self._colocate_with_first_write_call)
288      ta._infer_shape = self._infer_shape
289      ta._element_shape = self._element_shape
290      ta._colocate_with = self._colocate_with
291      return ta
292
293  def stack(self, name=None):
294    """See TensorArray."""
295    with ops.colocate_with(self._handle):
296      with ops.name_scope(name, "TensorArrayStack", [self._handle]):
297        return self.gather(math_ops.range(0, self.size()), name=name)
298
299  def gather(self, indices, name=None):
300    """See TensorArray."""
301    if self._element_shape:
302      element_shape = self._element_shape[0]
303    else:
304      element_shape = tensor_shape.TensorShape(None)
305    value = gen_data_flow_ops.tensor_array_gather_v3(
306        handle=self._handle,
307        indices=indices,
308        flow_in=self._flow,
309        dtype=self._dtype,
310        name=name,
311        element_shape=element_shape)
312    if self._element_shape and self._element_shape[0].dims is not None:
313      value.set_shape([None] + self._element_shape[0].dims)
314    return value
315
316  def concat(self, name=None):
317    """See TensorArray."""
318    if self._element_shape and self._element_shape[0].dims is not None:
319      element_shape_except0 = (
320          tensor_shape.TensorShape(self._element_shape[0].dims[1:]))
321    else:
322      element_shape_except0 = tensor_shape.TensorShape(None)
323    value, _ = gen_data_flow_ops.tensor_array_concat_v3(
324        handle=self._handle,
325        flow_in=self._flow,
326        dtype=self._dtype,
327        name=name,
328        element_shape_except0=element_shape_except0)
329    if self._element_shape and self._element_shape[0].dims is not None:
330      value.set_shape([None] + self._element_shape[0].dims[1:])
331    return value
332
333  @tf_should_use.should_use_result
334  def unstack(self, value, name=None):
335    """See TensorArray."""
336    with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
337      num_elements = array_ops.shape(value)[0]
338      return self.scatter(
339          indices=math_ops.range(0, num_elements), value=value, name=name)
340
341  @tf_should_use.should_use_result
342  def scatter(self, indices, value, name=None):
343    """See TensorArray."""
344    with ops.name_scope(name, "TensorArrayScatter",
345                        [self._handle, value, indices]):
346      value = ops.convert_to_tensor(value, name="value")
347      if self._infer_shape and not context.executing_eagerly():
348        self._merge_element_shape(value.shape[1:])
349      with self._maybe_colocate_with(value):
350        flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
351            handle=self._handle,
352            indices=indices,
353            value=value,
354            flow_in=self._flow,
355            name=name)
356      ta = TensorArray(
357          dtype=self._dtype,
358          handle=self._handle,
359          flow=flow_out,
360          colocate_with_first_write_call=self._colocate_with_first_write_call)
361      ta._infer_shape = self._infer_shape
362      ta._element_shape = self._element_shape
363      ta._colocate_with = self._colocate_with
364      return ta
365
366  @tf_should_use.should_use_result
367  def split(self, value, lengths, name=None):
368    """See TensorArray."""
369    with ops.name_scope(name, "TensorArraySplit",
370                        [self._handle, value, lengths]):
371      value = ops.convert_to_tensor(value, name="value")
372      with self._maybe_colocate_with(value):
373        lengths_64 = math_ops.cast(lengths, dtypes.int64)
374        if self._infer_shape and not context.executing_eagerly():
375          clengths = tensor_util.constant_value(lengths_64)
376          if value.shape.dims is not None:
377            if clengths is not None and clengths.max() == clengths.min():
378              self._merge_element_shape(
379                  tensor_shape.TensorShape([clengths[0]]).concatenate(
380                      value.shape[1:]))
381        flow_out = gen_data_flow_ops.tensor_array_split_v3(
382            handle=self._handle,
383            value=value,
384            lengths=lengths_64,
385            flow_in=self._flow,
386            name=name)
387      ta = TensorArray(
388          dtype=self._dtype,
389          handle=self._handle,
390          flow=flow_out,
391          colocate_with_first_write_call=self._colocate_with_first_write_call)
392      ta._infer_shape = self._infer_shape
393      ta._element_shape = self._element_shape
394      ta._colocate_with = self._colocate_with
395      return ta
396
397  def size(self, name=None):
398    """See TensorArray."""
399    return gen_data_flow_ops.tensor_array_size_v3(
400        handle=self._handle, flow_in=self.flow, name=name)
401
402  @tf_should_use.should_use_result
403  def close(self, name=None):
404    """See TensorArray."""
405    return gen_data_flow_ops.tensor_array_close_v3(
406        handle=self._handle, name=name)
407
408
409class _GraphTensorArrayV2(object):
410  """Graph-mode implementation of TensorArray backed by TensorLists.
411
412  The backing tensor of this TensorArray is a TensorList variant tensor which is
413  stored in the `flow`. The `handle` is always none here. The reason we use the
414  `flow` field and not the `handle` field is to ensure backwards compatibility
415  with legacy control flow.
416  """
417
418  def __init__(self,
419               dtype,
420               size=None,
421               dynamic_size=None,
422               clear_after_read=None,
423               tensor_array_name=None,
424               handle=None,
425               flow=None,
426               infer_shape=True,
427               element_shape=None,
428               colocate_with_first_write_call=True,
429               name=None):
430    """Constructs a graph mode TensorArray.
431
432    Args:
433      dtype: (required) data type of the TensorArray.
434      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
435        Required if flow is not provided.
436      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
437        can grow the TensorArray past its initial size.  Default: False.
438      clear_after_read: (optional) unused. Not supported in TensorLists.
439      tensor_array_name: (optional) unused.
440      handle: (optional) Must always be None.
441      flow: (optional) A variant `Tensor` scalar for a TensorList.
442      infer_shape: (optional, default: True) If True, shape inference is
443        enabled.  In this case, all elements must have the same shape.
444      element_shape: (optional, default: None) A `TensorShape` object specifying
445        the shape constraints of each of the elements of the TensorArray. Need
446        not be fully defined.
447      colocate_with_first_write_call: (optional). unused.
448      name: (optional) A name for the operation.
449
450    Raises:
451      ValueError: if both handle and tensor_array_name are provided.
452      TypeError: if handle is provided but is not a Tensor.
453    """
454    assert handle is None
455    del handle
456    del clear_after_read
457    del tensor_array_name
458    del colocate_with_first_write_call
459
460    self._dynamic_size = dynamic_size
461
462    if (flow is not None and
463        (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)):
464      raise TypeError("flow must be a variant tensor")
465    if flow is None and size is None:
466      raise ValueError("Size must be provided if flow is not provided")
467    if flow is not None and size is not None:
468      raise ValueError("Cannot provide both a flow and size "
469                       "at the same time")
470    if flow is not None and element_shape is not None:
471      raise ValueError("Cannot provide both a flow and element_shape "
472                       "at the same time")
473
474    self._dtype = dtype
475
476    # Record the current static shape for the array elements. The element
477    # shape is defined either by `element_shape` or the shape of the tensor
478    # of the first write. If `infer_shape` is true, all writes checks for
479    # shape equality.
480    if element_shape is None:
481      self._infer_shape = infer_shape
482      self._element_shape = []
483    else:
484      self._infer_shape = True
485      self._element_shape = [tensor_shape.TensorShape(element_shape)]
486    with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
487      if flow is None:
488        self._flow = list_ops.tensor_list_reserve(
489            element_shape=element_shape,
490            num_elements=size,
491            element_dtype=dtype,
492            name=scope)
493      else:
494        self._flow = flow
495
496    # For backwards compatibility.
497    self._colocate_with_first_write_call = None
498    self._colocate_with = None
499
500  @property
501  def flow(self):
502    return self._flow
503
504  @property
505  def dtype(self):
506    return self._dtype
507
508  @property
509  def handle(self):
510    # We intentionally do not raise an error so that legacy while_loop does not
511    # complain.
512    return None
513
514  def _merge_element_shape(self, shape):
515    """Changes the element shape of the array given a shape to merge with.
516
517    Args:
518      shape: A `TensorShape` object to merge with.
519
520    Raises:
521      ValueError: if the provided shape is incompatible with the current
522          element shape of the `TensorArray`.
523    """
524
525    if self._element_shape:
526      if not shape.is_compatible_with(self._element_shape[0]):
527        raise ValueError(
528            "Inconsistent shapes: saw %s but expected %s "
529            "(and infer_shape=True)" % (shape, self._element_shape[0]))
530      self._element_shape[0] = self._element_shape[0].merge_with(shape)
531    else:
532      self._element_shape.append(shape)
533
534  def identity(self):
535    """See TensorArray."""
536    flow = array_ops.identity(self._flow)
537    return build_ta_with_new_flow(self, flow)
538
539  def grad(self, source, flow=None, name=None):
540    """Not supported."""
541    raise NotImplementedError()
542
543  def read(self, index, name=None):
544    """See TensorArray."""
545    with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
546      if self._element_shape:
547        element_shape = self._element_shape[0]
548      else:
549        element_shape = tensor_shape.TensorShape(None)
550      value = list_ops.tensor_list_get_item(
551          input_handle=self._flow,
552          index=index,
553          element_dtype=self._dtype,
554          element_shape=element_shape,
555          name=name)
556      if self._element_shape:
557        value.set_shape(self._element_shape[0].dims)
558      return value
559
560  @tf_should_use.should_use_result
561  def write(self, index, value, name=None):
562    """See TensorArray."""
563    with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
564      value = ops.convert_to_tensor(value, name="value")
565      if self._infer_shape:
566        self._merge_element_shape(value.shape)
567      flow_out = list_ops.tensor_list_set_item(
568          input_handle=self._flow,
569          index=index,
570          item=value,
571          resize_if_index_out_of_bounds=self._dynamic_size,
572          name=name)
573      return build_ta_with_new_flow(self, flow_out)
574
575  def stack(self, name=None):
576    """See TensorArray."""
577    with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
578      if self._element_shape:
579        element_shape = self._element_shape[0]
580      else:
581        element_shape = tensor_shape.TensorShape(None)
582      value = list_ops.tensor_list_stack(
583          input_handle=self._flow,
584          element_dtype=self._dtype,
585          element_shape=element_shape)
586      if self._element_shape and self._element_shape[0].dims is not None:
587        value.set_shape([None] + self._element_shape[0].dims)
588      return value
589
590  def gather(self, indices, name=None):
591    """See TensorArray."""
592    if self._element_shape:
593      element_shape = self._element_shape[0]
594    else:
595      element_shape = tensor_shape.TensorShape(None)
596    value = list_ops.tensor_list_gather(
597        input_handle=self._flow,
598        indices=indices,
599        element_dtype=self._dtype,
600        element_shape=element_shape,
601        name=name)
602    if self._element_shape and self._element_shape[0].dims is not None:
603      value.set_shape([None] + self._element_shape[0].dims)
604    return value
605
606  def concat(self, name=None):
607    """See TensorArray."""
608    if self._element_shape and self._element_shape[0].dims is not None:
609      element_shape = [None] + self._element_shape[0].dims[1:]
610    else:
611      element_shape = None
612
613    value = list_ops.tensor_list_concat(
614        input_handle=self._flow,
615        element_dtype=self._dtype,
616        element_shape=element_shape,
617        name=name)
618    return value
619
620  @tf_should_use.should_use_result
621  def unstack(self, value, name=None):
622    """See TensorArray."""
623    with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
624      value = ops.convert_to_tensor(value, name="value")
625      if self._infer_shape and not context.executing_eagerly():
626        self._merge_element_shape(value.shape[1:])
627      flow_out = list_ops.tensor_list_from_tensor(
628          tensor=value, element_shape=value.shape[1:])
629      return build_ta_with_new_flow(self, flow_out)
630
631  @tf_should_use.should_use_result
632  def scatter(self, indices, value, name=None):
633    """See TensorArray."""
634    with ops.name_scope(name, "TensorArrayScatter",
635                        [self._flow, value, indices]):
636      value = ops.convert_to_tensor(value, name="value")
637      if self._infer_shape and not context.executing_eagerly():
638        self._merge_element_shape(value.shape[1:])
639      element_shape = self._element_shape[0] if self._element_shape else None
640      flow_out = list_ops.tensor_list_scatter(
641          tensor=value, indices=indices, input_handle=self._flow)
642      return build_ta_with_new_flow(self, flow_out)
643
644  @tf_should_use.should_use_result
645  def split(self, value, lengths, name=None):
646    """See TensorArray."""
647    with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
648      value = ops.convert_to_tensor(value, name="value")
649      lengths_64 = math_ops.cast(lengths, dtypes.int64)
650      if self._infer_shape and not context.executing_eagerly():
651        clengths = tensor_util.constant_value(lengths_64)
652        if value.shape.dims is not None:
653          if clengths is not None and clengths.max() == clengths.min():
654            self._merge_element_shape(
655                tensor_shape.TensorShape([clengths[0]]).concatenate(
656                    value.shape[1:]))
657      flow_out = list_ops.tensor_list_split(
658          tensor=value,
659          lengths=lengths_64,
660          element_shape=self._element_shape[0] if self._element_shape else None,
661          name=name)
662      return build_ta_with_new_flow(self, flow_out)
663
664  def size(self, name=None):
665    """See TensorArray."""
666    return list_ops.tensor_list_length(input_handle=self._flow, name=name)
667
668  @tf_should_use.should_use_result
669  def close(self, name=None):
670    """See TensorArray."""
671    return gen_control_flow_ops.no_op(name=name)
672
673# pylint: enable=protected-access
674
675
676class _EagerTensorArray(object):
677  """Eager-compatible implementation of TensorArray.
678  """
679
680  def __init__(self,
681               dtype,
682               size=None,
683               dynamic_size=None,
684               clear_after_read=None,
685               tensor_array_name=None,
686               handle=None,
687               flow=None,
688               infer_shape=True,
689               element_shape=None,
690               colocate_with_first_write_call=True,
691               name=None):
692    """Constructs a TensorArray compatible with eager execution.
693
694    Args:
695      dtype: (required) data type of the TensorArray.
696      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
697        Required if handle is not provided.
698      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
699        can grow the TensorArray past its initial size.  Default: False.
700      clear_after_read: Boolean (optional, default: True).  If True, clear
701        TensorArray values after reading them.  This disables read-many
702        semantics, but allows early release of memory.
703      tensor_array_name: unused.
704      handle: unsupported.
705      flow: unsupported.
706      infer_shape: used for error checking, same semantics as TensorArray.
707      element_shape: used for error checking, same semantics as TensorArray.
708      colocate_with_first_write_call: unsupported.
709      name: unsupported.
710
711    Raises:
712      ValueError: handle or flow are supplied, or if size is not supplied.
713    """
714
715    del (flow, tensor_array_name, name)  # Unused.
716
717    if handle is not None:
718      raise ValueError("TensorArray handles are not supported when eager "
719                       "execution is enabled.")
720    if size is None:
721      raise ValueError("Size must be declared for TensorArrays when eager "
722                       "execution is enabled.")
723
724    # These attributes are not meaningful when eager is enabled, but some
725    # library functions (e.g., those in control_flow_ops.py) access them to
726    # create new tensor arrays; as such, we define them for the sake of
727    # compatibility.
728    self._handle = None
729    # we assign a dummy value to _flow in case other code assumes it to be
730    # a Tensor
731    self._flow = constant_op.constant(0, dtype=dtypes.int32)
732    self._infer_shape = infer_shape
733    self._element_shape = element_shape
734    self._colocate_with_first_write_call = colocate_with_first_write_call
735
736    self._dtype = dtype
737    self._dynamic_size = dynamic_size or False
738    self._clear_after_read = (
739        True if clear_after_read is None else clear_after_read)
740    self._previously_read_indices = []
741
742    if isinstance(size, ops.EagerTensor):
743      size = size.numpy()
744    self._tensor_array = [None for _ in range(size)]
745
746  @property
747  def flow(self):
748    """For compatibility; flows are not meaningful when eager is enabled."""
749    return self._flow
750
751  @property
752  def dtype(self):
753    return self._dtype
754
755  @property
756  def handle(self):
757    """For compatibility; handles are not meaningful when eager is enabled."""
758    return self._handle
759
760  def identity(self):
761    """See TensorArray."""
762    return self.parent()
763
764  def grad(self, source, flow=None, name=None):
765    raise NotImplementedError(
766        "TensorArray.grad is not supported when executing eagerly; eager's "
767        "gradient implementation does not use/need this function to compute "
768        "gradients of operations that use TensorArrays.")
769
770  def read(self, index, name=None):
771    """See TensorArray."""
772    del name  # not meaningful when executing eagerly.
773
774    if isinstance(index, ops.EagerTensor):
775      index = index.numpy()
776
777    if index < 0:
778      raise errors_impl.OutOfRangeError(
779          None, None,
780          "Reading from negative indices (index %d) is not allowed." % index)
781
782    if index >= len(self._tensor_array):
783      raise errors_impl.OutOfRangeError(
784          None, None, "Tried to read from index %d but array size is: %d" %
785          (index, len(self._tensor_array)))
786
787    tensor = self._tensor_array[index]
788    if tensor is None:
789      if index in self._previously_read_indices:
790        raise errors_impl.InvalidArgumentError(
791            None, None,
792            "Could not read index %d twice because it was cleared after "
793            "a previous read (perhaps try setting clear_after_read = false?)" %
794            index)
795      else:
796        tensor = self._maybe_zero(index)
797
798    if self._clear_after_read:
799      self._tensor_array[index] = None
800      self._previously_read_indices.append(index)
801    return tensor
802
803  def _write(self, index, value):
804    """Writes `value` into index named by `index`.
805
806    Args:
807      index: 0-D.  int32 scalar with the index to write to.
808      value: N-D.  Tensor of type `dtype`.  The `Tensor` to write to `index`.
809
810    Raises:
811      errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
812      errors_impl.OutOfRangeError: `index` is out of bounds.
813      ValueError: shape of `value` is not consistent with inferred shape.
814    """
815
816    if isinstance(index, ops.EagerTensor):
817      index = index.numpy()
818
819    if index < 0:
820      raise errors_impl.OutOfRangeError(
821          None, None,
822          "Writing to negative indices (index %d) is not allowed." % index)
823
824    size = len(self._tensor_array)
825    if index >= size:
826      if not self._dynamic_size:
827        raise errors_impl.OutOfRangeError(
828            None, None,
829            "Tried to write to index %d but array is not resizeable and size "
830            "is: %d" % (index, size))
831      self._tensor_array.extend([None for _ in range(index - size + 1)])
832
833    if not isinstance(value, ops.EagerTensor):
834      value = ops.convert_to_tensor(value)
835
836    if self._infer_shape:
837      if self._element_shape is None:
838        self._element_shape = value.shape
839      elif not self._element_shape.is_compatible_with(value.shape):
840        raise ValueError("Incompatible shape for value (%s), expected (%s)" %
841                         (value.shape.as_list(), self._element_shape.as_list()))
842
843    if self._dtype != value.dtype:
844      raise errors_impl.InvalidArgumentError(
845          None, None,
846          "TensorArray dtype is %s but Op is trying to write dtype %s" %
847          (self._dtype.name, value.dtype.name))
848    self._tensor_array[index] = value
849
850  def write(self, index, value, name=None):
851    """See TensorArray."""
852    del name  # not meaningful when executing eagerly.
853    self._write(index, value)
854    return self.parent()
855
856  def _maybe_zero(self, ix):
857    val = self._tensor_array[ix]
858    if val is None:
859      val = self._tensor_array[ix] = array_ops.zeros(
860          shape=self._element_shape, dtype=self._dtype)
861    return val
862
863  def stack(self, name=None):
864    """See TensorArray."""
865    if self._tensor_array:
866      for ix in range(len(self._tensor_array)):
867        self._maybe_zero(ix)
868    return ops.convert_to_tensor(
869        self._tensor_array, name=name, dtype=self._dtype)
870
871  def gather(self, indices, name=None):
872    """See TensorArray."""
873    del name  # not meaningful when executing eagerly.
874    if isinstance(indices, ops.EagerTensor):
875      indices = indices.numpy()
876    return array_ops.stack([self._maybe_zero(i) for i in indices])
877
878  def concat(self, name=None):
879    """See TensorArray."""
880    try:
881      return array_ops.concat(
882          [self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
883          0, name=name)
884    except errors_impl.OpError:
885      # Reproduce a subset of the error-handling for graph-mode TensorArrays.
886      shapes = [t.shape for t in self._tensor_array]
887      ndims = [s.ndims for s in shapes]
888      if 0 in ndims:
889        idx = ndims.index(0)
890        raise errors_impl.InvalidArgumentError(
891            None, None, "Concat saw a scalar shape at index %d but requires "
892            "at least vectors." % idx)
893      else:
894        raise
895
896  def unstack(self, value, name=None):
897    """See TensorArray."""
898    tensors = array_ops.unstack(value, name=name)
899    if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
900      raise ValueError(
901          "Cannot unstack %d tensors into a TensorArray of static size %d" %
902          (len(tensors), len(self._tensor_array)))
903    self._tensor_array = tensors
904    return self.parent()
905
906  def scatter(self, indices, value, name=None):
907    """See TensorArray."""
908    del name  # not meaningful when executing eagerly.
909    if isinstance(indices, ops.EagerTensor):
910      indices = indices.numpy()
911    for index, val in zip(indices, array_ops.unstack(value)):
912      self._write(index, val)  # pylint: disable=protected-access
913    return self.parent()
914
915  def split(self, value, lengths, name=None):
916    """See TensorArray."""
917    # error checking to match graph-mode errors
918    value = ops.convert_to_tensor(value)
919    lengths = ops.convert_to_tensor(lengths)
920    sum_lengths = math_ops.reduce_sum(lengths)
921    if lengths.shape.ndims != 1:
922      raise errors_impl.InvalidArgumentError(
923          None, None, "Expected lengths to be a vector, received shape: %s" %
924          lengths.shape.as_list())
925    elif value.shape.ndims == 0:
926      raise errors_impl.InvalidArgumentError(
927          None, None, "Expected value to be at least a vector, "
928          "but received shape: %s" % value.shape.as_list())
929    elif sum_lengths.numpy() != value.shape.as_list()[0]:
930      raise errors_impl.InvalidArgumentError(
931          None, None, "Expected sum of lengths to be equal to "
932          "values.shape[0], but sum of lengths is %d and "
933          "value's shape is: %s " % (sum_lengths.numpy(),
934                                     value.shape.as_list()))
935    elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
936      raise errors_impl.InvalidArgumentError(
937          None, None, "TensorArray's size is not equal to the size of "
938          "lengths (%d vs. %d), and the TensorArray is not marked as "
939          "dynamically resizeable" % (len(self._tensor_array),
940                                      lengths.shape[0]))
941    else:
942      self._tensor_array = array_ops.split(value, lengths, name=name)
943      return self.parent()
944
945  def size(self, name=None):
946    """See TensorArray."""
947    del name  # not meaningful when executing eagerly.
948    return constant_op.constant(len(self._tensor_array))
949
950  def close(self, name=None):
951    del name  # not meaningful when executing eagerly.
952    del self._tensor_array[:]
953
954
955# TensorArray is designed to hide an underlying implementation object
956# and as such accesses many of that object's hidden fields.
957# pylint: disable=protected-access
958@tf_export("TensorArray")
959class TensorArray(object):
960  """Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
961
962  This class is meant to be used with dynamic iteration primitives such as
963  `while_loop` and `map_fn`.  It supports gradient back-propagation via special
964  "flow" control flow dependencies.
965  """
966
967  def __init__(self,
968               dtype,
969               size=None,
970               dynamic_size=None,
971               clear_after_read=None,
972               tensor_array_name=None,
973               handle=None,
974               flow=None,
975               infer_shape=True,
976               element_shape=None,
977               colocate_with_first_write_call=True,
978               name=None):
979    """Construct a new TensorArray or wrap an existing TensorArray handle.
980
981    A note about the parameter `name`:
982
983    The name of the `TensorArray` (even if passed in) is uniquified: each time
984    a new `TensorArray` is created at runtime it is assigned its own name for
985    the duration of the run.  This avoids name collisions if a `TensorArray`
986    is created within a `while_loop`.
987
988    Args:
989      dtype: (required) data type of the TensorArray.
990      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
991        Required if handle is not provided.
992      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
993        can grow the TensorArray past its initial size.  Default: False.
994      clear_after_read: Boolean (optional, default: True).  If True, clear
995        TensorArray values after reading them.  This disables read-many
996        semantics, but allows early release of memory.
997      tensor_array_name: (optional) Python string: the name of the TensorArray.
998        This is used when creating the TensorArray handle.  If this value is
999        set, handle should be None.
1000      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
1001        is set, tensor_array_name should be None. Only supported in graph mode.
1002      flow: (optional) A float `Tensor` scalar coming from an existing
1003        `TensorArray.flow`. Only supported in graph mode.
1004      infer_shape: (optional, default: True) If True, shape inference
1005        is enabled.  In this case, all elements must have the same shape.
1006      element_shape: (optional, default: None) A `TensorShape` object specifying
1007        the shape constraints of each of the elements of the TensorArray.
1008        Need not be fully defined.
1009      colocate_with_first_write_call: If `True`, the TensorArray will be
1010        colocated on the same device as the Tensor used on its first write
1011        (write operations include `write`, `unstack`, and `split`).  If `False`,
1012        the TensorArray will be placed on the device determined by the
1013        device context available during its initialization.
1014      name: A name for the operation (optional).
1015
1016    Raises:
1017      ValueError: if both handle and tensor_array_name are provided.
1018      TypeError: if handle is provided but is not a Tensor.
1019    """
1020    if context.executing_eagerly():
1021      implementation = _EagerTensorArray
1022    else:
1023      if control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
1024        implementation = _GraphTensorArrayV2
1025      else:
1026        implementation = _GraphTensorArray
1027    self._implementation = implementation(
1028        dtype,
1029        size=size,
1030        dynamic_size=dynamic_size,
1031        clear_after_read=clear_after_read,
1032        tensor_array_name=tensor_array_name,
1033        handle=handle,
1034        flow=flow,
1035        infer_shape=infer_shape,
1036        element_shape=element_shape,
1037        colocate_with_first_write_call=colocate_with_first_write_call,
1038        name=name)
1039
1040    self._implementation.parent = weakref.ref(self)
1041
1042  @property
1043  def flow(self):
1044    """The flow `Tensor` forcing ops leading to this TensorArray state."""
1045    return self._implementation._flow
1046
1047  @property
1048  def dtype(self):
1049    """The data type of this TensorArray."""
1050    return self._implementation._dtype
1051
1052  @property
1053  def handle(self):
1054    """The reference to the TensorArray."""
1055    return self._implementation.handle
1056
1057  @property
1058  def _dynamic_size(self):
1059    return self._implementation._dynamic_size
1060
1061  @property
1062  def _infer_shape(self):
1063    return self._implementation._infer_shape
1064
1065  @_infer_shape.setter
1066  def _infer_shape(self, infer_shape):
1067    self._implementation._infer_shape = infer_shape
1068
1069  @property
1070  def _element_shape(self):
1071    return self._implementation._element_shape
1072
1073  @_element_shape.setter
1074  def _element_shape(self, element_shape):
1075    self._implementation._element_shape = element_shape
1076
1077  @property
1078  def _colocate_with_first_write_call(self):
1079    return self._implementation._colocate_with_first_write_call
1080
1081  @property
1082  def _colocate_with(self):
1083    return self._implementation._colocate_with
1084
1085  @_colocate_with.setter
1086  def _colocate_with(self, colocate_with):
1087    self._implementation._colocate_with = colocate_with
1088
1089  def identity(self):
1090    """Returns a TensorArray with the same content and properties.
1091
1092    Returns:
1093      A new TensorArray object with flow that ensures the control dependencies
1094      from the contexts will become control dependencies for writes, reads, etc.
1095      Use this object all for subsequent operations.
1096    """
1097    return self._implementation.identity()
1098
1099  def grad(self, source, flow=None, name=None):
1100    return self._implementation.grad(source, flow=flow, name=name)
1101
1102  def read(self, index, name=None):
1103    """Read the value at location `index` in the TensorArray.
1104
1105    Args:
1106      index: 0-D.  int32 tensor with the index to read from.
1107      name: A name for the operation (optional).
1108
1109    Returns:
1110      The tensor at index `index`.
1111    """
1112    return self._implementation.read(index, name=name)
1113
1114  @tf_should_use.should_use_result
1115  def write(self, index, value, name=None):
1116    """Write `value` into index `index` of the TensorArray.
1117
1118    Args:
1119      index: 0-D.  int32 scalar with the index to write to.
1120      value: N-D.  Tensor of type `dtype`.  The Tensor to write to this index.
1121      name: A name for the operation (optional).
1122
1123    Returns:
1124      A new TensorArray object with flow that ensures the write occurs.
1125      Use this object all for subsequent operations.
1126
1127    Raises:
1128      ValueError: if there are more writers than specified.
1129    """
1130    return self._implementation.write(index, value, name=name)
1131
1132  def stack(self, name=None):
1133    """Return the values in the TensorArray as a stacked `Tensor`.
1134
1135    All of the values must have been written and their shapes must all match.
1136    If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
1137
1138    Args:
1139      name: A name for the operation (optional).
1140
1141    Returns:
1142      All the tensors in the TensorArray stacked into one tensor.
1143    """
1144    return self._implementation.stack(name=name)
1145
1146  def gather(self, indices, name=None):
1147    """Return selected values in the TensorArray as a packed `Tensor`.
1148
1149    All of selected values must have been written and their shapes
1150    must all match.
1151
1152    Args:
1153      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If
1154        the `TensorArray` is not dynamic, `max_value=size()`.
1155      name: A name for the operation (optional).
1156
1157    Returns:
1158      The tensors in the `TensorArray` selected by `indices`, packed into one
1159      tensor.
1160    """
1161    return self._implementation.gather(indices, name=name)
1162
1163  def concat(self, name=None):
1164    """Return the values in the TensorArray as a concatenated `Tensor`.
1165
1166    All of the values must have been written, their ranks must match, and
1167    and their shapes must all match for all dimensions except the first.
1168
1169    Args:
1170      name: A name for the operation (optional).
1171
1172    Returns:
1173      All the tensors in the TensorArray concatenated into one tensor.
1174    """
1175    return self._implementation.concat(name=name)
1176
1177  @tf_should_use.should_use_result
1178  def unstack(self, value, name=None):
1179    """Unstack the values of a `Tensor` in the TensorArray.
1180
1181    If input value shapes have rank-`R`, then the output TensorArray will
1182    contain elements whose shapes are rank-`(R-1)`.
1183
1184    Args:
1185      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unstack.
1186      name: A name for the operation (optional).
1187
1188    Returns:
1189      A new TensorArray object with flow that ensures the unstack occurs.
1190      Use this object all for subsequent operations.
1191
1192    Raises:
1193      ValueError: if the shape inference fails.
1194    """
1195    return self._implementation.unstack(value, name=name)
1196
1197  @tf_should_use.should_use_result
1198  def scatter(self, indices, value, name=None):
1199    """Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
1200
1201    Args:
1202      indices: A `1-D` `Tensor` taking values in `[0, max_value)`.  If
1203        the `TensorArray` is not dynamic, `max_value=size()`.
1204      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to unpack.
1205      name: A name for the operation (optional).
1206
1207    Returns:
1208      A new TensorArray object with flow that ensures the scatter occurs.
1209      Use this object all for subsequent operations.
1210
1211    Raises:
1212      ValueError: if the shape inference fails.
1213    """
1214    return self._implementation.scatter(indices, value, name=name)
1215
1216  @tf_should_use.should_use_result
1217  def split(self, value, lengths, name=None):
1218    """Split the values of a `Tensor` into the TensorArray.
1219
1220    Args:
1221      value: (N+1)-D.  Tensor of type `dtype`.  The Tensor to split.
1222      lengths: 1-D.  int32 vector with the lengths to use when splitting
1223        `value` along its first dimension.
1224      name: A name for the operation (optional).
1225
1226    Returns:
1227      A new TensorArray object with flow that ensures the split occurs.
1228      Use this object all for subsequent operations.
1229
1230    Raises:
1231      ValueError: if the shape inference fails.
1232    """
1233    return self._implementation.split(value, lengths, name=name)
1234
1235  def size(self, name=None):
1236    """Return the size of the TensorArray."""
1237    return self._implementation.size(name=name)
1238
1239  @tf_should_use.should_use_result
1240  def close(self, name=None):
1241    """Close the current TensorArray."""
1242    return self._implementation.close(name=name)
1243
1244
1245def build_ta_with_new_flow(old_ta, flow):
1246  """Builds a TensorArray with a new `flow` tensor."""
1247  ta = TensorArray(
1248      dtype=old_ta.dtype,
1249      dynamic_size=old_ta._dynamic_size,
1250      handle=old_ta.handle,
1251      flow=flow,
1252      infer_shape=old_ta._infer_shape,
1253      colocate_with_first_write_call=old_ta._colocate_with_first_write_call)
1254  ta._colocate_with = old_ta._colocate_with
1255  ta._element_shape = old_ta._element_shape
1256  return ta
1257
1258# pylint: enable=protected-access
1259