1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
17 #define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
18
19 #include <limits.h>
20 #include <vector>
21
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/aggregate_ops.h"
30 #include "tensorflow/core/kernels/fill_functor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39
40 namespace tensor_array {
41
42 // Full implementations are in tensor_array.cc
43 template <typename Device, typename T>
AddToTensor(OpKernelContext * ctx,Tensor * sum,const Tensor * current,const Tensor * add)44 Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current,
45 const Tensor* add) {
46 return errors::InvalidArgument(
47 "tensor_array::AddToTensor type not supported: ",
48 DataTypeString(DataTypeToEnum<T>::value));
49 }
50
51 #define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \
52 template <> \
53 Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \
54 const Tensor* current, const Tensor* add);
55
56 #define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
57 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
58 #undef TENSOR_ARRAY_WRITE_OR_ADD_CPU
59
60 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
61
62 #define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
63 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
64 TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
65 #undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
66
67 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
68
69 #undef TENSOR_ARRAY_WRITE_OR_ADD
70
71 template <typename Device, typename T>
TensorSetZero(OpKernelContext * ctx,Tensor * value)72 Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
73 return errors::InvalidArgument(
74 "tensor_array::TensorSetZero type not supported: ",
75 DataTypeString(DataTypeToEnum<T>::value));
76 }
77
78 #define TENSOR_ARRAY_SET_ZERO(Device, T) \
79 template <> \
80 Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
81
82 #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
83 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
84 TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
85 #undef TENSOR_ARRAY_SET_ZERO_CPU
86
87 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88
89 #define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
90 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
91 TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
92 #undef TENSOR_ARRAY_SET_ZERO_GPU
93
94 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
95
96 #undef TENSOR_ARRAY_SET_ZERO
97
98 } // namespace tensor_array
99
100 // The TensorArray object keeps an array of Tensors. It allows reading from the
101 // array and writing to the array.
102 //
103 // Important properties:
104 // * Usually, writing to a particular index in the TensorArray is allowed at
105 // most once per index. In a special case, writes with the flag
106 // multiple_writes_aggregate allow multiple writes to the same
107 // index. In this case, the writes are summed.
108 // * Multiple reads are supported.
109 // * Deep copies of Tensors are rarely made. The only time they are made is
110 // when WriteOrAggregate is called at least twice on the same index with the
111 // flag multiple_writes_aggregate = True.
112 // * Reading and Writing to the array is protected by a mutex.
113 // All operations on a TensorArray are thread-safe.
114 // * A TensorArray may be preemptively closed, which releases all
115 // memory associated with it.
116 //
117 // These properties together allow the TensorArray to work as a
118 // functional object and makes gradient computation easy. For
119 // example:
120 // * Write-Once semantics mean the gradient of a TensorArray Read never has to
121 // worry which of multiple writes to that index the gradient value
122 // is meant for.
123 // * Read-Many semantics (when using clear_after_read=false) allow the
124 // TensorArray to be read, packed, or concatenated multiple times;
125 // and the gradient operations use the multiple_writes_aggregate
126 // flag to aggregate the backprop writes. Multiple backprop writes to
127 // the same index are partial gradients corresponding to the
128 // multiple reads of that index in the forward phase.
129 //
130 class TensorArray : public ResourceBase {
131 public:
132 static std::atomic<int64_t> tensor_array_counter;
133
134 // Construct a TensorArray for holding Tensors of type 'dtype' with
135 // 'N' elements. While the underlying storage is a std::vector and
136 // can hold more than MAX_INT entries, in practice we do not expect
137 // users to construct this many Tensors for storage in a TensorArray.
TensorArray(const string & key,const DataType & dtype,const Tensor & handle,int32_t N,const PartialTensorShape & element_shape,bool identical_element_shapes,bool dynamic_size,bool multiple_writes_aggregate,bool is_grad,int32_t marked_size,bool clear_after_read)138 TensorArray(const string& key, const DataType& dtype, const Tensor& handle,
139 int32_t N, const PartialTensorShape& element_shape,
140 bool identical_element_shapes, bool dynamic_size,
141 bool multiple_writes_aggregate, bool is_grad, int32_t marked_size,
142 bool clear_after_read)
143 : key_(key),
144 dtype_(dtype),
145 handle_(handle),
146 closed_(false),
147 dynamic_size_(dynamic_size),
148 multiple_writes_aggregate_(multiple_writes_aggregate),
149 gradients_disallowed_(false),
150 clear_after_read_(clear_after_read),
151 is_grad_(is_grad),
152 marked_size_(marked_size),
153 element_shape_(element_shape),
154 identical_element_shapes_(identical_element_shapes),
155 tensors_(N) {}
156
157 // Write Tensor 'value' to index 'index'.
158 //
159 // Preconditions:
160 // * The TensorArray is not closed
161 // * If the array has dynamic size:
162 // The index is >= 0
163 // Otherwise:
164 // The index is in [0, N) where N == Size()
165 // * The dtype of the Tensor in 'value' matches the TensorArray's dtype.
166 // * If multiple_writes_aggregate is false:
167 // The Tensor at 'index' has not yet been written to.
168 // * If multiple_writes_aggregate is true:
169 // The Tensor at 'index' has the same shape as value.
170 //
171 // Side effects:
172 // * On the first write to 'index':
173 // - The underlying Tensor in 'value' has a new reference to it.
174 // - The index 'index' is marked as written.
175 // * If multiple_writes_aggregate is false, subsequent writes to 'index'
176 // raise an InvalidArgument error.
177 // * If multiple_writes_aggregate is true, subsequent writes to 'index':
178 // - The underlying Tensors in 'value' and from the first write
179 // are released and a local Tensor is created.
180 // - Index 'index' is also marked as local_copy.
181 // - The gradients_disallowed flag is set true (GradientsAllowed()
182 // will now return false).
183 //
184 // Note, value is passed as a pointer because we its underlying
185 // Tensor's shape is accessed. Otherwise it is not modified.
186 template <typename Device, typename T>
WriteOrAggregate(OpKernelContext * ctx,const int32_t index,const Tensor * value)187 Status WriteOrAggregate(OpKernelContext* ctx, const int32_t index,
188 const Tensor* value) {
189 mutex_lock l(mu_);
190 return LockedWriteOrAggregate<Device, T>(ctx, index, value);
191 }
192
193 template <typename Device, typename T>
WriteOrAggregateMany(OpKernelContext * ctx,const std::vector<int32> & indices,std::vector<Tensor> * values)194 Status WriteOrAggregateMany(OpKernelContext* ctx,
195 const std::vector<int32>& indices,
196 std::vector<Tensor>* values) {
197 mutex_lock l(mu_);
198 int32_t i = 0;
199 for (const int32_t ix : indices) {
200 Status s = LockedWriteOrAggregate<Device, T>(ctx, ix, &(*values)[i]);
201 ++i;
202 TF_RETURN_IF_ERROR(s);
203 }
204 return OkStatus();
205 }
206
207 // Read from index 'index' into Tensor 'value'.
208 //
209 // Preconditions:
210 // * The TensorArray is not closed
211 // * The index is in [0, N)
212 // * The Tensor at 'index' has been written to.
213 // * The Tensor at 'index' has not been read from with flag
214 // clear_after_read = true.
215 //
216 // Side effects:
217 // * If clear_after_read is true, the reference to the underlying
218 // Tensor is deleted.
219 // * The reference to the underlying Tensor at 'index' is copied to
220 // the returned '*value'.
221 // * The index is marked as read (it cannot be rewritten to).
222 template <typename Device, typename T>
Read(OpKernelContext * ctx,const int32_t index,Tensor * value)223 Status Read(OpKernelContext* ctx, const int32_t index, Tensor* value) {
224 mutex_lock l(mu_);
225 return LockedRead<Device, T>(ctx, index, value);
226 }
227
228 template <typename Device, typename T>
ReadMany(OpKernelContext * ctx,const std::vector<int32> & indices,std::vector<Tensor> * values)229 Status ReadMany(OpKernelContext* ctx, const std::vector<int32>& indices,
230 std::vector<Tensor>* values) {
231 mutex_lock l(mu_);
232 values->clear();
233 values->resize(indices.size());
234 int32_t i = 0;
235 for (const int32_t ix : indices) {
236 Status s = LockedRead<Device, T>(ctx, ix, &(*values)[i]);
237 ++i;
238 if (!s.ok()) return s;
239 }
240 return OkStatus();
241 }
242
ElemType()243 DataType ElemType() const { return dtype_; }
244
ElemShape()245 PartialTensorShape ElemShape() {
246 mutex_lock l(mu_);
247 return element_shape_;
248 }
249
SetElemShape(const PartialTensorShape & candidate)250 Status SetElemShape(const PartialTensorShape& candidate) {
251 mutex_lock l(mu_);
252 PartialTensorShape new_element_shape_;
253 Status s = element_shape_.MergeWith(candidate, &new_element_shape_);
254 if (!s.ok()) {
255 return s;
256 }
257 element_shape_ = new_element_shape_;
258 return OkStatus();
259 }
260
DebugString()261 string DebugString() const override {
262 mutex_lock l(mu_);
263 CHECK(!closed_);
264 return strings::StrCat("TensorArray[", tensors_.size(), "]");
265 }
266
IsClosed()267 bool IsClosed() {
268 mutex_lock l(mu_);
269 return closed_;
270 }
271
272 // Return the size of the TensorArray.
Size(int32 * size)273 Status Size(int32* size) {
274 mutex_lock l(mu_);
275 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
276 *size = tensors_.size();
277 return OkStatus();
278 }
279
280 // Record the size of the TensorArray after an unpack or split.
SetMarkedSize(int32_t size)281 Status SetMarkedSize(int32_t size) {
282 mutex_lock l(mu_);
283 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
284 if (!is_grad_) {
285 marked_size_ = size;
286 }
287 return OkStatus();
288 }
289
290 // Return the marked size of the TensorArray.
MarkedSize(int32 * size)291 Status MarkedSize(int32* size) {
292 mutex_lock l(mu_);
293 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
294 *size = marked_size_;
295 return OkStatus();
296 }
297
298 // Return the size that should be used by pack or concat op.
PackOrConcatSize(int32 * size)299 Status PackOrConcatSize(int32* size) {
300 mutex_lock l(mu_);
301 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
302 *size = is_grad_ ? marked_size_ : tensors_.size();
303 return OkStatus();
304 }
305
306 // Once a TensorArray is being used for gradient calculations, it
307 // should be marked as no longer resizeable.
DisableDynamicSize()308 void DisableDynamicSize() {
309 mutex_lock l(mu_);
310 dynamic_size_ = false;
311 }
312
HasDynamicSize()313 bool HasDynamicSize() {
314 mutex_lock l(mu_);
315 return dynamic_size_;
316 }
317
GradientsAllowed()318 bool GradientsAllowed() {
319 mutex_lock l(mu_);
320 return !gradients_disallowed_;
321 }
322
HasIdenticalElementShapes()323 bool HasIdenticalElementShapes() const { return identical_element_shapes_; }
324
325 // Copy the TensorShapes from another TensorArray into this one.
326 // If `shapes_to_prepend` is set, expands the rank of the copied shape by
327 // prepending the passed in shape prefix to the shape values in `rhs`.
328 // The sizes of the two TensorArrays must match and this one
329 // may not have any entries filled in. This performs a "soft copy",
330 // essentially filling the current TensorArray with virtual
331 // zero-tensors, which will be replaced by future aggregate writes,
332 // or instantiated by future reads. Requires a non-const pointer
333 // to the rhs to access its mutex.
334 Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend);
335
336 // Clear the TensorArray, including any Tensor references, and mark as closed.
ClearAndMarkClosed()337 void ClearAndMarkClosed() {
338 mutex_lock l(mu_);
339 tensors_.clear();
340 closed_ = true;
341 }
342
mu()343 mutex* mu() { return &mu_; }
handle()344 Tensor* handle() { return &handle_; }
345
resource_handle(OpKernelContext * ctx)346 ResourceHandle resource_handle(OpKernelContext* ctx) {
347 return ctx->step_container()->MakeResourceHandle<TensorArray>(
348 key_, *ctx->device());
349 }
350
351 private:
352 Status LockedWrite(OpKernelContext* ctx, const int32_t index, Tensor* value)
353 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
354
355 template <typename Device, typename T>
356 Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32_t index,
357 const Tensor* value)
358 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
359
360 template <typename Device, typename T>
361 Status LockedRead(OpKernelContext* ctx, const int32_t index, Tensor* value)
362 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
363
LockedReturnIfClosed()364 Status LockedReturnIfClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
365 if (closed_) {
366 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
367 " has already been closed.");
368 }
369 return OkStatus();
370 }
371
372 const string key_;
373
374 const DataType dtype_;
375 Tensor handle_;
376
377 mutable mutex mu_;
378
379 // Marks that the tensor_array_ has been cleared.
380 bool closed_ TF_GUARDED_BY(mu_);
381
382 // Writes are allowed to grow the array.
383 bool dynamic_size_;
384
385 // Multiple writes to the same index will result in summation of the
386 // values (used by backprop)
387 const bool multiple_writes_aggregate_;
388
389 // If multiple Writes were attempted (e.g. via attribute
390 // multiple_writes_aggregate), then gradients are disallowed.
391 bool gradients_disallowed_ TF_GUARDED_BY(mu_);
392
393 // After a read at an index, clear away its Tensor to release memory.
394 const bool clear_after_read_;
395
396 // True iff this is a gradient tensor array.
397 const bool is_grad_;
398
399 // The size of the TensorArray after a (legacy) unpack or split is performed.
400 // -1 if there has been no unpack or split performed on the TensorArray.
401 int32 marked_size_;
402
403 // The shape of each element in the TensorArray, may be partially known or not
404 // known at all.
405 PartialTensorShape element_shape_ TF_GUARDED_BY(mu_);
406
407 // Whether all elements in the TensorArray have identical shapes.
408 // This allows certain behaviors, like dynamically checking for
409 // consistent shapes on write, and being able to fill in properly
410 // shaped zero tensors on stack -- even if the initial element_shape
411 // was not fully defined.
412 const bool identical_element_shapes_;
413
414 // TensorAndState is used to keep track of the Tensors stored in the
415 // TensorArray, along with their shapes, and a boolean that determines whether
416 // they have already been read or not.
417 struct TensorAndState {
TensorAndStateTensorAndState418 TensorAndState()
419 : written(false), read(false), cleared(false), local_copy(false) {}
420 Tensor tensor;
421 TensorShape shape;
422 bool written; // True if a Tensor has been written to the index.
423 bool read; // True if a Tensor has been written to and read from the index.
424 bool cleared; // True if a tensor has been read with
425 // clear_after_read = true;
426
427 // Used by writes when multiple_writes_aggregate is true. In this
428 // case, the first time a value is written, it is a shallow copy.
429 // The second time a value is written, it is aggregated. However,
430 // in this case a new Tensor must be constructed to hold the
431 // aggregated value. This flag marks that such a Tensor is being
432 // used. All future writes will aggregate to the existing local Tensor.
433 bool local_copy;
434 };
435 // The list of underlying Tensors and states.
436 std::vector<TensorAndState> tensors_ TF_GUARDED_BY(mu_);
437 };
438
439 template <typename Device, typename T>
LockedWriteOrAggregate(OpKernelContext * ctx,const int32_t index,const Tensor * value)440 Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx,
441 const int32_t index,
442 const Tensor* value) {
443 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
444 size_t index_size = static_cast<size_t>(index);
445 if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) {
446 return errors::InvalidArgument(
447 "TensorArray ", handle_.vec<tstring>()(1), ": Tried to write to index ",
448 index, " but array is not resizeable and size is: ", tensors_.size());
449 }
450 if (dynamic_size_) {
451 // We must grow the internal TensorArray
452 if (index_size >= tensors_.capacity()) {
453 tensors_.reserve(2 * (index_size + 1));
454 }
455 if (index_size >= tensors_.size()) {
456 tensors_.resize(index_size + 1);
457 }
458 }
459 TensorAndState& t = tensors_[index];
460
461 if (value->dtype() != dtype_) {
462 return errors::InvalidArgument(
463 "TensorArray ", handle_.vec<tstring>()(1),
464 ": Could not write to TensorArray index ", index,
465 " because the value dtype is ", DataTypeString(value->dtype()),
466 " but TensorArray dtype is ", DataTypeString(dtype_), ".");
467 }
468 if (!element_shape_.IsCompatibleWith(value->shape())) {
469 return errors::InvalidArgument(
470 "TensorArray ", handle_.vec<tstring>()(1),
471 ": Could not write to TensorArray index ", index,
472 " because the value shape is ", value->shape().DebugString(),
473 " which is incompatible with the TensorArray's inferred element "
474 "shape: ",
475 element_shape_.DebugString(), " (consider setting infer_shape=False).");
476 } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) {
477 element_shape_ = PartialTensorShape(value->shape().dim_sizes());
478 }
479
480 if (t.read) {
481 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
482 ": Could not write to TensorArray index ",
483 index, " because it has already been read.");
484 }
485
486 if (!multiple_writes_aggregate_ && t.written) {
487 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
488 ": Could not write to TensorArray index ",
489 index,
490 " because it has already been written to.");
491 }
492
493 if (t.written) {
494 DCHECK(multiple_writes_aggregate_);
495
496 // Check that value shape matches t.shape
497 if (value->shape() != t.shape) {
498 return errors::InvalidArgument(
499 "TensorArray ", handle_.vec<tstring>()(1),
500 ": Could not aggregate to TensorArray index ", index,
501 " because the existing shape is ", t.shape.DebugString(),
502 " but the new input shape is ", value->shape().DebugString(), ".");
503 }
504
505 if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
506 // If existing_t == nullptr but written == true, then what was stored
507 // was just a shape, which just means zeros. So all we must do in this
508 // case is copy the reference over and return early.
509 t.tensor = *value;
510 return OkStatus();
511 }
512
513 Tensor* existing_t = &t.tensor;
514
515 if (t.local_copy) {
516 Status s = tensor_array::AddToTensor<Device, T>(ctx, existing_t,
517 existing_t, value);
518 TF_RETURN_IF_ERROR(s);
519 } else {
520 Tensor local_tensor;
521 TF_RETURN_IF_ERROR(
522 ctx->allocate_temp(dtype_, existing_t->shape(), &local_tensor));
523 Status s = tensor_array::AddToTensor<Device, T>(ctx, &local_tensor,
524 existing_t, value);
525 TF_RETURN_IF_ERROR(s);
526 t.tensor = local_tensor;
527 t.local_copy = true;
528 }
529
530 // We've aggregated the values, so disallow backprop on this
531 // TensorArray.
532 gradients_disallowed_ = true;
533 } else {
534 t.tensor = *value;
535 t.shape = value->shape();
536 t.written = true;
537 }
538 return OkStatus();
539 }
540
541 template <typename Device, typename T>
LockedRead(OpKernelContext * ctx,const int32_t index,Tensor * value)542 Status TensorArray::LockedRead(OpKernelContext* ctx, const int32_t index,
543 Tensor* value) {
544 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
545 if ((index < 0) ||
546 (!is_grad_ && (static_cast<size_t>(index) >= tensors_.size()))) {
547 return errors::InvalidArgument("Tried to read from index ", index,
548 " but array size is: ", tensors_.size());
549 }
550 size_t index_t = static_cast<size_t>(index);
551 if ((is_grad_ && (index_t >= tensors_.size() || !tensors_[index].written)) ||
552 (!is_grad_ && (index_t < tensors_.size() && !tensors_[index].written))) {
553 // Special case returning zeros if this is a gradient read that happens
554 // after a stop_gradients call with dynamic forward TensorArrays.
555 // There is sometimes a race condition where the gradient is not
556 // written due to stop_gradients, but is later read.
557 TensorShape element_shape;
558 if (is_grad_ && index_t < tensors_.size() &&
559 tensors_[index].shape.dims() > 0) {
560 // A gradient TensorArray has more specific gradient information
561 // available for each entry. A forward TensorArray must rely on
562 // the global element_shape_ to fill in zeros on read.
563 element_shape = tensors_[index].shape;
564 } else if (!element_shape_.IsFullyDefined()) {
565 return errors::InvalidArgument(
566 "TensorArray ", handle_.vec<tstring>()(1),
567 ": Could not read from TensorArray index ", index,
568 ". Furthermore, the element shape is not fully defined: ",
569 element_shape_.DebugString(),
570 ". It is possible you are working with a resizeable TensorArray and "
571 "stop_gradients is not allowing the gradients to be written. If you "
572 "set the full "
573 "element_shape property on the forward TensorArray, the proper "
574 "all-zeros tensor "
575 "will be returned instead of incurring this error.");
576 } else {
577 element_shape_.AsTensorShape(&element_shape); // Always succeeds.
578 }
579 if (index_t >= tensors_.size()) {
580 // Fill in tensors_ up to index to have known shape.
581 size_t old_tensors_size = tensors_.size();
582 tensors_.resize(index + 1);
583 for (size_t i = old_tensors_size; i < index + 1; ++i) {
584 tensors_[i].shape = element_shape;
585 tensors_[i].written = true;
586 }
587 } else {
588 tensors_[index].shape = element_shape;
589 tensors_[index].written = true;
590 }
591 }
592
593 TensorAndState& t = tensors_[index];
594
595 if (t.cleared) {
596 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
597 ": Could not read index ", index,
598 " twice because it was cleared after a "
599 "previous read (perhaps try setting "
600 "clear_after_read = false?).");
601 }
602
603 if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
604 // We stored just a shape, but no value. This means create and
605 // return zeros of the appropriate shape.
606 TF_RETURN_IF_ERROR(ctx->allocate_temp(dtype_, t.shape, &t.tensor));
607 if (t.shape.num_elements() > 0) {
608 Status s = tensor_array::TensorSetZero<Device, T>(ctx, &t.tensor);
609 if (!s.ok()) return s;
610 }
611 }
612
613 // Data is available inside the tensor, copy the reference over.
614 *value = t.tensor;
615
616 if (clear_after_read_) {
617 t.tensor = Tensor();
618 t.cleared = true;
619 }
620 t.read = true;
621 return OkStatus();
622 }
623
624 } // namespace tensorflow
625
626 #endif // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
627