1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
17 #define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
18
19 #include <limits.h>
20 #include <vector>
21
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/aggregate_ops.h"
30 #include "tensorflow/core/kernels/fill_functor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34
35 namespace tensorflow {
36
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39
40 namespace tensor_array {
41
42 // Full implementations are in tensor_array.cc
43 template <typename Device, typename T>
AddToTensor(OpKernelContext * ctx,Tensor * sum,const Tensor * current,const Tensor * add)44 Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current,
45 const Tensor* add) {
46 return errors::InvalidArgument(
47 "tensor_array::AddToTensor type not supported: ",
48 DataTypeString(DataTypeToEnum<T>::value));
49 }
50
51 #define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \
52 template <> \
53 Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \
54 const Tensor* current, const Tensor* add);
55
56 #define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
57 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
58 #undef TENSOR_ARRAY_WRITE_OR_ADD_CPU
59
60 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
61
62 #define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
63 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
64 TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
65 #undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
66
67 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
68
69 #undef TENSOR_ARRAY_WRITE_OR_ADD
70
71 template <typename Device, typename T>
TensorSetZero(OpKernelContext * ctx,Tensor * value)72 Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
73 return errors::InvalidArgument(
74 "tensor_array::TensorSetZero type not supported: ",
75 DataTypeString(DataTypeToEnum<T>::value));
76 }
77
78 #define TENSOR_ARRAY_SET_ZERO(Device, T) \
79 template <> \
80 Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
81
82 #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
83 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
84 TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
85 #undef TENSOR_ARRAY_SET_ZERO_CPU
86
87 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88
89 #define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
90 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
91 TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
92 #undef TENSOR_ARRAY_SET_ZERO_GPU
93
94 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
95
96 #undef TENSOR_ARRAY_SET_ZERO
97
98 } // namespace tensor_array
99
100 // The TensorArray object keeps an array of PersistentTensors. It
101 // allows reading from the array and writing to the array.
102 //
103 // Important properties:
104 // * Usually, writing to a particular index in the TensorArray is allowed at
105 // most once per index. In a special case, writes with the flag
106 // multiple_writes_aggregate allow multiple writes to the same
107 // index. In this case, the writes are summed.
108 // * Multiple reads are supported.
109 // * Deep copies of PersistentTensors are rarely made. The only
110 // time they are made is when WriteOrAggregate is called at least twice
111 // on the same index with the flag multiple_writes_aggregate = True.
112 // * Reading and Writing to the array is protected by a mutex.
113 // All operations on a TensorArray are thread-safe.
114 // * A TensorArray may be preemptively closed, which releases all
115 // memory associated with it.
116 //
117 // These properties together allow the TensorArray to work as a
118 // functional object and makes gradient computation easy. For
119 // example:
120 // * Write-Once semantics mean the gradient of a TensorArray Read never has to
121 // worry which of multiple writes to that index the gradient value
122 // is meant for.
123 // * Read-Many semantics (when using clear_after_read=false) allow the
124 // TensorArray to be read, packed, or concatenated multiple times;
125 // and the gradient operations use the multiple_writes_aggregate
126 // flag to aggregate the backprop writes. Multiple backprop writes to
127 // the same index are partial gradients corresponding to the
128 // multiple reads of that index in the forward phase.
129 //
130 class TensorArray : public ResourceBase {
131 public:
132 static std::atomic<int64> tensor_array_counter;
133
134 // Construct a TensorArray for holding Tensors of type 'dtype' with
135 // 'N' elements. While the underlying storage is a std::vector and
136 // can hold more than MAX_INT entries, in practice we do not expect
137 // users to construct this many Tensors for storage in a TensorArray.
TensorArray(const string & key,const DataType & dtype,const Tensor & handle,int32 N,const PartialTensorShape & element_shape,bool identical_element_shapes,bool dynamic_size,bool multiple_writes_aggregate,bool is_grad,int32 marked_size,bool clear_after_read)138 TensorArray(const string& key, const DataType& dtype, const Tensor& handle,
139 int32 N, const PartialTensorShape& element_shape,
140 bool identical_element_shapes, bool dynamic_size,
141 bool multiple_writes_aggregate, bool is_grad, int32 marked_size,
142 bool clear_after_read)
143 : key_(key),
144 dtype_(dtype),
145 handle_(handle),
146 closed_(false),
147 dynamic_size_(dynamic_size),
148 multiple_writes_aggregate_(multiple_writes_aggregate),
149 gradients_disallowed_(false),
150 clear_after_read_(clear_after_read),
151 is_grad_(is_grad),
152 marked_size_(marked_size),
153 element_shape_(element_shape),
154 identical_element_shapes_(identical_element_shapes),
155 tensors_(N) {}
156
157 // Write PersistentTensor 'value' to index 'index'.
158 //
159 // Preconditions:
160 // * The TensorArray is not closed
161 // * If the array has dynamic size:
162 // The index is >= 0
163 // Otherwise:
164 // The index is in [0, N) where N == Size()
165 // * The dtype of the Tensor in 'value' matches the TensorArray's dtype.
166 // * If multiple_writes_aggregate is false:
167 // The Tensor at 'index' has not yet been written to.
168 // * If multiple_writes_aggregate is true:
169 // The Tensor at 'index' has the same shape as value.
170 //
171 // Side effects:
172 // * On the first write to 'index':
173 // - The underlying Tensor in 'value' has a new reference to it.
174 // - The index 'index' is marked as written.
175 // * If multiple_writes_aggregate is false, subsequent writes to 'index'
176 // raise an InvalidArgument error.
177 // * If multiple_writes_aggregate is true, subsequent writes to 'index':
178 // - The underlying Tensors in 'value' and from the first write
179 // are released and a local PersistentTensor is created.
180 // - Index 'index' is also marked as local_copy.
181 // - The gradients_disallowed flag is set true (GradientsAllowed()
182 // will now return false).
183 //
184 // Note, value is passed as a pointer because we its underlying
185 // Tensor's shape is accessed. Otherwise it is not modified.
186 template <typename Device, typename T>
WriteOrAggregate(OpKernelContext * ctx,const int32 index,PersistentTensor * value)187 Status WriteOrAggregate(OpKernelContext* ctx, const int32 index,
188 PersistentTensor* value) {
189 mutex_lock l(mu_);
190 return LockedWriteOrAggregate<Device, T>(ctx, index, value);
191 }
192
193 template <typename Device, typename T>
WriteOrAggregateMany(OpKernelContext * ctx,const std::vector<int32> & indices,std::vector<PersistentTensor> * values)194 Status WriteOrAggregateMany(OpKernelContext* ctx,
195 const std::vector<int32>& indices,
196 std::vector<PersistentTensor>* values) {
197 mutex_lock l(mu_);
198 int32 i = 0;
199 for (const int32 ix : indices) {
200 Status s = LockedWriteOrAggregate<Device, T>(ctx, ix, &(*values)[i]);
201 ++i;
202 TF_RETURN_IF_ERROR(s);
203 }
204 return Status::OK();
205 }
206
207 // Read from index 'index' into PersistentTensor 'value'.
208 //
209 // Preconditions:
210 // * The TensorArray is not closed
211 // * The index is in [0, N)
212 // * The Tensor at 'index' has been written to.
213 // * The Tensor at 'index' has not been read from with flag
214 // clear_after_read = true.
215 //
216 // Side effects:
217 // * If clear_after_read is true, the reference to the underlying
218 // Tensor is deleted.
219 // * The reference to the underlying Tensor at 'index' is copied to
220 // the returned '*value'.
221 // * The index is marked as read (it cannot be rewritten to).
222 template <typename Device, typename T>
Read(OpKernelContext * ctx,const int32 index,PersistentTensor * value)223 Status Read(OpKernelContext* ctx, const int32 index,
224 PersistentTensor* value) {
225 mutex_lock l(mu_);
226 return LockedRead<Device, T>(ctx, index, value);
227 }
228
229 template <typename Device, typename T>
ReadMany(OpKernelContext * ctx,const std::vector<int32> & indices,std::vector<PersistentTensor> * values)230 Status ReadMany(OpKernelContext* ctx, const std::vector<int32>& indices,
231 std::vector<PersistentTensor>* values) {
232 mutex_lock l(mu_);
233 values->clear();
234 values->resize(indices.size());
235 int32 i = 0;
236 for (const int32 ix : indices) {
237 Status s = LockedRead<Device, T>(ctx, ix, &(*values)[i]);
238 ++i;
239 if (!s.ok()) return s;
240 }
241 return Status::OK();
242 }
243
ElemType()244 DataType ElemType() const { return dtype_; }
245
ElemShape()246 PartialTensorShape ElemShape() {
247 mutex_lock l(mu_);
248 return element_shape_;
249 }
250
SetElemShape(const PartialTensorShape & candidate)251 Status SetElemShape(const PartialTensorShape& candidate) {
252 mutex_lock l(mu_);
253 PartialTensorShape new_element_shape_;
254 Status s = element_shape_.MergeWith(candidate, &new_element_shape_);
255 if (!s.ok()) {
256 return s;
257 }
258 element_shape_ = new_element_shape_;
259 return Status::OK();
260 }
261
DebugString()262 string DebugString() const override {
263 mutex_lock l(mu_);
264 CHECK(!closed_);
265 return strings::StrCat("TensorArray[", tensors_.size(), "]");
266 }
267
IsClosed()268 bool IsClosed() {
269 mutex_lock l(mu_);
270 return closed_;
271 }
272
273 // Return the size of the TensorArray.
Size(int32 * size)274 Status Size(int32* size) {
275 mutex_lock l(mu_);
276 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
277 *size = tensors_.size();
278 return Status::OK();
279 }
280
281 // Record the size of the TensorArray after an unpack or split.
SetMarkedSize(int32 size)282 Status SetMarkedSize(int32 size) {
283 mutex_lock l(mu_);
284 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
285 if (!is_grad_) {
286 marked_size_ = size;
287 }
288 return Status::OK();
289 }
290
291 // Return the marked size of the TensorArray.
MarkedSize(int32 * size)292 Status MarkedSize(int32* size) {
293 mutex_lock l(mu_);
294 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
295 *size = marked_size_;
296 return Status::OK();
297 }
298
299 // Return the size that should be used by pack or concat op.
PackOrConcatSize(int32 * size)300 Status PackOrConcatSize(int32* size) {
301 mutex_lock l(mu_);
302 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
303 *size = is_grad_ ? marked_size_ : tensors_.size();
304 return Status::OK();
305 }
306
307 // Once a TensorArray is being used for gradient calculations, it
308 // should be marked as no longer resizeable.
DisableDynamicSize()309 void DisableDynamicSize() {
310 mutex_lock l(mu_);
311 dynamic_size_ = false;
312 }
313
HasDynamicSize()314 bool HasDynamicSize() {
315 mutex_lock l(mu_);
316 return dynamic_size_;
317 }
318
GradientsAllowed()319 bool GradientsAllowed() {
320 mutex_lock l(mu_);
321 return !gradients_disallowed_;
322 }
323
HasIdenticalElementShapes()324 bool HasIdenticalElementShapes() const { return identical_element_shapes_; }
325
326 // Copy the TensorShapes from another TensorArray into this one.
327 // If `shapes_to_prepend` is set, expands the rank of the copied shape by
328 // prepending the passed in shape prefix to the shape values in `rhs`.
329 // The sizes of the two TensorArrays must match and this one
330 // may not have any entries filled in. This performs a "soft copy",
331 // essentially filling the current TensorArray with virtual
332 // zero-tensors, which will be replaced by future aggregate writes,
333 // or instantiated by future reads. Requires a non-const pointer
334 // to the rhs to access its mutex.
335 Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend);
336
337 // Clear the TensorArray, including any Tensor references, and mark as closed.
ClearAndMarkClosed()338 void ClearAndMarkClosed() {
339 mutex_lock l(mu_);
340 tensors_.clear();
341 closed_ = true;
342 }
343
mu()344 mutex* mu() { return &mu_; }
handle()345 Tensor* handle() { return &handle_; }
346
resource_handle(OpKernelContext * ctx)347 ResourceHandle resource_handle(OpKernelContext* ctx) {
348 return ctx->step_container()->MakeResourceHandle<TensorArray>(
349 key_, *ctx->device());
350 }
351
352 private:
353 Status LockedWrite(OpKernelContext* ctx, const int32 index,
354 PersistentTensor* value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
355
356 template <typename Device, typename T>
357 Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32 index,
358 PersistentTensor* value)
359 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
360
361 template <typename Device, typename T>
362 Status LockedRead(OpKernelContext* ctx, const int32 index,
363 PersistentTensor* value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
364
LockedReturnIfClosed()365 Status LockedReturnIfClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
366 if (closed_) {
367 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
368 " has already been closed.");
369 }
370 return Status::OK();
371 }
372
373 const string key_;
374
375 const DataType dtype_;
376 Tensor handle_;
377
378 mutable mutex mu_;
379
380 // Marks that the tensor_array_ has been cleared.
381 bool closed_ TF_GUARDED_BY(mu_);
382
383 // Writes are allowed to grow the array.
384 bool dynamic_size_;
385
386 // Multiple writes to the same index will result in summation of the
387 // values (used by backprop)
388 const bool multiple_writes_aggregate_;
389
390 // If multiple Writes were attempted (e.g. via attribute
391 // multiple_writes_aggregate), then gradients are disallowed.
392 bool gradients_disallowed_ TF_GUARDED_BY(mu_);
393
394 // After a read at an index, clear away its PersistentTensor to
395 // release memory.
396 const bool clear_after_read_;
397
398 // True iff this is a gradient tensor array.
399 const bool is_grad_;
400
401 // The size of the TensorArray after a (legacy) unpack or split is performed.
402 // -1 if there has been no unpack or split performed on the TensorArray.
403 int32 marked_size_;
404
405 // The shape of each element in the TensorArray, may be partially known or not
406 // known at all.
407 PartialTensorShape element_shape_ TF_GUARDED_BY(mu_);
408
409 // Whether all elements in the TensorArray have identical shapes.
410 // This allows certain behaviors, like dynamically checking for
411 // consistent shapes on write, and being able to fill in properly
412 // shaped zero tensors on stack -- even if the initial element_shape
413 // was not fully defined.
414 const bool identical_element_shapes_;
415
416 // TensorAndState is used to keep track of the PersistentTensors
417 // stored in the TensorArray, along with their shapes, and a boolean
418 // that determines whether they have already been read or not.
419 struct TensorAndState {
TensorAndStateTensorAndState420 TensorAndState()
421 : written(false), read(false), cleared(false), local_copy(false) {}
422 PersistentTensor tensor;
423 TensorShape shape;
424 bool written; // True if a Tensor has been written to the index.
425 bool read; // True if a Tensor has been written to and read from the index.
426 bool cleared; // True if a tensor has been read with
427 // clear_after_read = true;
428
429 // Used by writes when multiple_writes_aggregate is true. In this
430 // case, the first time a value is written, it is a shallow copy.
431 // The second time a value is written, it is aggregated. However,
432 // in this case a new Tensor must be constructed to hold the
433 // aggregated value. This flag marks that such a Tensor is being
434 // used. All future writes will aggregate to the existing local Tensor.
435 bool local_copy;
436 };
437 // The list of underlying PersistentTensors and states.
438 std::vector<TensorAndState> tensors_ TF_GUARDED_BY(mu_);
439 };
440
441 template <typename Device, typename T>
LockedWriteOrAggregate(OpKernelContext * ctx,const int32 index,PersistentTensor * value)442 Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx,
443 const int32 index,
444 PersistentTensor* value) {
445 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
446 size_t index_size = static_cast<size_t>(index);
447 if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) {
448 return errors::InvalidArgument(
449 "TensorArray ", handle_.vec<tstring>()(1), ": Tried to write to index ",
450 index, " but array is not resizeable and size is: ", tensors_.size());
451 }
452 if (dynamic_size_) {
453 // We must grow the internal TensorArray
454 if (index_size >= tensors_.capacity()) {
455 tensors_.reserve(2 * (index_size + 1));
456 }
457 if (index_size >= tensors_.size()) {
458 tensors_.resize(index_size + 1);
459 }
460 }
461 TensorAndState& t = tensors_[index];
462
463 Tensor* value_t = value->AccessTensor(ctx);
464 if (value_t->dtype() != dtype_) {
465 return errors::InvalidArgument(
466 "TensorArray ", handle_.vec<tstring>()(1),
467 ": Could not write to TensorArray index ", index,
468 " because the value dtype is ", DataTypeString(value_t->dtype()),
469 " but TensorArray dtype is ", DataTypeString(dtype_), ".");
470 }
471 if (!element_shape_.IsCompatibleWith(value_t->shape())) {
472 return errors::InvalidArgument(
473 "TensorArray ", handle_.vec<tstring>()(1),
474 ": Could not write to TensorArray index ", index,
475 " because the value shape is ", value_t->shape().DebugString(),
476 " which is incompatible with the TensorArray's inferred element "
477 "shape: ",
478 element_shape_.DebugString(), " (consider setting infer_shape=False).");
479 } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) {
480 element_shape_ = PartialTensorShape(value_t->shape().dim_sizes());
481 }
482
483 if (t.read) {
484 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
485 ": Could not write to TensorArray index ",
486 index, " because it has already been read.");
487 }
488
489 if (!multiple_writes_aggregate_ && t.written) {
490 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
491 ": Could not write to TensorArray index ",
492 index,
493 " because it has already been written to.");
494 }
495
496 if (t.written) {
497 DCHECK(multiple_writes_aggregate_);
498
499 // Check that value_t shape matches t.shape
500 if (value_t->shape() != t.shape) {
501 return errors::InvalidArgument(
502 "TensorArray ", handle_.vec<tstring>()(1),
503 ": Could not aggregate to TensorArray index ", index,
504 " because the existing shape is ", t.shape.DebugString(),
505 " but the new input shape is ", value_t->shape().DebugString(), ".");
506 }
507
508 if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
509 // If existing_t == nullptr but written == true, then what was stored
510 // was just a shape, which just means zeros. So all we must do in this
511 // case is copy the reference over and return early.
512 t.tensor = *value;
513 return Status::OK();
514 }
515
516 Tensor* existing_t = t.tensor.AccessTensor(ctx);
517
518 if (t.local_copy) {
519 Status s = tensor_array::AddToTensor<Device, T>(ctx, existing_t,
520 existing_t, value_t);
521 TF_RETURN_IF_ERROR(s);
522 } else {
523 PersistentTensor local_tensor;
524 Tensor* local_tensor_t;
525 TF_RETURN_IF_ERROR(ctx->allocate_persistent(
526 dtype_, existing_t->shape(), &local_tensor, &local_tensor_t));
527 Status s = tensor_array::AddToTensor<Device, T>(ctx, local_tensor_t,
528 existing_t, value_t);
529 TF_RETURN_IF_ERROR(s);
530 t.tensor = local_tensor;
531 t.local_copy = true;
532 }
533
534 // We've aggregated the values, so disallow backprop on this
535 // TensorArray.
536 gradients_disallowed_ = true;
537 } else {
538 t.tensor = *value;
539 t.shape = value_t->shape();
540 t.written = true;
541 }
542 return Status::OK();
543 }
544
545 template <typename Device, typename T>
LockedRead(OpKernelContext * ctx,const int32 index,PersistentTensor * value)546 Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
547 PersistentTensor* value) {
548 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
549 if ((index < 0) ||
550 (!is_grad_ && (static_cast<size_t>(index) >= tensors_.size()))) {
551 return errors::InvalidArgument("Tried to read from index ", index,
552 " but array size is: ", tensors_.size());
553 }
554 size_t index_t = static_cast<size_t>(index);
555 if ((is_grad_ && (index_t >= tensors_.size() || !tensors_[index].written)) ||
556 (!is_grad_ && (index_t < tensors_.size() && !tensors_[index].written))) {
557 // Special case returning zeros if this is a gradient read that happens
558 // after a stop_gradients call with dynamic forward TensorArrays.
559 // There is sometimes a race condition where the gradient is not
560 // written due to stop_gradients, but is later read.
561 TensorShape element_shape;
562 if (is_grad_ && index_t < tensors_.size() &&
563 tensors_[index].shape.dims() > 0) {
564 // A gradient TensorArray has more specific gradient information
565 // available for each entry. A forward TensorArray must rely on
566 // the global element_shape_ to fill in zeros on read.
567 element_shape = tensors_[index].shape;
568 } else if (!element_shape_.IsFullyDefined()) {
569 return errors::InvalidArgument(
570 "TensorArray ", handle_.vec<tstring>()(1),
571 ": Could not read from TensorArray index ", index,
572 ". Furthermore, the element shape is not fully defined: ",
573 element_shape_.DebugString(),
574 ". It is possible you are working with a resizeable TensorArray and "
575 "stop_gradients is not allowing the gradients to be written. If you "
576 "set the full "
577 "element_shape property on the forward TensorArray, the proper "
578 "all-zeros tensor "
579 "will be returned instead of incurring this error.");
580 } else {
581 element_shape_.AsTensorShape(&element_shape); // Always succeeds.
582 }
583 if (index_t >= tensors_.size()) {
584 // Fill in tensors_ up to index to have known shape.
585 size_t old_tensors_size = tensors_.size();
586 tensors_.resize(index + 1);
587 for (size_t i = old_tensors_size; i < index + 1; ++i) {
588 tensors_[i].shape = element_shape;
589 tensors_[i].written = true;
590 }
591 } else {
592 tensors_[index].shape = element_shape;
593 tensors_[index].written = true;
594 }
595 }
596
597 TensorAndState& t = tensors_[index];
598
599 if (t.cleared) {
600 return errors::InvalidArgument("TensorArray ", handle_.vec<tstring>()(1),
601 ": Could not read index ", index,
602 " twice because it was cleared after a "
603 "previous read (perhaps try setting "
604 "clear_after_read = false?).");
605 }
606
607 if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
608 // We stored just a shape, but no value. This means create and
609 // return zeros of the appropriate shape.
610 Tensor* tensor_t;
611 TF_RETURN_IF_ERROR(
612 ctx->allocate_persistent(dtype_, t.shape, &t.tensor, &tensor_t));
613 if (t.shape.num_elements() > 0) {
614 Status s = tensor_array::TensorSetZero<Device, T>(ctx, tensor_t);
615 if (!s.ok()) return s;
616 }
617 }
618
619 // Data is available inside the tensor, copy the reference over.
620 *value = t.tensor;
621
622 if (clear_after_read_) {
623 t.tensor = PersistentTensor();
624 t.cleared = true;
625 }
626 t.read = true;
627 return Status::OK();
628 }
629
630 } // namespace tensorflow
631
632 #endif // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
633