• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
16 #define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
17 
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 #include "tensorflow/core/kernels/concat_lib.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/tensor_list.h"
33 #include "tensorflow/core/kernels/tensor_list_util.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/refcount.h"
37 #include "tensorflow/core/lib/gtl/array_slice.h"
38 #include "tensorflow/core/platform/platform.h"
39 #include "tensorflow/core/util/tensor_ops_util.h"
40 #include "tensorflow/core/util/util.h"
41 
42 // stream.h isn't available in some platforms such as Android, iOS, and
43 // ChromiumOS. Only include it for platforms that PluggableDevice is tested on.
44 #if !defined(PLUGGABLE_DEVICE_SUPPORTED) &&                              \
45     (__x86_64__ || __i386__ || defined(__APPLE__) || defined(_WIN32)) && \
46     !defined(ANDROID) && !defined(__ANDROID__) && !TARGET_OS_IOS &&      \
47     !defined(PLATFORM_CHROMIUMOS)
48 #define PLUGGABLE_DEVICE_SUPPORTED
49 #endif
50 
51 #ifdef PLUGGABLE_DEVICE_SUPPORTED
52 #include "tensorflow/stream_executor/stream.h"
53 #endif
54 
55 namespace tensorflow {
56 
57 typedef Eigen::ThreadPoolDevice CPUDevice;
58 
59 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
60 
61 Status GetElementShapeFromInput(OpKernelContext* c,
62                                 const TensorList& tensor_list, int index,
63                                 PartialTensorShape* element_shape);
64 
65 Status GetInputList(OpKernelContext* c, int index, const TensorList** list);
66 
67 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index,
68                                    int32_t output_index,
69                                    const TensorList& input_list,
70                                    TensorList** output_list);
71 
72 // TODO(penporn): Move this to a proper place.
IsPluggableDevice(OpKernelContext * c)73 inline bool IsPluggableDevice(OpKernelContext* c) {
74   return c->op_device_context() && c->op_device_context()->IsPluggableDevice();
75 }
76 
77 template <typename Device, typename T>
SetZero(OpKernelContext * ctx,Tensor & tensor)78 inline void SetZero(OpKernelContext* ctx, Tensor& tensor) {
79 #ifdef PLUGGABLE_DEVICE_SUPPORTED
80   if (IsPluggableDevice(ctx)) {
81     auto ptr =
82         se::DeviceMemoryBase(tensor.flat<T>().data(), tensor.TotalBytes());
83     auto stream = ctx->op_device_context()->stream();
84     auto result = stream->ThenMemZero(&ptr, tensor.TotalBytes()).ok();
85     DCHECK_EQ(true, result);
86   } else {
87 #endif  // PLUGGABLE_DEVICE_SUPPORTED
88     functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
89                                          tensor.flat<T>());
90 #ifdef PLUGGABLE_DEVICE_SUPPORTED
91   }
92 #endif  // PLUGGABLE_DEVICE_SUPPORTED
93 }
94 
95 template <typename T>
CopyTensorPluggableDevice(OpKernelContext * ctx,Tensor & src,Tensor & dst)96 inline void CopyTensorPluggableDevice(OpKernelContext* ctx, Tensor& src,
97                                       Tensor& dst) {
98 #ifdef PLUGGABLE_DEVICE_SUPPORTED
99   auto src_t = src.unaligned_flat<T>();
100   auto dst_t = dst.flat<T>();
101   DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v()));
102   auto src_ptr = se::DeviceMemoryBase(src_t.data(), src.TotalBytes());
103   auto dst_ptr = se::DeviceMemoryBase(dst_t.data(), dst.TotalBytes());
104   auto stream = ctx->op_device_context()->stream();
105   auto result = stream->ThenMemcpy(&dst_ptr, src_ptr, src.TotalBytes()).ok();
106   DCHECK_EQ(true, result);
107 #else
108   LOG(FATAL)  // Crash OK.
109       << "PluggableDevice is not supported on this platform.";
110 #endif  // PLUGGABLE_DEVICE_SUPPORTED
111 }
112 
113 template <typename Device, typename T>
CopyTensor(OpKernelContext * ctx,Tensor & src,Tensor & dst)114 inline void CopyTensor(OpKernelContext* ctx, Tensor& src, Tensor& dst) {
115   auto src_t = src.unaligned_flat<T>();
116   auto dst_t = dst.flat<T>();
117   dst_t.device(ctx->eigen_device<Device>()) = src_t;
118 }
119 
120 template <typename T>
ConcatPluggableDevice(OpKernelContext * context,const std::vector<std::unique_ptr<typename TTypes<T,2>::ConstMatrix>> & inputs,typename TTypes<T,2>::Matrix * output)121 void ConcatPluggableDevice(
122     OpKernelContext* context,
123     const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&
124         inputs,
125     typename TTypes<T, 2>::Matrix* output) {
126 #ifdef PLUGGABLE_DEVICE_SUPPORTED
127   DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum<T>::v()));
128 
129   se::Stream* stream = context->op_device_context()->stream();
130 
131   size_t num_inputs = inputs.size();
132   std::vector<ptrdiff_t> sizes;
133   sizes.reserve(num_inputs);
134   int64 row_size = 0;
135   for (const auto& input : inputs) {
136     sizes.push_back(input->dimension(1));
137     row_size += sizes.back();
138   }
139 
140   T* out = &(*output)(0, 0);
141   std::vector<const T*> inp;
142   inp.reserve(num_inputs);
143   for (const auto& input : inputs) {
144     inp.push_back(&(*input)(0, 0));
145   }
146   const int64 dim0 = output->dimension(0);
147   for (int64 i = 0; i < dim0; ++i) {
148     for (int64 j = 0; j < num_inputs; ++j) {
149       auto size = sizes[j];
150       se::DeviceMemoryBase out_base{out, size * sizeof(T)};
151       se::DeviceMemoryBase inp_base{const_cast<T*>(inp[j]), size * sizeof(T)};
152       stream->ThenMemcpy(&out_base, inp_base, size * sizeof(T));
153       out += size;
154       inp[j] += size;
155     }
156   }
157 #else
158   LOG(FATAL)  // Crash OK.
159       << "PluggableDevice is not supported on this platform.";
160 #endif  // PLUGGABLE_DEVICE_SUPPORTED
161 }
162 
163 template <typename Device, typename T>
164 class TensorListStack : public OpKernel {
165  public:
166   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
167       ConstMatrixVector;
TensorListStack(OpKernelConstruction * c)168   explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) {
169     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
170     OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_));
171   }
172 
Compute(OpKernelContext * c)173   void Compute(OpKernelContext* c) override {
174     const TensorList* tensor_list = nullptr;
175     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
176     OP_REQUIRES(
177         c, element_dtype_ == tensor_list->element_dtype,
178         errors::InvalidArgument(
179             "Invalid data types; op elements ", DataTypeString(element_dtype_),
180             " but list elements ", DataTypeString(tensor_list->element_dtype)));
181     if (num_elements_ != -1) {
182       OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_,
183                   errors::InvalidArgument(
184                       "Operation expected a list with ", num_elements_,
185                       " elements but got a list with ",
186                       tensor_list->tensors().size(), " elements."));
187     }
188     PartialTensorShape partial_element_shape;
189     OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
190                                                &partial_element_shape));
191     OP_REQUIRES(
192         c,
193         partial_element_shape.IsFullyDefined() ||
194             !tensor_list->tensors().empty(),
195         errors::InvalidArgument("Tried to stack elements of an empty ",
196                                 "list with non-fully-defined element_shape: ",
197                                 partial_element_shape.DebugString()));
198 
199     // Check that `element_shape` input tensor is compatible with the shapes of
200     // element tensors.
201     if (!tensor_list->element_shape.IsFullyDefined()) {
202       for (int i = 0; i < tensor_list->tensors().size(); ++i) {
203         const Tensor& t = tensor_list->tensors()[i];
204         if (t.dtype() != DT_INVALID) {
205           PartialTensorShape tmp = partial_element_shape;
206           OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
207         }
208       }
209     }
210 
211     // Compute the shape of the output tensor by pre-pending the leading dim to
212     // the element_shape.
213     TensorShape element_shape;
214     OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape),
215                 errors::InvalidArgument(
216                     "Tried to stack list which only contains uninitialized ",
217                     "tensors and has a non-fully-defined element_shape: ",
218                     partial_element_shape.DebugString()));
219     TensorShape output_shape = element_shape;
220     output_shape.InsertDim(0, tensor_list->tensors().size());
221     Tensor* output;
222     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
223     if (output->NumElements() == 0) {
224       return;
225     }
226 
227     ConstMatrixVector inputs_flat;
228     inputs_flat.reserve(tensor_list->tensors().size());
229     Tensor zeros;
230     for (const auto& t : tensor_list->tensors()) {
231       if (t.dtype() != DT_INVALID) {
232         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
233             t.shaped<T, 2>({1, t.NumElements()})));
234       } else {
235         if (!zeros.NumElements()) {
236           AllocatorAttributes attr;
237           if (element_dtype_ == DT_VARIANT) {
238             attr.set_on_host(true);
239           }
240           OP_REQUIRES_OK(
241               c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
242           SetZero<Device, T>(c, zeros);
243         }
244         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
245             const_cast<const Tensor&>(zeros).shaped<T, 2>(
246                 {1, zeros.NumElements()})));
247       }
248     }
249     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
250 
251 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
252     if (std::is_same<Device, Eigen::GpuDevice>::value) {
253       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
254       return;
255     }
256 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
257     if (IsPluggableDevice(c)) {
258       ConcatPluggableDevice<T>(c, inputs_flat, &output_flat);
259     } else {
260       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
261     }
262   }
263 
264  private:
265   int num_elements_;
266   DataType element_dtype_;
267 };
268 
269 template <typename Device, typename T>
270 class TensorListGetItem : public OpKernel {
271  public:
TensorListGetItem(OpKernelConstruction * c)272   explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
273     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
274   }
275 
Compute(OpKernelContext * c)276   void Compute(OpKernelContext* c) override {
277     const TensorList* l = nullptr;
278     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
279     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
280                 errors::InvalidArgument("Invalid data types; op elements ",
281                                         DataTypeString(element_dtype_),
282                                         " but list elements ",
283                                         DataTypeString(l->element_dtype)));
284     int32_t index = c->input(1).scalar<int32>()();
285     OP_REQUIRES(c, index < l->tensors().size(),
286                 errors::InvalidArgument("Trying to access element ", index,
287                                         " in a list with ", l->tensors().size(),
288                                         " elements."));
289     if (l->tensors()[index].dtype() != DT_INVALID) {
290       c->set_output(0, l->tensors()[index]);
291     } else {
292       PartialTensorShape partial_element_shape;
293       OP_REQUIRES_OK(
294           c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape));
295       TensorShape element_shape;
296       // If l->element_shape and the element_shape input are both not fully
297       // defined, try to infer the shape from other list elements. This requires
298       // that all initialized list elements have the same shape.
299       // NOTE(srbs): This might be a performance bottleneck since we are
300       // iterating over the entire list here. This is necessary for feature
301       // parity with TensorArray.read. TensorArray has a mode in which all
302       // elements are required to be of the same shape, TensorList does not.
303       // In that mode TensorArray sets the array's element_shape on the first
304       // write call. We could do something similar here if needed.
305       if (!partial_element_shape.IsFullyDefined()) {
306         for (const Tensor& t : l->tensors()) {
307           if (t.dtype() != DT_INVALID) {
308             PartialTensorShape tmp = partial_element_shape;
309             OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
310           }
311         }
312       }
313       OP_REQUIRES(
314           c, partial_element_shape.AsTensorShape(&element_shape),
315           errors::InvalidArgument("Trying to read an uninitialized tensor but ",
316                                   "element_shape is not fully defined: ",
317                                   partial_element_shape.DebugString(),
318                                   " and no list element is set."));
319       Tensor* result;
320       AllocatorAttributes attr;
321       if (element_dtype_ == DT_VARIANT) {
322         attr.set_on_host(true);
323       }
324       OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr));
325       SetZero<Device, T>(c, *result);
326     }
327   }
328 
329  private:
330   DataType element_dtype_;
331 };
332 
333 template <typename Device, typename T>
334 class TensorListPopBack : public OpKernel {
335  public:
TensorListPopBack(OpKernelConstruction * c)336   explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
337     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
338   }
339 
Compute(OpKernelContext * c)340   void Compute(OpKernelContext* c) override {
341     const TensorList* l = nullptr;
342     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
343     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
344                 errors::InvalidArgument("Invalid data types; op elements ",
345                                         DataTypeString(element_dtype_),
346                                         " but list elements ",
347                                         DataTypeString(l->element_dtype)));
348 
349     OP_REQUIRES(c, !l->tensors().empty(),
350                 errors::InvalidArgument("Trying to pop from an empty list."));
351 
352     const Tensor& t = l->tensors().back();
353     if (t.dtype() != DT_INVALID) {
354       c->set_output(1, t);
355     } else {
356       PartialTensorShape partial_element_shape;
357       OP_REQUIRES_OK(
358           c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape));
359       TensorShape element_shape;
360       OP_REQUIRES(
361           c, partial_element_shape.AsTensorShape(&element_shape),
362           errors::InvalidArgument("Trying to read an uninitialized tensor but ",
363                                   "element_shape is not fully defined.",
364                                   partial_element_shape.DebugString()));
365       Tensor* result;
366       AllocatorAttributes attr;
367       if (element_dtype_ == DT_VARIANT) {
368         attr.set_on_host(true);
369       }
370       OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr));
371       SetZero<Device, T>(c, *result);
372     }
373 
374     TensorList* output_list = nullptr;
375     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
376     output_list->tensors().pop_back();
377   }
378 
379  private:
380   DataType element_dtype_;
381 };
382 
383 template <typename Device, typename T>
384 class TensorListConcat : public OpKernel {
385  public:
386   using ConstMatrixVector =
387       std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>;
TensorListConcat(OpKernelConstruction * c)388   explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) {
389     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
390     if (c->HasAttr("element_shape")) {
391       OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape_));
392     }
393   }
394 
Compute(OpKernelContext * c)395   void Compute(OpKernelContext* c) override {
396     PartialTensorShape element_shape_except_first_dim;
397     if (!element_shape_.unknown_rank()) {
398       element_shape_except_first_dim = PartialTensorShape(
399           gtl::ArraySlice<int64_t>(element_shape_.dim_sizes()).subspan(1));
400     }
401     // Check that the input Variant tensor is indeed a TensorList and has the
402     // correct element type.
403     const TensorList* tensor_list = nullptr;
404     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
405     OP_REQUIRES(
406         c, element_dtype_ == tensor_list->element_dtype,
407         errors::InvalidArgument(
408             "Invalid data types; op elements ", DataTypeString(element_dtype_),
409             " but list elements ", DataTypeString(tensor_list->element_dtype)));
410     // The leading dimension of all list elements if they are all the same.
411     // This is used as the leading dim of uninitialized tensors in the list
412     // if leading_dims is not provided.
413     int64_t first_dim = -1;
414     if (c->num_inputs() > 1) {
415       // TensorListConcatV2
416       PartialTensorShape element_shape;
417       OP_REQUIRES_OK(
418           c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape));
419       OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
420                   errors::InvalidArgument(
421                       "Concat requires elements to be at least vectors, ",
422                       "found scalars instead."));
423       // Split `element_shape` into `first_dim` and
424       // `element_shape_except_first_dim`.
425       first_dim = element_shape.dim_size(0);
426       element_shape_except_first_dim = element_shape;
427       element_shape_except_first_dim.RemoveDim(0);
428     }
429     // If the TensorList is empty, element_shape_except_first_dim must be fully
430     // defined.
431     OP_REQUIRES(c,
432                 !tensor_list->tensors().empty() ||
433                     element_shape_except_first_dim.IsFullyDefined(),
434                 errors::InvalidArgument(
435                     "All except the first dimension must be fully defined ",
436                     "when concating an empty tensor list. element_shape: ",
437                     element_shape_except_first_dim.DebugString()));
438     // 1. Check that `element_shape_except_first_dim` input tensor is
439     //    compatible with the shapes of element tensors.
440     // 2. Check that the elements have the same shape except the first dim.
441     // 3. If `first_dim` is known, check that it is compatible with the leading
442     //    dims of all elements.
443     // 4. If `first_dim` is unknown (-1), check whether all initialized
444     //    elements have the same leading dim and if so set `first_dim` to that
445     //    value.
446     if (!tensor_list->element_shape.IsFullyDefined()) {
447       bool check_dim = (first_dim == -1);
448       int64_t inferred_first_dim = first_dim;
449       for (int i = 0; i < tensor_list->tensors().size(); ++i) {
450         const Tensor& t = tensor_list->tensors()[i];
451         if (t.dtype() != DT_INVALID) {
452           PartialTensorShape tmp = element_shape_except_first_dim;
453           OP_REQUIRES(
454               c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
455               errors::InvalidArgument("Concat saw a scalar shape at index ", i,
456                                       " but requires at least vectors."));
457           TensorShape shape_except_first_dim = TensorShape(
458               gtl::ArraySlice<int64_t>(t.shape().dim_sizes()).subspan(1));
459           OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
460                                           &element_shape_except_first_dim));
461           OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0),
462                       errors::InvalidArgument(
463                           "First entry of element_shape input does not match ",
464                           "the first dim of list element at index: ", i,
465                           " Expected: ", first_dim,
466                           " Actual: ", t.shape().dim_size(0)));
467           if (check_dim) {
468             if (inferred_first_dim == -1) {
469               inferred_first_dim = t.shape().dim_size(0);
470             } else if (inferred_first_dim != t.shape().dim_size(0)) {
471               inferred_first_dim = -1;
472               check_dim = false;
473             }
474           }
475         }
476       }
477       first_dim = inferred_first_dim;
478     }
479     TensorShape output_shape;
480     OP_REQUIRES(c, element_shape_except_first_dim.AsTensorShape(&output_shape),
481                 errors::InvalidArgument(
482                     "Trying to concat list with only uninitialized tensors ",
483                     "but element_shape_except_first_dim is not fully defined: ",
484                     element_shape_except_first_dim.DebugString()));
485     // Build the lengths_tensor and leading dim of the output tensor by
486     // iterating over all element tensors.
487     Tensor* lengths_tensor = nullptr;
488     OP_REQUIRES_OK(c, c->allocate_output(1,
489                                          TensorShape({static_cast<int64_t>(
490                                              tensor_list->tensors().size())}),
491                                          &lengths_tensor));
492     auto lengths_tensor_vec = lengths_tensor->vec<int64_t>();
493     int64_t leading_dim = 0;
494     for (size_t i = 0; i < tensor_list->tensors().size(); i++) {
495       int64_t dim;
496       if (tensor_list->tensors()[i].dtype() != DT_INVALID) {
497         dim = tensor_list->tensors()[i].shape().dim_size(0);
498       } else {
499         // If leading_dims is not provided or does not contain an entry for
500         // index i use the inferred `first_dim` if set.
501         if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) &&
502             first_dim != -1) {
503           dim = first_dim;
504         } else {
505           OP_REQUIRES(c, c->num_inputs() > 2,
506                       errors::InvalidArgument(
507                           "Concating lists with uninitialized tensors is not ",
508                           "supported in this version of TensorListConcat. ",
509                           "Consider updating your GraphDef to run the newer ",
510                           "version."));
511           OP_REQUIRES(c, i < c->input(2).NumElements(),
512                       errors::InvalidArgument(
513                           "List contains uninitialized tensor at index ", i,
514                           " but leading_dims has only ",
515                           c->input(2).NumElements(), " elements."));
516           dim = c->input(2).vec<int64_t>()(i);
517         }
518       }
519       leading_dim += dim;
520       lengths_tensor_vec(i) = dim;
521     }
522     output_shape.InsertDim(0, leading_dim);
523     Tensor* output;
524     // Allocate the output tensor and fill it up with the concated element
525     // tensors.
526     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
527     if (output->NumElements() == 0) {
528       return;
529     }
530 
531     ConstMatrixVector inputs_flat;
532     inputs_flat.reserve(tensor_list->tensors().size());
533     // Store the zeros tensors in a vector to prevent them from being GC'ed till
534     // concat is complete.
535     std::vector<Tensor> zeros_vec;
536     for (int i = 0; i < tensor_list->tensors().size(); i++) {
537       const Tensor& element_tensor = tensor_list->tensors()[i];
538       if (element_tensor.dtype() != DT_INVALID) {
539         if (element_tensor.NumElements() > 0) {
540           inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
541               element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
542         }
543       } else {
544         AllocatorAttributes attr;
545         if (element_dtype_ == DT_VARIANT) {
546           attr.set_on_host(true);
547         }
548         TensorShape element_shape = output_shape;
549         element_shape.set_dim(0, lengths_tensor_vec(i));
550         zeros_vec.emplace_back();
551         Tensor& zeros = zeros_vec.back();
552         OP_REQUIRES_OK(
553             c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
554         SetZero<Device, T>(c, zeros);
555         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
556             const_cast<const Tensor&>(zeros).shaped<T, 2>(
557                 {1, zeros.NumElements()})));
558       }
559     }
560     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
561 
562 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
563     if (std::is_same<Device, Eigen::GpuDevice>::value) {
564       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
565       return;
566     }
567 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
568     if (IsPluggableDevice(c)) {
569       ConcatPluggableDevice<T>(c, inputs_flat, &output_flat);
570     } else {
571       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
572     }
573   }
574 
575  private:
576   DataType element_dtype_;
577   PartialTensorShape element_shape_;
578 };
579 
580 template <typename Device, typename T>
581 class TensorListSplit : public OpKernel {
582  public:
TensorListSplit(OpKernelConstruction * c)583   TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {}
584 
Compute(OpKernelContext * c)585   void Compute(OpKernelContext* c) override {
586     Tensor* output_tensor;
587     AllocatorAttributes attr;
588     attr.set_on_host(true);
589     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
590     PartialTensorShape element_shape;
591     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
592     OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
593                 errors::InvalidArgument(
594                     "TensorListSplit requires element_shape to be at least of ",
595                     "rank 1, but saw: ", element_shape.DebugString()));
596     TensorList output_list;
597     const Tensor& input_tensor = c->input(0);
598     output_list.element_dtype = input_tensor.dtype();
599     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
600                 errors::InvalidArgument(
601                     "Tensor must be at least a vector, but saw shape: ",
602                     input_tensor.shape().DebugString()));
603     TensorShape tensor_shape_without_first_dim(input_tensor.shape());
604     tensor_shape_without_first_dim.RemoveDim(0);
605     PartialTensorShape element_shape_without_first_dim;
606     if (!element_shape.unknown_rank()) {
607       element_shape_without_first_dim =
608           PartialTensorShape(element_shape.dim_sizes());
609       element_shape_without_first_dim.RemoveDim(0);
610     }
611     OP_REQUIRES(c,
612                 element_shape_without_first_dim.IsCompatibleWith(
613                     tensor_shape_without_first_dim),
614                 errors::InvalidArgument(
615                     "tensor shape ", input_tensor.shape().DebugString(),
616                     " is not compatible with element_shape ",
617                     element_shape.DebugString()));
618     output_list.element_shape = element_shape;
619     const Tensor& lengths = c->input(2);
620     OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()),
621                 errors::InvalidArgument(
622                     "Expected lengths to be a vector, received shape: ",
623                     lengths.shape().DebugString()));
624     output_list.tensors().reserve(lengths.shape().dim_size(0));
625 
626     const auto copy_tensor = IsPluggableDevice(c)
627                                  ? &CopyTensorPluggableDevice<T>
628                                  : &CopyTensor<Device, T>;
629 
630     int64_t start = 0;
631     int64_t end = 0;
632     for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
633       int64_t length = lengths.vec<int64_t>()(i);
634       OP_REQUIRES(
635           c, length >= 0,
636           errors::InvalidArgument("Invalid value in lengths: ", length));
637       end = start + length;
638       OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0),
639                   errors::InvalidArgument("Attempting to slice [", start, ", ",
640                                           end, "] from tensor with length ",
641                                           input_tensor.shape().dim_size(0)));
642       Tensor tmp = input_tensor.Slice(start, end);
643       start = end;
644       // TODO(apassos) maybe not always align; but weird compiler bugs seem to
645       // prevent this.
646       Tensor aligned;
647       OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
648       copy_tensor(c, tmp, aligned);
649       output_list.tensors().emplace_back(aligned);
650     }
651     OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
652                 errors::InvalidArgument(
653                     "Unused values in tensor. Length of tensor: ",
654                     input_tensor.shape().dim_size(0), " Values used: ", end));
655     output_tensor->scalar<Variant>()() = std::move(output_list);
656   }
657 };
658 
659 template <typename Device, typename T>
660 class TensorListGather : public OpKernel {
661  public:
662   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
663       ConstMatrixVector;
TensorListGather(OpKernelConstruction * c)664   explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
665     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
666   }
667 
Compute(OpKernelContext * c)668   void Compute(OpKernelContext* c) override {
669     const TensorList* tensor_list = nullptr;
670     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
671     OP_REQUIRES(
672         c, element_dtype_ == tensor_list->element_dtype,
673         errors::InvalidArgument(
674             "Invalid data types; op elements ", DataTypeString(element_dtype_),
675             " but list elements ", DataTypeString(tensor_list->element_dtype)));
676     const Tensor& indices = c->input(1);
677     PartialTensorShape partial_element_shape;
678     OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2,
679                                                &partial_element_shape));
680     OP_REQUIRES(
681         c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0,
682         errors::InvalidArgument("Tried to gather 0-elements from "
683                                 "a list with non-fully-defined shape: ",
684                                 partial_element_shape.DebugString()));
685 
686     // Check that `element_shape` input tensor is compatible with the shapes of
687     // element tensors.
688     if (!tensor_list->element_shape.IsFullyDefined()) {
689       for (int index = 0; index < indices.NumElements(); ++index) {
690         const int i = indices.flat<int32>()(index);
691         const Tensor& t = tensor_list->tensors()[i];
692         if (t.dtype() != DT_INVALID) {
693           PartialTensorShape tmp = partial_element_shape;
694           OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
695         }
696       }
697     }
698 
699     // Compute the shape of the output tensor by pre-pending the leading dim to
700     // the element_shape.
701     TensorShape element_shape;
702     OP_REQUIRES(
703         c, partial_element_shape.AsTensorShape(&element_shape),
704         errors::InvalidArgument("Tried to gather uninitialized tensors from a ",
705                                 "list with non-fully-defined element_shape: ",
706                                 partial_element_shape.DebugString()));
707     TensorShape output_shape = element_shape;
708     output_shape.InsertDim(0, indices.NumElements());
709     Tensor* output;
710     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
711     if (output->NumElements() == 0) {
712       return;
713     }
714 
715     ConstMatrixVector inputs_flat;
716     inputs_flat.reserve(indices.NumElements());
717     Tensor zeros;
718     for (int index = 0; index < indices.NumElements(); ++index) {
719       const int i = indices.flat<int32>()(index);
720       OP_REQUIRES(
721           c, i < tensor_list->tensors().size(),
722           errors::InvalidArgument("Index ", i, " out o range; list only has ",
723                                   tensor_list->tensors().size(), " elements."));
724       const Tensor& t = tensor_list->tensors()[i];
725       if (t.dtype() != DT_INVALID) {
726         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
727             t.shaped<T, 2>({1, t.NumElements()})));
728       } else {
729         if (!zeros.NumElements()) {
730           AllocatorAttributes attr;
731           if (element_dtype_ == DT_VARIANT) {
732             attr.set_on_host(true);
733           }
734           OP_REQUIRES_OK(
735               c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
736           SetZero<Device, T>(c, zeros);
737         }
738         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
739             const_cast<const Tensor&>(zeros).shaped<T, 2>(
740                 {1, zeros.NumElements()})));
741       }
742     }
743     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
744 
745 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
746     if (std::is_same<Device, Eigen::GpuDevice>::value) {
747       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
748       return;
749     }
750 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
751     if (IsPluggableDevice(c)) {
752       ConcatPluggableDevice<T>(c, inputs_flat, &output_flat);
753     } else {
754       ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
755     }
756   }
757 
758  private:
759   DataType element_dtype_;
760 };
761 
762 template <typename Device, typename T>
763 class TensorListFromTensor : public OpKernel {
764  public:
TensorListFromTensor(OpKernelConstruction * c)765   TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
766 
Compute(OpKernelContext * c)767   void Compute(OpKernelContext* c) override {
768     Tensor* output_tensor;
769     AllocatorAttributes attr;
770     attr.set_on_host(true);
771     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
772     PartialTensorShape element_shape;
773     OP_REQUIRES(
774         c, !TensorShapeUtils::IsMatrixOrHigher(c->input(1).shape()),
775         errors::InvalidArgument(
776             "TensorListFromTensor: element_shape must be at most rank 1 but ",
777             "has the shape of ", c->input(1).shape().DebugString()));
778     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
779     TensorList output_list;
780     const Tensor& t = c->input(0);
781     output_list.element_dtype = t.dtype();
782     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
783                 errors::InvalidArgument(
784                     "Tensor must be at least a vector, but saw shape: ",
785                     t.shape().DebugString()));
786     TensorShape output_shape(t.shape());
787     output_shape.RemoveDim(0);
788     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
789                 errors::InvalidArgument(
790                     "Specified a list with shape ", element_shape.DebugString(),
791                     " from a tensor with shape ", output_shape.DebugString()));
792     output_list.element_shape = element_shape;
793     output_list.tensors().reserve(t.shape().dim_size(0));
794 
795     const auto copy_tensor = IsPluggableDevice(c)
796                                  ? &CopyTensorPluggableDevice<T>
797                                  : &CopyTensor<Device, T>;
798 
799     for (int i = 0; i < t.shape().dim_size(0); ++i) {
800       Tensor tmp = t.Slice(i, i + 1);
801       TensorShape tmp_shape = tmp.shape();
802       tmp_shape.RemoveDim(0);
803       OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
804                   errors::Unknown("Unexpected shape error."));
805       // TODO(apassos) maybe not always align; but weird compiler bugs seem to
806       // prevent this.
807       Tensor aligned;
808       OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
809       copy_tensor(c, tmp, aligned);
810       output_list.tensors().push_back(aligned);
811     }
812     output_tensor->scalar<Variant>()() = std::move(output_list);
813   }
814 };
815 
816 // Scatters values in `value` into `list`. Assumes that `indices` are valid.
817 template <typename Device, typename T>
Scatter(OpKernelContext * c,const Tensor & value,const Tensor & indices,TensorList * list)818 Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
819                TensorList* list) {
820   const auto copy_tensor = IsPluggableDevice(c) ? &CopyTensorPluggableDevice<T>
821                                                 : &CopyTensor<Device, T>;
822   for (int index = 0; index < indices.NumElements(); ++index) {
823     const int i = indices.flat<int32>()(index);
824     Tensor tmp = value.Slice(index, index + 1);
825     TensorShape tmp_shape = tmp.shape();
826     tmp_shape.RemoveDim(0);
827     if (!tmp.CopyFrom(tmp, tmp_shape)) {
828       return errors::Unknown("Unexpected shape error.");
829     }
830     // TODO(apassos) maybe not always align; but weird compiler bugs seem to
831     // prevent this.
832     Tensor aligned;
833     TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
834     // TODO(apassos) do all slices in a single kernel invocation instead of
835     // many small ones.
836     copy_tensor(c, tmp, aligned);
837     std::swap(list->tensors()[i], aligned);
838   }
839   return OkStatus();
840 }
841 
842 template <typename Device, typename T>
843 class TensorListScatterIntoExistingList : public OpKernel {
844  public:
TensorListScatterIntoExistingList(OpKernelConstruction * c)845   TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {}
846 
Compute(OpKernelContext * c)847   void Compute(OpKernelContext* c) override {
848     const TensorList* l = nullptr;
849     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
850     const Tensor& input_tensor = c->input(1);
851     const Tensor& indices = c->input(2);
852 
853     // Check that inputs are valid.
854     OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype,
855                 errors::InvalidArgument(
856                     "Invalid data types; input tensor type: ",
857                     DataTypeString(input_tensor.dtype()),
858                     " list element_type: ", DataTypeString(l->element_dtype)));
859     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
860                 errors::InvalidArgument(
861                     "Tensor must be at least a vector, but saw shape: ",
862                     input_tensor.shape().DebugString()));
863     OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()),
864                 errors::InvalidArgument(
865                     "Expected indices to be a vector, but received shape: ",
866                     indices.shape().DebugString()));
867     OP_REQUIRES(
868         c, indices.NumElements() == input_tensor.shape().dim_size(0),
869         errors::InvalidArgument(
870             "Expected len(indices) == tensor.shape[0], but saw: ",
871             indices.NumElements(), " vs. ", input_tensor.shape().dim_size(0)));
872 
873     // Resize the list if needed to accommodate all indices.
874     TensorList* output_list = nullptr;
875     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
876     const auto indices_vec = indices.vec<int32>();
877     int32_t max_index =
878         (indices.NumElements() == 0)
879             ? -1
880             : *std::max_element(indices_vec.data(),
881                                 indices_vec.data() + indices.NumElements());
882     if (max_index + 1 > output_list->tensors().size()) {
883       output_list->tensors().resize(max_index + 1);
884     }
885 
886     // Scatter the values.
887     OP_REQUIRES_OK(c,
888                    Scatter<Device, T>(c, input_tensor, indices, output_list));
889   }
890 };
891 
892 template <typename Device, typename T>
893 class TensorListScatter : public OpKernel {
894  public:
TensorListScatter(OpKernelConstruction * c)895   TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
896 
Compute(OpKernelContext * c)897   void Compute(OpKernelContext* c) override {
898     Tensor* output_tensor;
899     AllocatorAttributes attr;
900     attr.set_on_host(true);
901     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
902     Tensor indices = c->input(1);
903     PartialTensorShape element_shape;
904     OP_REQUIRES(
905         c, !TensorShapeUtils::IsMatrixOrHigher(c->input(2).shape()),
906         errors::InvalidArgument(
907             "TensorListScatter: element_shape must be at most rank 1 but has ",
908             "the shape of ", c->input(2).shape().DebugString()));
909     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
910     // TensorListScatterV2 passes the num_elements input, TensorListScatter does
911     // not.
912     int num_elements = c->num_inputs() >= 4 ? c->input(3).scalar<int>()() : -1;
913     OP_REQUIRES(c, num_elements >= -1,
914                 errors::InvalidArgument(
915                     "TensorListScatter expects num_elements >= -1, found: ",
916                     num_elements));
917     TensorList output_list;
918     const Tensor& input_tensor = c->input(0);
919     output_list.element_dtype = input_tensor.dtype();
920     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
921                 errors::InvalidArgument(
922                     "Tensor must be at least a vector, but saw shape: ",
923                     input_tensor.shape().DebugString()));
924     TensorShape output_shape(input_tensor.shape());
925     output_shape.RemoveDim(0);
926     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
927                 errors::InvalidArgument(
928                     "Specified a list with shape ", element_shape.DebugString(),
929                     " from a tensor with shape ", output_shape.DebugString()));
930     output_list.element_shape = element_shape;
931 
932     OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
933                 errors::InvalidArgument(
934                     "Invalid number of rows in input tensor. Expected: ",
935                     indices.NumElements(),
936                     " Actual: ", input_tensor.shape().dim_size(0)));
937 
938     // Validate indices and resize output_list.tensors to fit the highest index.
939     {
940       int highest_index = -1;
941       for (int index = 0; index < indices.NumElements(); ++index) {
942         const int i = indices.flat<int32>()(index);
943         OP_REQUIRES(
944             c, i >= 0,
945             errors::InvalidArgument(
946                 "Indices in TensorListScatter must all be non-negative."));
947         OP_REQUIRES(c, num_elements == -1 || i < num_elements,
948                     errors::InvalidArgument(
949                         "TensorListScatter: Trying to scatter at index ", i,
950                         " in list with size ", num_elements));
951         if (i > highest_index) {
952           highest_index = i;
953         }
954       }
955       output_list.tensors().resize(std::max(highest_index + 1, num_elements),
956                                    Tensor(DT_INVALID));
957     }
958 
959     OP_REQUIRES_OK(c,
960                    Scatter<Device, T>(c, input_tensor, indices, &output_list));
961     output_tensor->scalar<Variant>()() = std::move(output_list);
962   }
963 };
964 
965 template <typename Device>
TensorListBinaryAdd(OpKernelContext * c,const TensorList & a,const TensorList & b,TensorList * out)966 Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
967                            const TensorList& b, TensorList* out) {
968   return TensorListBinaryAdd(c, a, b, out, BinaryAddTensors<Device>);
969 }
970 
971 template <typename Device>
TensorListZerosLike(OpKernelContext * c,const TensorList & x,TensorList * y)972 Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
973                            TensorList* y) {
974   return TensorListZerosLike(c, x, y, ZerosLikeTensor<Device>);
975 }
976 
977 template <typename Device, typename T>
978 class TensorListPushBackBatch : public OpKernel {
979  public:
TensorListPushBackBatch(OpKernelConstruction * c)980   explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
981     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
982   }
983 
Compute(OpKernelContext * c)984   void Compute(OpKernelContext* c) override {
985     const Tensor& input = c->input(1);
986     OP_REQUIRES(c, element_dtype_ == input.dtype(),
987                 errors::InvalidArgument("Invalid data types; list elements ",
988                                         DataTypeString(element_dtype_),
989                                         " but tried to append ",
990                                         DataTypeString(input.dtype())));
991     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
992                 errors::InvalidArgument(
993                     "Expected tensor to be at least a vector, but saw shape: ",
994                     input.shape().DebugString()));
995 
996     const TensorShape& tls_shape = c->input(0).shape();
997 
998     // For purposes of input forwarding, we want the least restrictive
999     // AllocatorAttributes possible.  If we need to allocate later,
1000     // we'll request the DT_VARIANT be allocated on host.
1001     AllocatorAttributes attr;
1002 
1003     std::unique_ptr<Tensor> tls_alias = c->forward_input(
1004         0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
1005         DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
1006 
1007     bool ok_to_alias = tls_alias != nullptr;
1008     if (tls_alias && tls_alias->dtype() == DT_VARIANT &&
1009         tls_alias->NumElements() > 0) {
1010       auto alias_t = tls_alias->flat<Variant>();
1011       for (int i = 0; i < tls_alias->NumElements(); ++i) {
1012         TensorList* tl_i = alias_t(i).get<TensorList>();
1013         if (tl_i == nullptr || !tl_i->RefCountIsOne()) {
1014           ok_to_alias = false;
1015           break;
1016         }
1017       }
1018     }
1019     const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0);
1020 
1021     OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
1022                 errors::InvalidArgument(
1023                     "Expected input_handles dtype to be Variant, but saw: ",
1024                     DataTypeString(tls.dtype())));
1025     OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
1026                 errors::InvalidArgument(
1027                     "Expected input_handles to be a vector, but saw shape: ",
1028                     tls_shape.DebugString()));
1029     const int64_t batch_size = tls.NumElements();
1030     OP_REQUIRES(c, input.dim_size(0) == batch_size,
1031                 errors::InvalidArgument(
1032                     "Expected tensor.shape[0] == input_handles.size, but saw ",
1033                     input.dim_size(0), " vs. ", batch_size));
1034     auto tls_t = tls.vec<Variant>();
1035 
1036     TensorShape input_element_shape = input.shape();
1037     input_element_shape.RemoveDim(0);
1038     std::vector<const TensorList*> tl_batch;
1039     for (int64_t b = 0; b < batch_size; ++b) {
1040       const TensorList* l = tls_t(b).get<TensorList>();
1041       OP_REQUIRES(c, l != nullptr,
1042                   errors::InvalidArgument("Input handle at index ", b,
1043                                           " is not a list. Saw: '",
1044                                           tls_t(b).DebugString(), "'"));
1045       OP_REQUIRES(
1046           c, l->element_shape.IsCompatibleWith(input_element_shape),
1047           errors::InvalidArgument(
1048               "Tried to append a tensor with incompatible shape to a "
1049               "list at index ",
1050               b, ". Op element shape: ", input_element_shape.DebugString(),
1051               " list shape: ", l->element_shape.DebugString()));
1052       OP_REQUIRES(c, element_dtype_ == l->element_dtype,
1053                   errors::InvalidArgument(
1054                       "Invalid data type at index ", b, "; op elements ",
1055                       DataTypeString(element_dtype_), " but list elements ",
1056                       DataTypeString(l->element_dtype)));
1057       tl_batch.push_back(l);
1058     }
1059 
1060     Tensor* result;
1061 
1062     if (ok_to_alias) {
1063       result = tls_alias.get();
1064       c->set_output(0, *result);
1065     } else {
1066       // DT_VARIANT tensors always allocated on host.
1067       AllocatorAttributes attr;
1068       attr.set_on_host(true);
1069       OP_REQUIRES_OK(
1070           c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
1071     }
1072 
1073     if (batch_size == 0) {
1074       return;
1075     }
1076 
1077     auto input_t = input.flat_outer_dims<T, 2>();
1078     auto result_t = result->vec<Variant>();
1079 
1080     for (int64_t b = 0; b < batch_size; ++b) {
1081       if (!ok_to_alias) {
1082         result_t(b) = tl_batch[b]->Copy();
1083       }
1084       TensorList* output = result_t(b).get<TensorList>();
1085       DCHECK(output != nullptr);
1086       Tensor frame;
1087       OP_REQUIRES_OK(
1088           c, c->allocate_temp(element_dtype_, input_element_shape, &frame));
1089       if (input_element_shape.num_elements() > 0) {
1090         auto frame_t = frame.flat<T>();
1091         // TODO(penporn): Get this if out of the batch loop.
1092         if (IsPluggableDevice(c)) {
1093           // The chip method need Eigen Device, so need to use Tensor.Slice
1094           // instead of chip for pluggable device. The input should be reshaped
1095           // to 2-D and so can be sliced by batch dim.
1096           auto input_t_shape =
1097               TensorShape({input_t.dimension(0), input_t.dimension(1)});
1098           auto input_reshaped = Tensor();
1099           OP_REQUIRES(c, input_reshaped.CopyFrom(input, input_t_shape),
1100                       errors::Unknown("Unexpected shape error."));
1101 
1102           auto input_batch = input_reshaped.Slice(b, b + 1);
1103           CopyTensorPluggableDevice<T>(c, input_batch, frame);
1104         } else {
1105           frame_t.device(c->eigen_device<Device>()) =
1106               input_t.template chip<0>(b);
1107         }
1108       }
1109       output->tensors().push_back(std::move(frame));
1110     }
1111   }
1112 
1113  private:
1114   DataType element_dtype_;
1115 };
1116 
1117 }  // namespace tensorflow
1118 
1119 #undef PLUGGABLE_DEVICE_SUPPORTED
1120 #endif  // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
1121