• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
16 #define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
17 
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 #include "tensorflow/core/kernels/concat_lib.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/tensor_list.h"
33 #include "tensorflow/core/lib/core/coding.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/refcount.h"
36 #include "tensorflow/core/lib/gtl/array_slice.h"
37 #include "tensorflow/core/platform/platform.h"
38 #include "tensorflow/core/util/tensor_ops_util.h"
39 #include "tensorflow/core/util/util.h"
40 
41 namespace tensorflow {
42 
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44 
45 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
46 
47 Status GetElementShapeFromInput(OpKernelContext* c,
48                                 const TensorList& tensor_list, int index,
49                                 PartialTensorShape* element_shape);
50 
51 Status GetInputList(OpKernelContext* c, int index, const TensorList** list);
52 
53 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
54                                    int32 output_index,
55                                    const TensorList& input_list,
56                                    TensorList** output_list);
57 
58 template <typename Device, typename T>
59 class TensorListStack : public OpKernel {
60  public:
61   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
62       ConstMatrixVector;
TensorListStack(OpKernelConstruction * c)63   explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) {
64     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
65     OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_));
66   }
67 
Compute(OpKernelContext * c)68   void Compute(OpKernelContext* c) override {
69     const TensorList* tensor_list = nullptr;
70     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
71     OP_REQUIRES(
72         c, element_dtype_ == tensor_list->element_dtype,
73         errors::InvalidArgument(
74             "Invalid data types; op elements ", DataTypeString(element_dtype_),
75             " but list elements ", DataTypeString(tensor_list->element_dtype)));
76     if (num_elements_ != -1) {
77       OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_,
78                   errors::InvalidArgument(
79                       "Operation expected a list with ", num_elements_,
80                       " elements but got a list with ",
81                       tensor_list->tensors().size(), " elements."));
82     }
83     PartialTensorShape partial_element_shape;
84     OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
85                                                &partial_element_shape));
86     OP_REQUIRES(
87         c,
88         partial_element_shape.IsFullyDefined() ||
89             !tensor_list->tensors().empty(),
90         errors::InvalidArgument("Tried to stack elements of an empty ",
91                                 "list with non-fully-defined element_shape: ",
92                                 partial_element_shape.DebugString()));
93 
94     // Check that `element_shape` input tensor is compatible with the shapes of
95     // element tensors.
96     if (!tensor_list->element_shape.IsFullyDefined()) {
97       for (int i = 0; i < tensor_list->tensors().size(); ++i) {
98         const Tensor& t = tensor_list->tensors()[i];
99         if (t.dtype() != DT_INVALID) {
100           PartialTensorShape tmp = partial_element_shape;
101           OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
102         }
103       }
104     }
105 
106     // Compute the shape of the output tensor by pre-pending the leading dim to
107     // the element_shape.
108     TensorShape element_shape;
109     OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape),
110                 errors::InvalidArgument(
111                     "Tried to stack list which only contains uninitialized ",
112                     "tensors and has a non-fully-defined element_shape: ",
113                     partial_element_shape.DebugString()));
114     TensorShape output_shape = element_shape;
115     output_shape.InsertDim(0, tensor_list->tensors().size());
116     Tensor* output;
117     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
118     if (output->NumElements() == 0) {
119       return;
120     }
121 
122     ConstMatrixVector inputs_flat;
123     inputs_flat.reserve(tensor_list->tensors().size());
124     Tensor zeros;
125     for (const auto& t : tensor_list->tensors()) {
126       if (t.dtype() != DT_INVALID) {
127         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
128             t.shaped<T, 2>({1, t.NumElements()})));
129       } else {
130         if (!zeros.NumElements()) {
131           AllocatorAttributes attr;
132           if (element_dtype_ == DT_VARIANT) {
133             attr.set_on_host(true);
134           }
135           OP_REQUIRES_OK(
136               c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
137           functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
138                                                zeros.flat<T>());
139         }
140         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
141             const_cast<const Tensor&>(zeros).shaped<T, 2>(
142                 {1, zeros.NumElements()})));
143       }
144     }
145     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
146 
147 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
148     if (std::is_same<Device, Eigen::GpuDevice>::value) {
149       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
150       return;
151     }
152 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
153     ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
154   }
155 
156  private:
157   int num_elements_;
158   DataType element_dtype_;
159 };
160 
161 template <typename Device, typename T>
162 class TensorListGetItem : public OpKernel {
163  public:
TensorListGetItem(OpKernelConstruction * c)164   explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
165     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
166   }
167 
Compute(OpKernelContext * c)168   void Compute(OpKernelContext* c) override {
169     const TensorList* l = nullptr;
170     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
171     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
172                 errors::InvalidArgument("Invalid data types; op elements ",
173                                         DataTypeString(element_dtype_),
174                                         " but list elements ",
175                                         DataTypeString(l->element_dtype)));
176     int32 index = c->input(1).scalar<int32>()();
177     OP_REQUIRES(c, index < l->tensors().size(),
178                 errors::InvalidArgument("Trying to access element ", index,
179                                         " in a list with ", l->tensors().size(),
180                                         " elements."));
181     if (l->tensors()[index].dtype() != DT_INVALID) {
182       c->set_output(0, l->tensors()[index]);
183     } else {
184       PartialTensorShape partial_element_shape;
185       OP_REQUIRES_OK(
186           c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape));
187       TensorShape element_shape;
188       // If l->element_shape and the element_shape input are both not fully
189       // defined, try to infer the shape from other list elements. This requires
190       // that all initialized list elements have the same shape.
191       // NOTE(srbs): This might be a performance bottleneck since we are
192       // iterating over the entire list here. This is necessary for feature
193       // parity with TensorArray.read. TensorArray has a mode in which all
194       // elements are required to be of the same shape, TensorList does not.
195       // In that mode TensorArray sets the array's element_shape on the first
196       // write call. We could do something similar here if needed.
197       if (!partial_element_shape.IsFullyDefined()) {
198         for (const Tensor& t : l->tensors()) {
199           if (t.dtype() != DT_INVALID) {
200             PartialTensorShape tmp = partial_element_shape;
201             OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
202           }
203         }
204       }
205       OP_REQUIRES(
206           c, partial_element_shape.AsTensorShape(&element_shape),
207           errors::InvalidArgument("Trying to read an uninitialized tensor but ",
208                                   "element_shape is not fully defined: ",
209                                   partial_element_shape.DebugString(),
210                                   " and no list element is set."));
211       Tensor* result;
212       AllocatorAttributes attr;
213       if (element_dtype_ == DT_VARIANT) {
214         attr.set_on_host(true);
215       }
216       OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr));
217       functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
218                                            result->flat<T>());
219     }
220   }
221 
222  private:
223   DataType element_dtype_;
224 };
225 
226 template <typename Device, typename T>
227 class TensorListPopBack : public OpKernel {
228  public:
TensorListPopBack(OpKernelConstruction * c)229   explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
230     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
231   }
232 
Compute(OpKernelContext * c)233   void Compute(OpKernelContext* c) override {
234     const TensorList* l = nullptr;
235     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
236     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
237                 errors::InvalidArgument("Invalid data types; op elements ",
238                                         DataTypeString(element_dtype_),
239                                         " but list elements ",
240                                         DataTypeString(l->element_dtype)));
241 
242     OP_REQUIRES(c, !l->tensors().empty(),
243                 errors::InvalidArgument("Trying to pop from an empty list."));
244 
245     const Tensor& t = l->tensors().back();
246     if (t.dtype() != DT_INVALID) {
247       c->set_output(1, t);
248     } else {
249       PartialTensorShape partial_element_shape;
250       OP_REQUIRES_OK(
251           c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape));
252       TensorShape element_shape;
253       OP_REQUIRES(
254           c, partial_element_shape.AsTensorShape(&element_shape),
255           errors::InvalidArgument("Trying to read an uninitialized tensor but ",
256                                   "element_shape is not fully defined.",
257                                   partial_element_shape.DebugString()));
258       Tensor* result;
259       AllocatorAttributes attr;
260       if (element_dtype_ == DT_VARIANT) {
261         attr.set_on_host(true);
262       }
263       OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr));
264       functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
265                                            result->flat<T>());
266     }
267 
268     TensorList* output_list = nullptr;
269     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
270     output_list->tensors().pop_back();
271   }
272 
273  private:
274   DataType element_dtype_;
275 };
276 
277 template <typename Device, typename T>
278 class TensorListConcat : public OpKernel {
279  public:
280   using ConstMatrixVector =
281       std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>;
TensorListConcat(OpKernelConstruction * c)282   explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) {
283     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
284     if (c->HasAttr("element_shape")) {
285       OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape_));
286     }
287   }
288 
Compute(OpKernelContext * c)289   void Compute(OpKernelContext* c) override {
290     PartialTensorShape element_shape_except_first_dim;
291     if (!element_shape_.unknown_rank()) {
292       element_shape_except_first_dim = PartialTensorShape(
293           gtl::ArraySlice<int64>(element_shape_.dim_sizes()).subspan(1));
294     }
295     // Check that the input Variant tensor is indeed a TensorList and has the
296     // correct element type.
297     const TensorList* tensor_list = nullptr;
298     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
299     OP_REQUIRES(
300         c, element_dtype_ == tensor_list->element_dtype,
301         errors::InvalidArgument(
302             "Invalid data types; op elements ", DataTypeString(element_dtype_),
303             " but list elements ", DataTypeString(tensor_list->element_dtype)));
304     // The leading dimension of all list elements if they are all the same.
305     // This is used as the leading dim of uninitialized tensors in the list
306     // if leading_dims is not provided.
307     int64 first_dim = -1;
308     if (c->num_inputs() > 1) {
309       // TensorListConcatV2
310       PartialTensorShape element_shape;
311       OP_REQUIRES_OK(
312           c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape));
313       OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
314                   errors::InvalidArgument(
315                       "Concat requires elements to be at least vectors, ",
316                       "found scalars instead."));
317       // Split `element_shape` into `first_dim` and
318       // `element_shape_except_first_dim`.
319       first_dim = element_shape.dim_size(0);
320       element_shape_except_first_dim = element_shape;
321       element_shape_except_first_dim.RemoveDim(0);
322     }
323     // If the TensorList is empty, element_shape_except_first_dim must be fully
324     // defined.
325     OP_REQUIRES(c,
326                 !tensor_list->tensors().empty() ||
327                     element_shape_except_first_dim.IsFullyDefined(),
328                 errors::InvalidArgument(
329                     "All except the first dimension must be fully defined ",
330                     "when concating an empty tensor list. element_shape: ",
331                     element_shape_except_first_dim.DebugString()));
332     // 1. Check that `element_shape_except_first_dim` input tensor is
333     //    compatible with the shapes of element tensors.
334     // 2. Check that the elements have the same shape except the first dim.
335     // 3. If `first_dim` is known, check that it is compatible with the leading
336     //    dims of all elements.
337     // 4. If `first_dim` is unknown (-1), check whether all initialized
338     //    elements have the same leading dim and if so set `first_dim` to that
339     //    value.
340     if (!tensor_list->element_shape.IsFullyDefined()) {
341       bool check_dim = (first_dim == -1);
342       int64 inferred_first_dim = first_dim;
343       for (int i = 0; i < tensor_list->tensors().size(); ++i) {
344         const Tensor& t = tensor_list->tensors()[i];
345         if (t.dtype() != DT_INVALID) {
346           PartialTensorShape tmp = element_shape_except_first_dim;
347           OP_REQUIRES(
348               c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
349               errors::InvalidArgument("Concat saw a scalar shape at index ", i,
350                                       " but requires at least vectors."));
351           TensorShape shape_except_first_dim = TensorShape(
352               gtl::ArraySlice<int64>(t.shape().dim_sizes()).subspan(1));
353           OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
354                                           &element_shape_except_first_dim));
355           OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0),
356                       errors::InvalidArgument(
357                           "First entry of element_shape input does not match ",
358                           "the first dim of list element at index: ", i,
359                           " Expected: ", first_dim,
360                           " Actual: ", t.shape().dim_size(0)));
361           if (check_dim) {
362             if (inferred_first_dim == -1) {
363               inferred_first_dim = t.shape().dim_size(0);
364             } else if (inferred_first_dim != t.shape().dim_size(0)) {
365               inferred_first_dim = -1;
366               check_dim = false;
367             }
368           }
369         }
370       }
371       first_dim = inferred_first_dim;
372     }
373     TensorShape output_shape;
374     OP_REQUIRES(c, element_shape_except_first_dim.AsTensorShape(&output_shape),
375                 errors::InvalidArgument(
376                     "Trying to concat list with only uninitialized tensors ",
377                     "but element_shape_except_first_dim is not fully defined: ",
378                     element_shape_except_first_dim.DebugString()));
379     // Build the lengths_tensor and leading dim of the output tensor by
380     // iterating over all element tensors.
381     Tensor* lengths_tensor = nullptr;
382     OP_REQUIRES_OK(
383         c,
384         c->allocate_output(
385             1, TensorShape({static_cast<int64>(tensor_list->tensors().size())}),
386             &lengths_tensor));
387     auto lengths_tensor_vec = lengths_tensor->vec<int64>();
388     int64 leading_dim = 0;
389     for (size_t i = 0; i < tensor_list->tensors().size(); i++) {
390       int64 dim;
391       if (tensor_list->tensors()[i].dtype() != DT_INVALID) {
392         dim = tensor_list->tensors()[i].shape().dim_size(0);
393       } else {
394         // If leading_dims is not provided or does not contain an entry for
395         // index i use the inferred `first_dim` if set.
396         if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) &&
397             first_dim != -1) {
398           dim = first_dim;
399         } else {
400           OP_REQUIRES(c, c->num_inputs() > 2,
401                       errors::InvalidArgument(
402                           "Concating lists with uninitialized tensors is not ",
403                           "supported in this version of TensorListConcat. ",
404                           "Consider updating your GraphDef to run the newer ",
405                           "version."));
406           OP_REQUIRES(c, i < c->input(2).NumElements(),
407                       errors::InvalidArgument(
408                           "List contains uninitialized tensor at index ", i,
409                           " but leading_dims has only ",
410                           c->input(2).NumElements(), " elements."));
411           dim = c->input(2).vec<int64>()(i);
412         }
413       }
414       leading_dim += dim;
415       lengths_tensor_vec(i) = dim;
416     }
417     output_shape.InsertDim(0, leading_dim);
418     Tensor* output;
419     // Allocate the output tensor and fill it up with the concated element
420     // tensors.
421     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
422     if (output->NumElements() == 0) {
423       return;
424     }
425 
426     ConstMatrixVector inputs_flat;
427     inputs_flat.reserve(tensor_list->tensors().size());
428     // Store the zeros tensors in a vector to prevent them from being GC'ed till
429     // concat is complete.
430     std::vector<Tensor> zeros_vec;
431     for (int i = 0; i < tensor_list->tensors().size(); i++) {
432       const Tensor& element_tensor = tensor_list->tensors()[i];
433       if (element_tensor.dtype() != DT_INVALID) {
434         if (element_tensor.NumElements() > 0) {
435           inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
436               element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
437         }
438       } else {
439         AllocatorAttributes attr;
440         if (element_dtype_ == DT_VARIANT) {
441           attr.set_on_host(true);
442         }
443         TensorShape element_shape = output_shape;
444         element_shape.set_dim(0, lengths_tensor_vec(i));
445         zeros_vec.emplace_back();
446         Tensor& zeros = zeros_vec.back();
447         OP_REQUIRES_OK(
448             c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
449         functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
450                                              zeros.flat<T>());
451         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
452             const_cast<const Tensor&>(zeros).shaped<T, 2>(
453                 {1, zeros.NumElements()})));
454       }
455     }
456     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
457 
458 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
459     if (std::is_same<Device, Eigen::GpuDevice>::value) {
460       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
461       return;
462     }
463 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
464     ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
465   }
466 
467  private:
468   DataType element_dtype_;
469   PartialTensorShape element_shape_;
470 };
471 
472 template <typename Device, typename T>
473 class TensorListSplit : public OpKernel {
474  public:
TensorListSplit(OpKernelConstruction * c)475   TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {}
476 
Compute(OpKernelContext * c)477   void Compute(OpKernelContext* c) override {
478     Tensor* output_tensor;
479     AllocatorAttributes attr;
480     attr.set_on_host(true);
481     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
482     PartialTensorShape element_shape;
483     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
484     OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
485                 errors::InvalidArgument(
486                     "TensorListSplit requires element_shape to be at least of ",
487                     "rank 1, but saw: ", element_shape.DebugString()));
488     TensorList output_list;
489     const Tensor& input_tensor = c->input(0);
490     output_list.element_dtype = input_tensor.dtype();
491     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
492                 errors::InvalidArgument(
493                     "Tensor must be at least a vector, but saw shape: ",
494                     input_tensor.shape().DebugString()));
495     TensorShape tensor_shape_without_first_dim(input_tensor.shape());
496     tensor_shape_without_first_dim.RemoveDim(0);
497     PartialTensorShape element_shape_without_first_dim;
498     if (!element_shape.unknown_rank()) {
499       element_shape_without_first_dim =
500           PartialTensorShape(element_shape.dim_sizes());
501       element_shape_without_first_dim.RemoveDim(0);
502     }
503     OP_REQUIRES(c,
504                 element_shape_without_first_dim.IsCompatibleWith(
505                     tensor_shape_without_first_dim),
506                 errors::InvalidArgument(
507                     "tensor shape ", input_tensor.shape().DebugString(),
508                     " is not compatible with element_shape ",
509                     element_shape.DebugString()));
510     output_list.element_shape = element_shape;
511     const Tensor& lengths = c->input(2);
512     OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()),
513                 errors::InvalidArgument(
514                     "Expected lengths to be a vector, received shape: ",
515                     lengths.shape().DebugString()));
516     output_list.tensors().reserve(lengths.shape().dim_size(0));
517     int64 start = 0;
518     int64 end = 0;
519     for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
520       int64 length = lengths.vec<int64>()(i);
521       OP_REQUIRES(
522           c, length >= 0,
523           errors::InvalidArgument("Invalid value in lengths: ", length));
524       end = start + length;
525       OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0),
526                   errors::InvalidArgument("Attempting to slice [", start, ", ",
527                                           end, "] from tensor with length ",
528                                           input_tensor.shape().dim_size(0)));
529       Tensor tmp = input_tensor.Slice(start, end);
530       start = end;
531       // TODO(apassos) maybe not always align; but weird compiler bugs seem to
532       // prevent this.
533       Tensor aligned;
534       OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
535       aligned.flat<T>().device(c->eigen_device<Device>()) =
536           tmp.unaligned_flat<T>();
537       output_list.tensors().emplace_back(aligned);
538     }
539     OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
540                 errors::InvalidArgument(
541                     "Unused values in tensor. Length of tensor: ",
542                     input_tensor.shape().dim_size(0), " Values used: ", end));
543     output_tensor->scalar<Variant>()() = std::move(output_list);
544   }
545 };
546 
547 template <typename Device, typename T>
548 class TensorListGather : public OpKernel {
549  public:
550   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
551       ConstMatrixVector;
TensorListGather(OpKernelConstruction * c)552   explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
553     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
554   }
555 
Compute(OpKernelContext * c)556   void Compute(OpKernelContext* c) override {
557     const TensorList* tensor_list = nullptr;
558     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
559     OP_REQUIRES(
560         c, element_dtype_ == tensor_list->element_dtype,
561         errors::InvalidArgument(
562             "Invalid data types; op elements ", DataTypeString(element_dtype_),
563             " but list elements ", DataTypeString(tensor_list->element_dtype)));
564     const Tensor& indices = c->input(1);
565     PartialTensorShape partial_element_shape;
566     OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2,
567                                                &partial_element_shape));
568     OP_REQUIRES(
569         c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0,
570         errors::InvalidArgument("Tried to gather 0-elements from "
571                                 "a list with non-fully-defined shape: ",
572                                 partial_element_shape.DebugString()));
573 
574     // Check that `element_shape` input tensor is compatible with the shapes of
575     // element tensors.
576     if (!tensor_list->element_shape.IsFullyDefined()) {
577       for (int index = 0; index < indices.NumElements(); ++index) {
578         const int i = indices.flat<int32>()(index);
579         const Tensor& t = tensor_list->tensors()[i];
580         if (t.dtype() != DT_INVALID) {
581           PartialTensorShape tmp = partial_element_shape;
582           OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
583         }
584       }
585     }
586 
587     // Compute the shape of the output tensor by pre-pending the leading dim to
588     // the element_shape.
589     TensorShape element_shape;
590     OP_REQUIRES(
591         c, partial_element_shape.AsTensorShape(&element_shape),
592         errors::InvalidArgument("Tried to gather uninitialized tensors from a ",
593                                 "list with non-fully-defined element_shape: ",
594                                 partial_element_shape.DebugString()));
595     TensorShape output_shape = element_shape;
596     output_shape.InsertDim(0, indices.NumElements());
597     Tensor* output;
598     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
599     if (output->NumElements() == 0) {
600       return;
601     }
602 
603     ConstMatrixVector inputs_flat;
604     inputs_flat.reserve(indices.NumElements());
605     Tensor zeros;
606     for (int index = 0; index < indices.NumElements(); ++index) {
607       const int i = indices.flat<int32>()(index);
608       OP_REQUIRES(
609           c, i < tensor_list->tensors().size(),
610           errors::InvalidArgument("Index ", i, " out o range; list only has ",
611                                   tensor_list->tensors().size(), " elements."));
612       const Tensor& t = tensor_list->tensors()[i];
613       if (t.dtype() != DT_INVALID) {
614         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
615             t.shaped<T, 2>({1, t.NumElements()})));
616       } else {
617         if (!zeros.NumElements()) {
618           AllocatorAttributes attr;
619           if (element_dtype_ == DT_VARIANT) {
620             attr.set_on_host(true);
621           }
622           OP_REQUIRES_OK(
623               c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
624           functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
625                                                zeros.flat<T>());
626         }
627         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
628             const_cast<const Tensor&>(zeros).shaped<T, 2>(
629                 {1, zeros.NumElements()})));
630       }
631     }
632     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
633 
634 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
635     if (std::is_same<Device, Eigen::GpuDevice>::value) {
636       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
637       return;
638     }
639 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
640     ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
641   }
642 
643  private:
644   DataType element_dtype_;
645 };
646 
647 template <typename Device, typename T>
648 class TensorListFromTensor : public OpKernel {
649  public:
TensorListFromTensor(OpKernelConstruction * c)650   TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
651 
Compute(OpKernelContext * c)652   void Compute(OpKernelContext* c) override {
653     Tensor* output_tensor;
654     AllocatorAttributes attr;
655     attr.set_on_host(true);
656     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
657     PartialTensorShape element_shape;
658     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
659     TensorList output_list;
660     const Tensor& t = c->input(0);
661     output_list.element_dtype = t.dtype();
662     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
663                 errors::InvalidArgument(
664                     "Tensor must be at least a vector, but saw shape: ",
665                     t.shape().DebugString()));
666     TensorShape output_shape(t.shape());
667     output_shape.RemoveDim(0);
668     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
669                 errors::InvalidArgument(
670                     "Specified a list with shape ", element_shape.DebugString(),
671                     " from a tensor with shape ", output_shape.DebugString()));
672     output_list.element_shape = element_shape;
673     output_list.tensors().reserve(t.shape().dim_size(0));
674     for (int i = 0; i < t.shape().dim_size(0); ++i) {
675       Tensor tmp = t.Slice(i, i + 1);
676       TensorShape tmp_shape = tmp.shape();
677       tmp_shape.RemoveDim(0);
678       OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
679                   errors::Unknown("Unexpected shape error."));
680       // TODO(apassos) maybe not always align; but weird compiler bugs seem to
681       // prevent this.
682       Tensor aligned;
683       OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
684       aligned.flat<T>().device(c->eigen_device<Device>()) =
685           tmp.unaligned_flat<T>();
686       output_list.tensors().push_back(aligned);
687     }
688     output_tensor->scalar<Variant>()() = std::move(output_list);
689   }
690 };
691 
692 // Scatters values in `value` into `list`. Assumes that `indices` are valid.
693 template <typename Device, typename T>
Scatter(OpKernelContext * c,const Tensor & value,const Tensor & indices,TensorList * list)694 Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
695                TensorList* list) {
696   for (int index = 0; index < indices.NumElements(); ++index) {
697     const int i = indices.flat<int32>()(index);
698     Tensor tmp = value.Slice(index, index + 1);
699     TensorShape tmp_shape = tmp.shape();
700     tmp_shape.RemoveDim(0);
701     if (!tmp.CopyFrom(tmp, tmp_shape)) {
702       return errors::Unknown("Unexpected shape error.");
703     }
704     // TODO(apassos) maybe not always align; but weird compiler bugs seem to
705     // prevent this.
706     Tensor aligned;
707     TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
708     // TODO(apassos) do all slices in a single kernel invocation instead of
709     // many small ones.
710     aligned.flat<T>().device(c->eigen_device<Device>()) =
711         tmp.unaligned_flat<T>();
712     std::swap(list->tensors()[i], aligned);
713   }
714   return Status::OK();
715 }
716 
717 template <typename Device, typename T>
718 class TensorListScatterIntoExistingList : public OpKernel {
719  public:
TensorListScatterIntoExistingList(OpKernelConstruction * c)720   TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {}
721 
Compute(OpKernelContext * c)722   void Compute(OpKernelContext* c) override {
723     const TensorList* l = nullptr;
724     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
725     const Tensor& input_tensor = c->input(1);
726     const Tensor& indices = c->input(2);
727 
728     // Check that inputs are valid.
729     OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype,
730                 errors::InvalidArgument(
731                     "Invalid data types; input tensor type: ",
732                     DataTypeString(input_tensor.dtype()),
733                     " list element_type: ", DataTypeString(l->element_dtype)));
734     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
735                 errors::InvalidArgument(
736                     "Tensor must be at least a vector, but saw shape: ",
737                     input_tensor.shape().DebugString()));
738     OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()),
739                 errors::InvalidArgument(
740                     "Expected indices to be a vector, but received shape: ",
741                     indices.shape().DebugString()));
742     OP_REQUIRES(
743         c, indices.NumElements() == input_tensor.shape().dim_size(0),
744         errors::InvalidArgument(
745             "Expected len(indices) == tensor.shape[0], but saw: ",
746             indices.NumElements(), " vs. ", input_tensor.shape().dim_size(0)));
747 
748     // Resize the list if needed to accommodate all indices.
749     TensorList* output_list = nullptr;
750     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
751     const auto indices_vec = indices.vec<int32>();
752     int32 max_index =
753         (indices.NumElements() == 0)
754             ? -1
755             : *std::max_element(indices_vec.data(),
756                                 indices_vec.data() + indices.NumElements());
757     if (max_index + 1 > output_list->tensors().size()) {
758       output_list->tensors().resize(max_index + 1);
759     }
760 
761     // Scatter the values.
762     OP_REQUIRES_OK(c,
763                    Scatter<Device, T>(c, input_tensor, indices, output_list));
764   }
765 };
766 
767 template <typename Device, typename T>
768 class TensorListScatter : public OpKernel {
769  public:
TensorListScatter(OpKernelConstruction * c)770   TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
771 
Compute(OpKernelContext * c)772   void Compute(OpKernelContext* c) override {
773     Tensor* output_tensor;
774     AllocatorAttributes attr;
775     attr.set_on_host(true);
776     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
777     Tensor indices = c->input(1);
778     PartialTensorShape element_shape;
779     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
780     // TensorListScatterV2 passes the num_elements input, TensorListScatter does
781     // not.
782     int num_elements = c->num_inputs() >= 4 ? c->input(3).scalar<int>()() : -1;
783     OP_REQUIRES(c, num_elements >= -1,
784                 errors::InvalidArgument(
785                     "TensorListScatter expects num_elements >= -1, found: ",
786                     num_elements));
787     TensorList output_list;
788     const Tensor& input_tensor = c->input(0);
789     output_list.element_dtype = input_tensor.dtype();
790     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
791                 errors::InvalidArgument(
792                     "Tensor must be at least a vector, but saw shape: ",
793                     input_tensor.shape().DebugString()));
794     TensorShape output_shape(input_tensor.shape());
795     output_shape.RemoveDim(0);
796     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
797                 errors::InvalidArgument(
798                     "Specified a list with shape ", element_shape.DebugString(),
799                     " from a tensor with shape ", output_shape.DebugString()));
800     output_list.element_shape = element_shape;
801 
802     OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
803                 errors::InvalidArgument(
804                     "Invalid number of rows in input tensor. Expected: ",
805                     indices.NumElements(),
806                     " Actual: ", input_tensor.shape().dim_size(0)));
807 
808     // Validate indices and resize output_list.tensors to fit the highest index.
809     {
810       int highest_index = -1;
811       for (int index = 0; index < indices.NumElements(); ++index) {
812         const int i = indices.flat<int32>()(index);
813         OP_REQUIRES(
814             c, i >= 0,
815             errors::InvalidArgument(
816                 "Indices in TensorListScatter must all be non-negative."));
817         OP_REQUIRES(c, num_elements == -1 || i < num_elements,
818                     errors::InvalidArgument(
819                         "TensorListScatter: Trying to scatter at index ", i,
820                         " in list with size ", num_elements));
821         if (i > highest_index) {
822           highest_index = i;
823         }
824       }
825       output_list.tensors().resize(std::max(highest_index + 1, num_elements),
826                                    Tensor(DT_INVALID));
827     }
828 
829     OP_REQUIRES_OK(c,
830                    Scatter<Device, T>(c, input_tensor, indices, &output_list));
831     output_tensor->scalar<Variant>()() = std::move(output_list);
832   }
833 };
834 
835 template <typename Device>
TensorListBinaryAdd(OpKernelContext * c,const TensorList & a,const TensorList & b,TensorList * out)836 Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
837                            const TensorList& b, TensorList* out) {
838   if (a.element_dtype != b.element_dtype) {
839     return errors::InvalidArgument(
840         "Trying to add two lists of tensors of different dtypes. One is ",
841         DataTypeString(a.element_dtype), " and the other is ",
842         DataTypeString(b.element_dtype));
843   }
844   out->element_dtype = a.element_dtype;
845   if (!a.element_shape.IsCompatibleWith(b.element_shape)) {
846     return errors::InvalidArgument(
847         "Trying to add two lists of tensors with incompatible element shapes. "
848         "One is ",
849         a.element_shape.DebugString(), " and the other is ",
850         b.element_shape.DebugString());
851   }
852 
853   TF_RETURN_IF_ERROR(
854       a.element_shape.MergeWith(b.element_shape, &out->element_shape));
855   if (a.tensors().size() != b.tensors().size()) {
856     return errors::InvalidArgument(
857         "Trying to add two lists of tensors with different lengths. One is ",
858         a.tensors().size(), " and the other is ", b.tensors().size());
859   }
860   out->tensors().reserve(a.tensors().size());
861   for (int i = 0; i < a.tensors().size(); ++i) {
862     const Tensor& a_tensor = a.tensors()[i];
863     const Tensor& b_tensor = b.tensors()[i];
864     Tensor out_tensor;
865     TF_RETURN_IF_ERROR(
866         BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
867     out->tensors().push_back(out_tensor);
868   }
869   return Status::OK();
870 }
871 
872 template <typename Device>
TensorListZerosLike(OpKernelContext * c,const TensorList & x,TensorList * y)873 Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
874                            TensorList* y) {
875   y->element_dtype = x.element_dtype;
876   y->element_shape = x.element_shape;
877   y->tensors().reserve(x.tensors().size());
878   for (const Tensor& t : x.tensors()) {
879     Tensor out_tensor;
880     TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
881     y->tensors().emplace_back(out_tensor);
882   }
883   return Status::OK();
884 }
885 
886 template <typename Device, typename T>
887 class TensorListPushBackBatch : public OpKernel {
888  public:
TensorListPushBackBatch(OpKernelConstruction * c)889   explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
890     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
891   }
892 
Compute(OpKernelContext * c)893   void Compute(OpKernelContext* c) override {
894     const Tensor& input = c->input(1);
895     OP_REQUIRES(c, element_dtype_ == input.dtype(),
896                 errors::InvalidArgument("Invalid data types; list elements ",
897                                         DataTypeString(element_dtype_),
898                                         " but tried to append ",
899                                         DataTypeString(input.dtype())));
900     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
901                 errors::InvalidArgument(
902                     "Expected tensor to be at least a vector, but saw shape: ",
903                     input.shape().DebugString()));
904 
905     const TensorShape& tls_shape = c->input(0).shape();
906 
907     // For purposes of input forwarding, we want the least restrictive
908     // AllocatorAttributes possible.  If we need to allocate later,
909     // we'll request the DT_VARIANT be allocated on host.
910     AllocatorAttributes attr;
911 
912     std::unique_ptr<Tensor> tls_alias = c->forward_input(
913         0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
914         DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
915 
916     bool ok_to_alias = tls_alias != nullptr;
917     if (tls_alias && tls_alias->dtype() == DT_VARIANT &&
918         tls_alias->NumElements() > 0) {
919       auto alias_t = tls_alias->flat<Variant>();
920       for (int i = 0; i < tls_alias->NumElements(); ++i) {
921         TensorList* tl_i = alias_t(i).get<TensorList>();
922         if (tl_i == nullptr || !tl_i->RefCountIsOne()) {
923           ok_to_alias = false;
924           break;
925         }
926       }
927     }
928     const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0);
929 
930     OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
931                 errors::InvalidArgument(
932                     "Expected input_handles dtype to be Variant, but saw: ",
933                     DataTypeString(tls.dtype())));
934     OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
935                 errors::InvalidArgument(
936                     "Expected input_handles to be a vector, but saw shape: ",
937                     tls_shape.DebugString()));
938     const int64 batch_size = tls.NumElements();
939     OP_REQUIRES(c, input.dim_size(0) == batch_size,
940                 errors::InvalidArgument(
941                     "Expected tensor.shape[0] == input_handles.size, but saw ",
942                     input.dim_size(0), " vs. ", batch_size));
943     auto tls_t = tls.vec<Variant>();
944 
945     TensorShape input_element_shape = input.shape();
946     input_element_shape.RemoveDim(0);
947     std::vector<const TensorList*> tl_batch;
948     for (int64 b = 0; b < batch_size; ++b) {
949       const TensorList* l = tls_t(b).get<TensorList>();
950       OP_REQUIRES(c, l != nullptr,
951                   errors::InvalidArgument("Input handle at index ", b,
952                                           " is not a list. Saw: '",
953                                           tls_t(b).DebugString(), "'"));
954       OP_REQUIRES(
955           c, l->element_shape.IsCompatibleWith(input_element_shape),
956           errors::InvalidArgument(
957               "Tried to append a tensor with incompatible shape to a "
958               "list at index ",
959               b, ". Op element shape: ", input_element_shape.DebugString(),
960               " list shape: ", l->element_shape.DebugString()));
961       OP_REQUIRES(c, element_dtype_ == l->element_dtype,
962                   errors::InvalidArgument(
963                       "Invalid data type at index ", b, "; op elements ",
964                       DataTypeString(element_dtype_), " but list elements ",
965                       DataTypeString(l->element_dtype)));
966       tl_batch.push_back(l);
967     }
968 
969     Tensor* result;
970 
971     if (ok_to_alias) {
972       result = tls_alias.get();
973       c->set_output(0, *result);
974     } else {
975       // DT_VARIANT tensors always allocated on host.
976       AllocatorAttributes attr;
977       attr.set_on_host(true);
978       OP_REQUIRES_OK(
979           c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
980     }
981 
982     if (batch_size == 0) {
983       return;
984     }
985 
986     auto input_t = input.flat_outer_dims<T, 2>();
987     auto result_t = result->vec<Variant>();
988 
989     for (int64 b = 0; b < batch_size; ++b) {
990       if (!ok_to_alias) {
991         result_t(b) = tl_batch[b]->Copy();
992       }
993       TensorList* output = result_t(b).get<TensorList>();
994       DCHECK(output != nullptr);
995       Tensor* frame;
996       PersistentTensor tmp;
997       OP_REQUIRES_OK(c, c->allocate_persistent(
998                             element_dtype_, input_element_shape, &tmp, &frame));
999       if (input_element_shape.num_elements() > 0) {
1000         auto frame_t = frame->flat<T>();
1001         frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
1002       }
1003       output->tensors().push_back(std::move(*frame));
1004     }
1005   }
1006 
1007  private:
1008   DataType element_dtype_;
1009 };
1010 
1011 }  // namespace tensorflow
1012 
1013 #endif  // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
1014