• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 
22 #include "tensorflow/core/kernels/list_kernels.h"
23 
24 #include <limits>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant.h"
32 #include "tensorflow/core/framework/variant_op_registry.h"
33 #include "tensorflow/core/kernels/concat_lib.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/util/util.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 
TensorShapeFromTensor(const Tensor & t,PartialTensorShape * out)42 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
43   if (t.shape() == TensorShape({})) {
44     if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
45         (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) {
46       *out = PartialTensorShape();
47       return Status::OK();
48     }
49     return errors::InvalidArgument(
50         "The only valid scalar shape tensor is the fully unknown shape "
51         "specified as -1.");
52   }
53   if (t.dtype() == DT_INT32) {
54     return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
55                                                 t.NumElements(), out);
56   } else if (t.dtype() == DT_INT64) {
57     return PartialTensorShape::MakePartialShape(t.vec<int64>().data(),
58                                                 t.NumElements(), out);
59   }
60   return errors::InvalidArgument(
61       "Expected an int32 or int64 shape tensor; found ",
62       DataTypeString(t.dtype()));
63 }
64 
GetElementShapeFromInput(OpKernelContext * c,const TensorList & tensor_list,int index,PartialTensorShape * element_shape)65 Status GetElementShapeFromInput(OpKernelContext* c,
66                                 const TensorList& tensor_list, int index,
67                                 PartialTensorShape* element_shape) {
68   TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape));
69   // Check that `element_shape` and `tensor_list.element_shape` are
70   // compatible and store the merged shape in `element_shape`.
71   PartialTensorShape tmp = *element_shape;
72   TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape));
73   return Status::OK();
74 }
75 
GetInputList(OpKernelContext * c,int index,const TensorList ** list)76 Status GetInputList(OpKernelContext* c, int index, const TensorList** list) {
77   if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
78     return errors::InvalidArgument("Input list must be a scalar saw: ",
79                                    c->input(index).shape().DebugString());
80   }
81   const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>();
82   if (l == nullptr) {
83     return errors::InvalidArgument(
84         "Input handle is not a list. Saw: '",
85         c->input(index).scalar<Variant>()().DebugString(), "'");
86   }
87   *list = l;
88   return Status::OK();
89 }
90 
ForwardInputOrCreateNewList(OpKernelContext * c,int32 input_index,int32 output_index,const TensorList & input_list,TensorList ** output_list)91 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
92                                    int32 output_index,
93                                    const TensorList& input_list,
94                                    TensorList** output_list) {
95   // Attempt to forward the input tensor to the output if possible.
96   std::unique_ptr<Tensor> maybe_output = c->forward_input(
97       input_index, output_index, DT_VARIANT, TensorShape{},
98       c->input_memory_type(input_index), AllocatorAttributes());
99   Tensor* output_tensor;
100   if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
101       maybe_output->NumElements() == 1) {
102     output_tensor = maybe_output.get();
103     TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>();
104     if (tmp_out == nullptr) {
105       return errors::InvalidArgument(
106           "Expected input ", input_index, " to be a TensorList but saw ",
107           output_tensor->scalar<Variant>()().TypeName());
108     }
109     if (tmp_out->RefCountIsOne()) {
110       // Woohoo, forwarding succeeded!
111       c->set_output(output_index, *output_tensor);
112       *output_list = tmp_out;
113       return Status::OK();
114     }
115   }
116 
117   // If forwarding is not possible allocate a new output tensor and copy
118   // the `input_list` to it.
119   AllocatorAttributes attr;
120   attr.set_on_host(true);
121   TF_RETURN_IF_ERROR(
122       c->allocate_output(output_index, {}, &output_tensor, attr));
123   output_tensor->scalar<Variant>()() = input_list.Copy();
124 
125   *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
126   return Status::OK();
127 }
128 
129 class EmptyTensorList : public OpKernel {
130  public:
EmptyTensorList(OpKernelConstruction * ctx)131   explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
132     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
133   }
134 
Compute(OpKernelContext * ctx)135   void Compute(OpKernelContext* ctx) override {
136     const Tensor& max_num_elements_t = ctx->input(1);
137     OP_REQUIRES(
138         ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()),
139         errors::InvalidArgument(
140             "max_num_elements expected to be a scalar ",
141             "but got shape: ", max_num_elements_t.shape().DebugString()));
142     Tensor* result;
143     AllocatorAttributes attr;
144     attr.set_on_host(true);
145     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
146     TensorList empty;
147     empty.element_dtype = element_dtype_;
148     empty.max_num_elements = max_num_elements_t.scalar<int32>()();
149     PartialTensorShape element_shape;
150     OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
151     empty.element_shape = element_shape;
152     result->scalar<Variant>()() = std::move(empty);
153   }
154 
155  private:
156   DataType element_dtype_;
157 };
158 
159 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
160                         EmptyTensorList);
161 
162 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
163 
164 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
165                             .Device(DEVICE_GPU)
166                             .HostMemory("element_shape")
167                             .HostMemory("max_num_elements"),
168                         EmptyTensorList);
169 
170 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
171 
172 class TensorListPushBack : public OpKernel {
173  public:
TensorListPushBack(OpKernelConstruction * c)174   explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
175     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
176   }
177 
~TensorListPushBack()178   ~TensorListPushBack() override {}
179 
Compute(OpKernelContext * c)180   void Compute(OpKernelContext* c) override {
181     const Tensor& input = c->input(1);
182     OP_REQUIRES(c, element_dtype_ == input.dtype(),
183                 errors::InvalidArgument("Invalid data types; list elements ",
184                                         DataTypeString(element_dtype_),
185                                         " but tried to append ",
186                                         DataTypeString(input.dtype())));
187 
188     const TensorList* l = nullptr;
189     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
190     OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
191                 errors::InvalidArgument(
192                     "Tried to append a tensor with incompatible shape to a "
193                     "list. Op element shape: ",
194                     input.shape().DebugString(),
195                     " list shape: ", l->element_shape.DebugString()));
196     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
197                 errors::InvalidArgument("Invalid data types; op elements ",
198                                         DataTypeString(element_dtype_),
199                                         " but list elements ",
200                                         DataTypeString(l->element_dtype)));
201 
202     if (l->max_num_elements != -1) {
203       OP_REQUIRES(
204           c, l->tensors().size() < l->max_num_elements,
205           errors::InvalidArgument("Tried to push item into a full list",
206                                   " list size: ", l->tensors().size(),
207                                   " max_num_elements: ", l->max_num_elements));
208     }
209 
210     TensorList* output_list = nullptr;
211     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
212     output_list->tensors().push_back(input);
213   }
214 
215  private:
216   DataType element_dtype_;
217 };
218 
219 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
220                         TensorListPushBack);
221 
222 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
223 
224 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
225                         TensorListPushBack);
226 
227 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
228 
229 class TensorListLength : public OpKernel {
230  public:
TensorListLength(OpKernelConstruction * c)231   explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
~TensorListLength()232   ~TensorListLength() override {}
233 
Compute(OpKernelContext * c)234   void Compute(OpKernelContext* c) override {
235     const TensorList* l = nullptr;
236     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
237     Tensor* result;
238     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
239     result->scalar<int32>()() = l->tensors().size();
240   }
241 };
242 
243 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
244                         TensorListLength);
245 
246 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
247 
248 REGISTER_KERNEL_BUILDER(
249     Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
250     TensorListLength);
251 
252 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
253 
254 class TensorListElementShape : public OpKernel {
255  public:
TensorListElementShape(OpKernelConstruction * c)256   explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
257 
Compute(OpKernelContext * c)258   void Compute(OpKernelContext* c) override {
259     const TensorList* l = nullptr;
260     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
261     Tensor* result;
262     if (l->element_shape.unknown_rank()) {
263       OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
264       if (result->dtype() == DT_INT32) {
265         result->scalar<int32>()() = -1;
266       } else {
267         result->scalar<int64>()() = -1;
268       }
269     } else {
270       OP_REQUIRES_OK(c, c->allocate_output(
271                             0, TensorShape{l->element_shape.dims()}, &result));
272       for (int i = 0; i < l->element_shape.dims(); ++i) {
273         if (result->dtype() == DT_INT32) {
274           result->flat<int32>()(i) = l->element_shape.dim_size(i);
275         } else {
276           result->flat<int64>()(i) = l->element_shape.dim_size(i);
277         }
278       }
279     }
280   }
281 };
282 
283 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
284                         TensorListElementShape);
285 
286 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
287 
288 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
289                             .Device(DEVICE_GPU)
290                             .HostMemory("element_shape"),
291                         TensorListElementShape);
292 
293 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
294 
295 class TensorListReserve : public OpKernel {
296  public:
TensorListReserve(OpKernelConstruction * c)297   explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
298     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
299   }
300 
Compute(OpKernelContext * c)301   void Compute(OpKernelContext* c) override {
302     PartialTensorShape element_shape;
303     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
304     int32 num_elements = c->input(1).scalar<int32>()();
305     TensorList output;
306     output.element_shape = element_shape;
307     output.element_dtype = element_dtype_;
308     output.tensors().resize(num_elements, Tensor(DT_INVALID));
309     Tensor* result;
310     AllocatorAttributes attr;
311     attr.set_on_host(true);
312     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
313     result->scalar<Variant>()() = std::move(output);
314   }
315 
316  private:
317   DataType element_dtype_;
318 };
319 
320 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
321                         TensorListReserve);
322 
323 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
324 
325 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
326                             .Device(DEVICE_GPU)
327                             .HostMemory("element_shape")
328                             .HostMemory("num_elements"),
329                         TensorListReserve);
330 
331 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
332 class TensorListResize : public OpKernel {
333  public:
TensorListResize(OpKernelConstruction * c)334   explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {}
335 
Compute(OpKernelContext * c)336   void Compute(OpKernelContext* c) override {
337     const TensorList* input_list = nullptr;
338     OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
339     int32 size = c->input(1).scalar<int32>()();
340     OP_REQUIRES(
341         c, size >= 0,
342         errors::InvalidArgument(
343             "TensorListSlice expects size to be non-negative. Got: ", size));
344 
345     std::unique_ptr<Tensor> maybe_result =
346         c->forward_input(0, 0, DT_VARIANT, TensorShape{},
347                          c->input_memory_type(0), AllocatorAttributes());
348     if (maybe_result != nullptr) {
349       TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>();
350       if (out->RefCountIsOne()) {
351         // We are able to forward the input.
352         out->tensors().resize(size, Tensor(DT_INVALID));
353         c->set_output(0, *maybe_result);
354         return;
355       }
356     }
357 
358     // We were not able to forward the input.  Will have to resize from scratch.
359     Tensor* result;
360     AllocatorAttributes attr;
361     attr.set_on_host(true);
362     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
363     TensorList output_list;
364     output_list.element_shape = input_list->element_shape;
365     output_list.element_dtype = input_list->element_dtype;
366     output_list.max_num_elements = input_list->max_num_elements;
367     if (size > input_list->tensors().size()) {
368       output_list.tensors().insert(output_list.tensors().begin(),
369                                    input_list->tensors().begin(),
370                                    input_list->tensors().end());
371       // Add DT_INVALID tensors to the end of the list if the requested size
372       // is larger than the list length.
373       output_list.tensors().resize(size, Tensor(DT_INVALID));
374     } else {
375       output_list.tensors().insert(output_list.tensors().begin(),
376                                    input_list->tensors().begin(),
377                                    input_list->tensors().begin() + size);
378     }
379     result->scalar<Variant>()() = std::move(output_list);
380   }
381 };
382 
383 REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU),
384                         TensorListResize);
385 
386 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
387 
388 REGISTER_KERNEL_BUILDER(
389     Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"),
390     TensorListResize);
391 
392 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
393 
394 class TensorListSetItem : public OpKernel {
395  public:
TensorListSetItem(OpKernelConstruction * c)396   explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
397     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
398   }
399 
Compute(OpKernelContext * c)400   void Compute(OpKernelContext* c) override {
401     const TensorList* l = nullptr;
402     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
403     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
404                 errors::InvalidArgument("Invalid data types; op elements ",
405                                         DataTypeString(element_dtype_),
406                                         " but list elements ",
407                                         DataTypeString(l->element_dtype)));
408     int32 index = c->input(1).scalar<int32>()();
409     OP_REQUIRES(c, index < l->tensors().size(),
410                 errors::InvalidArgument("Trying to modify element ", index,
411                                         " in a list with ", l->tensors().size(),
412                                         " elements."));
413     const Tensor& value = c->input(2);
414     OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
415                 errors::InvalidArgument(
416                     "Tried to set a tensor with incompatible shape at a "
417                     "list index. Item element shape: ",
418                     value.shape().DebugString(),
419                     " list shape: ", l->element_shape.DebugString()));
420     TensorList* output_list = nullptr;
421     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
422     output_list->tensors()[index] = value;
423   }
424 
425  private:
426   DataType element_dtype_;
427 };
428 
429 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
430                         TensorListSetItem);
431 
432 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
433 
434 #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T)                      \
435   REGISTER_KERNEL_BUILDER(Name("TensorListSetItem")               \
436                               .TypeConstraint<T>("element_dtype") \
437                               .Device(DEVICE_GPU)                 \
438                               .HostMemory("index"),               \
439                           TensorListSetItem);
440 
441 TF_CALL_GPU_ALL_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
442 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
443 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
444 REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16)
445 #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU
446 
447 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
448 
449 class TensorListConcatLists : public OpKernel {
450  public:
TensorListConcatLists(OpKernelConstruction * c)451   explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) {
452     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
453   }
454 
Compute(OpKernelContext * c)455   void Compute(OpKernelContext* c) override {
456     const TensorShape& tl_a_shape = c->input(0).shape();
457     const TensorShape& tl_b_shape = c->input(1).shape();
458     OP_REQUIRES(
459         c, tl_a_shape == tl_b_shape,
460         errors::InvalidArgument("Incompatible input TensorList tensor shapes: ",
461                                 tl_a_shape.DebugString(), " vs. ",
462                                 tl_b_shape.DebugString()));
463     AllocatorAttributes attr;
464     std::unique_ptr<Tensor> tl_alias = c->forward_input(
465         0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape,
466         DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
467 
468     // tl_a may be aliased by tl_alias.
469     const Tensor& tl_a = c->input(0);
470     const Tensor& tl_b = c->input(1);
471 
472     Tensor* output = nullptr;
473     bool ok_to_alias = tl_alias != nullptr;
474     if (tl_alias && tl_alias->dtype() == DT_VARIANT &&
475         tl_alias->NumElements() > 0) {
476       auto tl_a_t = tl_alias->flat<Variant>();
477       for (int64 i = 0; i < tl_alias->NumElements(); ++i) {
478         TensorList* aliased = tl_a_t(i).get<TensorList>();
479         if (aliased == nullptr || !aliased->RefCountIsOne()) {
480           ok_to_alias = false;
481           break;
482         }
483       }
484       if (ok_to_alias) {
485         c->set_output(0, *tl_alias);
486         output = tl_alias.get();
487       }
488     }
489     if (!ok_to_alias) {
490       // Couldn't alias the entire Tensor.  We'll be conservative and not try
491       // to alias individual batch entries.
492       attr.set_on_host(true);
493       OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
494     }
495 
496     auto output_t = output->flat<Variant>();
497     auto tl_a_t = tl_a.flat<Variant>();
498     auto tl_b_t = tl_b.flat<Variant>();
499 
500     for (int64 i = 0; i < tl_a.NumElements(); ++i) {
501       const TensorList* l_a = tl_a_t(i).get<TensorList>();
502       const TensorList* l_b = tl_b_t(i).get<TensorList>();
503       OP_REQUIRES(
504           c, l_a != nullptr,
505           errors::InvalidArgument("input_a is not a TensorList at index ", i,
506                                   ".  Saw: '", tl_a_t(i).DebugString(), "'"));
507       OP_REQUIRES(
508           c, l_b != nullptr,
509           errors::InvalidArgument("input_b is not a TensorList at index ", i,
510                                   ".  Saw: '", tl_b_t(i).DebugString(), "'"));
511       OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
512                   errors::InvalidArgument(
513                       "input_a[", i, "].dtype != element_dtype.  Saw: ",
514                       DataTypeString(l_a->element_dtype), " vs. ",
515                       DataTypeString(element_dtype_)));
516       OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
517                   errors::InvalidArgument(
518                       "input_b[", i, "].dtype != element_dtype.  Saw: ",
519                       DataTypeString(l_b->element_dtype), " vs. ",
520                       DataTypeString(element_dtype_)));
521       OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
522                   errors::InvalidArgument(
523                       "input_a and input_b TensorList element shapes are not "
524                       "identical at index ",
525                       i, ".  Saw ", l_a->element_shape.DebugString(), " vs. ",
526                       l_b->element_shape.DebugString()));
527       if (ok_to_alias) {
528         TensorList* out = output_t(i).get<TensorList>();
529         std::copy(l_b->tensors().begin(), l_b->tensors().end(),
530                   std::back_inserter(out->tensors()));
531       } else {
532         TensorList out = l_a->Copy();
533         std::copy(l_b->tensors().begin(), l_b->tensors().end(),
534                   std::back_inserter(out.tensors()));
535         output_t(i) = std::move(out);
536       }
537     }
538   }
539 
540  private:
541   DataType element_dtype_;
542 };
543 
544 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU),
545                         TensorListConcatLists);
546 
547 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
548 
549 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU),
550                         TensorListConcatLists);
551 
552 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
553 
554 #define REGISTER_TENSOR_LIST_OPS_CPU(T)                                    \
555   REGISTER_KERNEL_BUILDER(Name("TensorListStack")                          \
556                               .TypeConstraint<T>("element_dtype")          \
557                               .Device(DEVICE_CPU),                         \
558                           TensorListStack<CPUDevice, T>)                   \
559   REGISTER_KERNEL_BUILDER(Name("TensorListGather")                         \
560                               .TypeConstraint<T>("element_dtype")          \
561                               .Device(DEVICE_CPU),                         \
562                           TensorListGather<CPUDevice, T>)                  \
563   REGISTER_KERNEL_BUILDER(Name("TensorListConcat")                         \
564                               .TypeConstraint<T>("element_dtype")          \
565                               .Device(DEVICE_CPU),                         \
566                           TensorListConcat<CPUDevice, T>)                  \
567   REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2")                       \
568                               .TypeConstraint<T>("element_dtype")          \
569                               .Device(DEVICE_CPU),                         \
570                           TensorListConcat<CPUDevice, T>)                  \
571   REGISTER_KERNEL_BUILDER(Name("TensorListGetItem")                        \
572                               .TypeConstraint<T>("element_dtype")          \
573                               .Device(DEVICE_CPU),                         \
574                           TensorListGetItem<CPUDevice, T>)                 \
575   REGISTER_KERNEL_BUILDER(Name("TensorListPopBack")                        \
576                               .TypeConstraint<T>("element_dtype")          \
577                               .Device(DEVICE_CPU),                         \
578                           TensorListPopBack<CPUDevice, T>)                 \
579   REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor")                     \
580                               .TypeConstraint<T>("element_dtype")          \
581                               .Device(DEVICE_CPU),                         \
582                           TensorListFromTensor<CPUDevice, T>)              \
583   REGISTER_KERNEL_BUILDER(Name("TensorListScatter")                        \
584                               .TypeConstraint<T>("element_dtype")          \
585                               .Device(DEVICE_CPU),                         \
586                           TensorListScatter<CPUDevice, T>)                 \
587   REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2")                      \
588                               .TypeConstraint<T>("element_dtype")          \
589                               .Device(DEVICE_CPU),                         \
590                           TensorListScatter<CPUDevice, T>)                 \
591   REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList")        \
592                               .TypeConstraint<T>("element_dtype")          \
593                               .Device(DEVICE_CPU),                         \
594                           TensorListScatterIntoExistingList<CPUDevice, T>) \
595   REGISTER_KERNEL_BUILDER(Name("TensorListSplit")                          \
596                               .TypeConstraint<T>("element_dtype")          \
597                               .Device(DEVICE_CPU),                         \
598                           TensorListSplit<CPUDevice, T>)                   \
599   REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch")                  \
600                               .TypeConstraint<T>("element_dtype")          \
601                               .Device(DEVICE_CPU),                         \
602                           TensorListPushBackBatch<CPUDevice, T>)
603 
604 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU);
605 REGISTER_TENSOR_LIST_OPS_CPU(quint8);
606 REGISTER_TENSOR_LIST_OPS_CPU(qint8);
607 REGISTER_TENSOR_LIST_OPS_CPU(quint16);
608 REGISTER_TENSOR_LIST_OPS_CPU(qint16);
609 REGISTER_TENSOR_LIST_OPS_CPU(qint32);
610 REGISTER_TENSOR_LIST_OPS_CPU(Variant);
611 
612 #undef REGISTER_TENSOR_LIST_OPS_CPU
613 
614 #define REGISTER_TENSOR_LIST_OPS_CPU(T)
615 
616 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
617                                           TensorList,
618                                           TensorListBinaryAdd<CPUDevice>);
619 
620 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
621                                          DEVICE_CPU, TensorList,
622                                          TensorListZerosLike<CPUDevice>);
623 
624 }  // namespace tensorflow
625