• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 
22 #include "tensorflow/core/kernels/list_kernels.h"
23 
24 #include <algorithm>
25 #include <iterator>
26 #include <limits>
27 #include <memory>
28 #include <utility>
29 
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/framework/variant.h"
37 #include "tensorflow/core/framework/variant_op_registry.h"
38 #include "tensorflow/core/platform/errors.h"
39 
40 namespace tensorflow {
41 
42 typedef Eigen::ThreadPoolDevice CPUDevice;
43 
TensorShapeFromTensor(const Tensor & t,PartialTensorShape * out)44 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
45   if (t.shape() == TensorShape({})) {
46     if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
47         (t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1)) {
48       *out = PartialTensorShape();
49       return OkStatus();
50     }
51     return errors::InvalidArgument(
52         "The only valid scalar shape tensor is the fully unknown shape "
53         "specified as -1.");
54   } else if (t.shape().dims() != 1) {
55     return errors::InvalidArgument("Shape must be at most rank 1 but is rank ",
56                                    t.shape().dims());
57   }
58   if (t.dtype() == DT_INT32) {
59     return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
60                                                 t.NumElements(), out);
61   } else if (t.dtype() == DT_INT64) {
62     return PartialTensorShape::MakePartialShape(t.vec<int64_t>().data(),
63                                                 t.NumElements(), out);
64   }
65   return errors::InvalidArgument(
66       "Expected an int32 or int64 shape tensor; found ",
67       DataTypeString(t.dtype()));
68 }
69 
GetElementShapeFromInput(OpKernelContext * c,const TensorList & tensor_list,int index,PartialTensorShape * element_shape)70 Status GetElementShapeFromInput(OpKernelContext* c,
71                                 const TensorList& tensor_list, int index,
72                                 PartialTensorShape* element_shape) {
73   TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape));
74   // Check that `element_shape` and `tensor_list.element_shape` are
75   // compatible and store the merged shape in `element_shape`.
76   PartialTensorShape tmp = *element_shape;
77   TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape));
78   return OkStatus();
79 }
80 
GetInputList(OpKernelContext * c,int index,const TensorList ** list)81 Status GetInputList(OpKernelContext* c, int index, const TensorList** list) {
82   if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
83     return errors::InvalidArgument("Input list must be a scalar saw: ",
84                                    c->input(index).shape().DebugString());
85   }
86   const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>();
87   if (l == nullptr) {
88     return errors::InvalidArgument(
89         "Input handle is not a list. Saw: '",
90         c->input(index).scalar<Variant>()().DebugString(), "'");
91   }
92   *list = l;
93   return OkStatus();
94 }
95 
ForwardInputOrCreateNewList(OpKernelContext * c,int32_t input_index,int32_t output_index,const TensorList & input_list,TensorList ** output_list)96 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index,
97                                    int32_t output_index,
98                                    const TensorList& input_list,
99                                    TensorList** output_list) {
100   // Attempt to forward the input tensor to the output if possible.
101   std::unique_ptr<Tensor> maybe_output = c->forward_input(
102       input_index, output_index, DT_VARIANT, TensorShape{},
103       c->input_memory_type(input_index), AllocatorAttributes());
104   Tensor* output_tensor;
105   if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
106       maybe_output->NumElements() == 1) {
107     output_tensor = maybe_output.get();
108     TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>();
109     if (tmp_out == nullptr) {
110       return errors::InvalidArgument(
111           "Expected input ", input_index, " to be a TensorList but saw ",
112           output_tensor->scalar<Variant>()().TypeName());
113     }
114     if (tmp_out->RefCountIsOne()) {
115       // Woohoo, forwarding succeeded!
116       c->set_output(output_index, *output_tensor);
117       *output_list = tmp_out;
118       return OkStatus();
119     }
120   }
121 
122   // If forwarding is not possible allocate a new output tensor and copy
123   // the `input_list` to it.
124   AllocatorAttributes attr;
125   attr.set_on_host(true);
126   TF_RETURN_IF_ERROR(
127       c->allocate_output(output_index, {}, &output_tensor, attr));
128   output_tensor->scalar<Variant>()() = input_list.Copy();
129 
130   *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
131   return OkStatus();
132 }
133 
134 class EmptyTensorList : public OpKernel {
135  public:
EmptyTensorList(OpKernelConstruction * ctx)136   explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
137     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
138   }
139 
Compute(OpKernelContext * ctx)140   void Compute(OpKernelContext* ctx) override {
141     const Tensor& max_num_elements_t = ctx->input(1);
142     OP_REQUIRES(
143         ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()),
144         errors::InvalidArgument(
145             "max_num_elements expected to be a scalar ",
146             "but got shape: ", max_num_elements_t.shape().DebugString()));
147     Tensor* result;
148     AllocatorAttributes attr;
149     attr.set_on_host(true);
150     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
151     TensorList empty;
152     empty.element_dtype = element_dtype_;
153     empty.max_num_elements = max_num_elements_t.scalar<int32>()();
154     PartialTensorShape element_shape;
155     OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
156     empty.element_shape = element_shape;
157     result->scalar<Variant>()() = std::move(empty);
158   }
159 
160  private:
161   DataType element_dtype_;
162 };
163 
164 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
165                         EmptyTensorList);
166 
167 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
168 
169 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
170                             .Device(DEVICE_GPU)
171                             .HostMemory("element_shape")
172                             .HostMemory("max_num_elements"),
173                         EmptyTensorList);
174 
175 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176 
177 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
178                             .Device(DEVICE_DEFAULT)
179                             .HostMemory("element_shape")
180                             .HostMemory("max_num_elements"),
181                         EmptyTensorList);
182 
183 class TensorListPushBack : public OpKernel {
184  public:
TensorListPushBack(OpKernelConstruction * c)185   explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
186     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
187   }
188 
~TensorListPushBack()189   ~TensorListPushBack() override {}
190 
Compute(OpKernelContext * c)191   void Compute(OpKernelContext* c) override {
192     const Tensor& input = c->input(1);
193     OP_REQUIRES(c, element_dtype_ == input.dtype(),
194                 errors::InvalidArgument("Invalid data types; list elements ",
195                                         DataTypeString(element_dtype_),
196                                         " but tried to append ",
197                                         DataTypeString(input.dtype())));
198 
199     const TensorList* l = nullptr;
200     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
201     OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
202                 errors::InvalidArgument(
203                     "Tried to append a tensor with incompatible shape to a "
204                     "list. Op element shape: ",
205                     input.shape().DebugString(),
206                     " list shape: ", l->element_shape.DebugString()));
207     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
208                 errors::InvalidArgument("Invalid data types; op elements ",
209                                         DataTypeString(element_dtype_),
210                                         " but list elements ",
211                                         DataTypeString(l->element_dtype)));
212 
213     if (l->max_num_elements != -1) {
214       OP_REQUIRES(
215           c, l->tensors().size() < l->max_num_elements,
216           errors::InvalidArgument("Tried to push item into a full list",
217                                   " list size: ", l->tensors().size(),
218                                   " max_num_elements: ", l->max_num_elements));
219     }
220 
221     TensorList* output_list = nullptr;
222     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
223     output_list->tensors().push_back(input);
224   }
225 
226  private:
227   DataType element_dtype_;
228 };
229 
230 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
231                         TensorListPushBack);
232 
233 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
234 
235 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
236                         TensorListPushBack);
237 
238 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
239 
240 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_DEFAULT),
241                         TensorListPushBack);
242 
243 class TensorListLength : public OpKernel {
244  public:
TensorListLength(OpKernelConstruction * c)245   explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
~TensorListLength()246   ~TensorListLength() override {}
247 
Compute(OpKernelContext * c)248   void Compute(OpKernelContext* c) override {
249     const TensorList* l = nullptr;
250     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
251     Tensor* result;
252     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
253     result->scalar<int32>()() = l->tensors().size();
254   }
255 };
256 
257 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
258                         TensorListLength);
259 
260 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
261 
262 REGISTER_KERNEL_BUILDER(
263     Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
264     TensorListLength);
265 
266 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
267 
268 REGISTER_KERNEL_BUILDER(
269     Name("TensorListLength").Device(DEVICE_DEFAULT).HostMemory("length"),
270     TensorListLength);
271 
272 class TensorListElementShape : public OpKernel {
273  public:
TensorListElementShape(OpKernelConstruction * c)274   explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
275 
Compute(OpKernelContext * c)276   void Compute(OpKernelContext* c) override {
277     const TensorList* l = nullptr;
278     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
279     Tensor* result;
280     if (l->element_shape.unknown_rank()) {
281       OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
282       if (result->dtype() == DT_INT32) {
283         result->scalar<int32>()() = -1;
284       } else {
285         result->scalar<int64_t>()() = -1;
286       }
287     } else {
288       OP_REQUIRES_OK(c, c->allocate_output(
289                             0, TensorShape{l->element_shape.dims()}, &result));
290       for (int i = 0; i < l->element_shape.dims(); ++i) {
291         if (result->dtype() == DT_INT32) {
292           result->flat<int32>()(i) = l->element_shape.dim_size(i);
293         } else {
294           result->flat<int64_t>()(i) = l->element_shape.dim_size(i);
295         }
296       }
297     }
298   }
299 };
300 
301 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
302                         TensorListElementShape);
303 
304 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
305 
306 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
307                             .Device(DEVICE_GPU)
308                             .HostMemory("element_shape"),
309                         TensorListElementShape);
310 
311 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
312 
313 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
314                             .Device(DEVICE_DEFAULT)
315                             .HostMemory("element_shape"),
316                         TensorListElementShape);
317 
318 class TensorListReserve : public OpKernel {
319  public:
TensorListReserve(OpKernelConstruction * c)320   explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
321     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
322   }
323 
Compute(OpKernelContext * c)324   void Compute(OpKernelContext* c) override {
325     PartialTensorShape element_shape;
326     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
327     OP_REQUIRES(
328         c, TensorShapeUtils::IsScalar(c->input(1).shape()),
329         errors::InvalidArgument(
330             "The num_elements to reserve must be a tensor size 1, but got ",
331             c->input(1).shape()));
332     int32_t num_elements = c->input(1).scalar<int32>()();
333     OP_REQUIRES(c, num_elements >= 0,
334                 errors::InvalidArgument("The num_elements to reserve must be a "
335                                         "non negative number, but got ",
336                                         num_elements));
337     TensorList output;
338     output.element_shape = element_shape;
339     output.element_dtype = element_dtype_;
340     output.tensors().resize(num_elements, Tensor(DT_INVALID));
341     Tensor* result;
342     AllocatorAttributes attr;
343     attr.set_on_host(true);
344     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
345     result->scalar<Variant>()() = std::move(output);
346   }
347 
348  private:
349   DataType element_dtype_;
350 };
351 
352 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
353                         TensorListReserve);
354 
355 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
356 
357 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
358                             .Device(DEVICE_GPU)
359                             .HostMemory("element_shape")
360                             .HostMemory("num_elements"),
361                         TensorListReserve);
362 
363 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
364 
365 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
366                             .Device(DEVICE_DEFAULT)
367                             .HostMemory("element_shape")
368                             .HostMemory("num_elements"),
369                         TensorListReserve);
370 
371 class TensorListResize : public OpKernel {
372  public:
TensorListResize(OpKernelConstruction * c)373   explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {}
374 
Compute(OpKernelContext * c)375   void Compute(OpKernelContext* c) override {
376     const TensorList* input_list = nullptr;
377     OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
378     int32_t size = c->input(1).scalar<int32>()();
379     OP_REQUIRES(
380         c, size >= 0,
381         errors::InvalidArgument(
382             "TensorListSlice expects size to be non-negative. Got: ", size));
383 
384     std::unique_ptr<Tensor> maybe_result =
385         c->forward_input(0, 0, DT_VARIANT, TensorShape{},
386                          c->input_memory_type(0), AllocatorAttributes());
387     if (maybe_result != nullptr) {
388       TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>();
389       if (out->RefCountIsOne()) {
390         // We are able to forward the input.
391         out->tensors().resize(size, Tensor(DT_INVALID));
392         c->set_output(0, *maybe_result);
393         return;
394       }
395     }
396 
397     // We were not able to forward the input.  Will have to resize from scratch.
398     Tensor* result;
399     AllocatorAttributes attr;
400     attr.set_on_host(true);
401     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
402     TensorList output_list;
403     output_list.element_shape = input_list->element_shape;
404     output_list.element_dtype = input_list->element_dtype;
405     output_list.max_num_elements = input_list->max_num_elements;
406     if (size > input_list->tensors().size()) {
407       output_list.tensors().insert(output_list.tensors().begin(),
408                                    input_list->tensors().begin(),
409                                    input_list->tensors().end());
410       // Add DT_INVALID tensors to the end of the list if the requested size
411       // is larger than the list length.
412       output_list.tensors().resize(size, Tensor(DT_INVALID));
413     } else {
414       output_list.tensors().insert(output_list.tensors().begin(),
415                                    input_list->tensors().begin(),
416                                    input_list->tensors().begin() + size);
417     }
418     result->scalar<Variant>()() = std::move(output_list);
419   }
420 };
421 
422 REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU),
423                         TensorListResize);
424 
425 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
426 
427 REGISTER_KERNEL_BUILDER(
428     Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"),
429     TensorListResize);
430 
431 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
432 
433 REGISTER_KERNEL_BUILDER(
434     Name("TensorListResize").Device(DEVICE_DEFAULT).HostMemory("size"),
435     TensorListResize);
436 
437 class TensorListSetItem : public OpKernel {
438  public:
TensorListSetItem(OpKernelConstruction * c)439   explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
440     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
441   }
442 
Compute(OpKernelContext * c)443   void Compute(OpKernelContext* c) override {
444     const TensorList* l = nullptr;
445     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
446     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
447                 errors::InvalidArgument("Invalid data types; op elements ",
448                                         DataTypeString(element_dtype_),
449                                         " but list elements ",
450                                         DataTypeString(l->element_dtype)));
451     int32_t index = c->input(1).scalar<int32>()();
452     OP_REQUIRES(c, index < l->tensors().size(),
453                 errors::InvalidArgument("Trying to modify element ", index,
454                                         " in a list with ", l->tensors().size(),
455                                         " elements."));
456     const Tensor& value = c->input(2);
457     OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
458                 errors::InvalidArgument(
459                     "Tried to set a tensor with incompatible shape at a "
460                     "list index. Item element shape: ",
461                     value.shape().DebugString(),
462                     " list shape: ", l->element_shape.DebugString()));
463     TensorList* output_list = nullptr;
464     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
465     output_list->tensors()[index] = value;
466   }
467 
468  private:
469   DataType element_dtype_;
470 };
471 
472 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
473                         TensorListSetItem);
474 
475 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
476 
477 #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T)                      \
478   REGISTER_KERNEL_BUILDER(Name("TensorListSetItem")               \
479                               .TypeConstraint<T>("element_dtype") \
480                               .Device(DEVICE_GPU)                 \
481                               .HostMemory("index"),               \
482                           TensorListSetItem);
483 
484 TF_CALL_GPU_ALL_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
485 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
486 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
487 REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16)
488 #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU
489 
490 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
491 
492 #define REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(T)                  \
493   REGISTER_KERNEL_BUILDER(Name("TensorListSetItem")               \
494                               .TypeConstraint<T>("element_dtype") \
495                               .Device(DEVICE_DEFAULT)             \
496                               .HostMemory("index"),               \
497                           TensorListSetItem);
498 
499 TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
500 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
501 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
502 REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(bfloat16)
503 #undef REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT
504 
505 class TensorListConcatLists : public OpKernel {
506  public:
TensorListConcatLists(OpKernelConstruction * c)507   explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) {
508     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
509   }
510 
Compute(OpKernelContext * c)511   void Compute(OpKernelContext* c) override {
512     const TensorShape& tl_a_shape = c->input(0).shape();
513     const TensorShape& tl_b_shape = c->input(1).shape();
514     OP_REQUIRES(
515         c, tl_a_shape == tl_b_shape,
516         errors::InvalidArgument("Incompatible input TensorList tensor shapes: ",
517                                 tl_a_shape.DebugString(), " vs. ",
518                                 tl_b_shape.DebugString()));
519     AllocatorAttributes attr;
520     std::unique_ptr<Tensor> tl_alias = c->forward_input(
521         0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape,
522         DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
523 
524     // tl_a may be aliased by tl_alias.
525     const Tensor& tl_a = c->input(0);
526     const Tensor& tl_b = c->input(1);
527 
528     Tensor* output = nullptr;
529     bool ok_to_alias = tl_alias != nullptr;
530     if (tl_alias && tl_alias->dtype() == DT_VARIANT &&
531         tl_alias->NumElements() > 0) {
532       auto tl_a_t = tl_alias->flat<Variant>();
533       for (int64_t i = 0; i < tl_alias->NumElements(); ++i) {
534         TensorList* aliased = tl_a_t(i).get<TensorList>();
535         if (aliased == nullptr || !aliased->RefCountIsOne()) {
536           ok_to_alias = false;
537           break;
538         }
539       }
540       if (ok_to_alias) {
541         c->set_output(0, *tl_alias);
542         output = tl_alias.get();
543       }
544     }
545     if (!ok_to_alias) {
546       // Couldn't alias the entire Tensor.  We'll be conservative and not try
547       // to alias individual batch entries.
548       attr.set_on_host(true);
549       OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
550     }
551 
552     auto output_t = output->flat<Variant>();
553     auto tl_a_t = tl_a.flat<Variant>();
554     auto tl_b_t = tl_b.flat<Variant>();
555 
556     for (int64_t i = 0; i < tl_a.NumElements(); ++i) {
557       const TensorList* l_a = tl_a_t(i).get<TensorList>();
558       const TensorList* l_b = tl_b_t(i).get<TensorList>();
559       OP_REQUIRES(
560           c, l_a != nullptr,
561           errors::InvalidArgument("input_a is not a TensorList at index ", i,
562                                   ".  Saw: '", tl_a_t(i).DebugString(), "'"));
563       OP_REQUIRES(
564           c, l_b != nullptr,
565           errors::InvalidArgument("input_b is not a TensorList at index ", i,
566                                   ".  Saw: '", tl_b_t(i).DebugString(), "'"));
567       OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
568                   errors::InvalidArgument(
569                       "input_a[", i, "].dtype != element_dtype.  Saw: ",
570                       DataTypeString(l_a->element_dtype), " vs. ",
571                       DataTypeString(element_dtype_)));
572       OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
573                   errors::InvalidArgument(
574                       "input_b[", i, "].dtype != element_dtype.  Saw: ",
575                       DataTypeString(l_b->element_dtype), " vs. ",
576                       DataTypeString(element_dtype_)));
577       OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
578                   errors::InvalidArgument(
579                       "input_a and input_b TensorList element shapes are not "
580                       "identical at index ",
581                       i, ".  Saw ", l_a->element_shape.DebugString(), " vs. ",
582                       l_b->element_shape.DebugString()));
583       if (ok_to_alias) {
584         TensorList* out = output_t(i).get<TensorList>();
585         std::copy(l_b->tensors().begin(), l_b->tensors().end(),
586                   std::back_inserter(out->tensors()));
587       } else {
588         TensorList out = l_a->Copy();
589         std::copy(l_b->tensors().begin(), l_b->tensors().end(),
590                   std::back_inserter(out.tensors()));
591         output_t(i) = std::move(out);
592       }
593     }
594   }
595 
596  private:
597   DataType element_dtype_;
598 };
599 
600 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU),
601                         TensorListConcatLists);
602 
603 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
604 
605 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU),
606                         TensorListConcatLists);
607 
608 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
609 
610 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_DEFAULT),
611                         TensorListConcatLists);
612 
613 #define REGISTER_TENSOR_LIST_OPS_CPU(T)                                    \
614   REGISTER_KERNEL_BUILDER(Name("TensorListStack")                          \
615                               .TypeConstraint<T>("element_dtype")          \
616                               .Device(DEVICE_CPU),                         \
617                           TensorListStack<CPUDevice, T>)                   \
618   REGISTER_KERNEL_BUILDER(Name("TensorListGather")                         \
619                               .TypeConstraint<T>("element_dtype")          \
620                               .Device(DEVICE_CPU),                         \
621                           TensorListGather<CPUDevice, T>)                  \
622   REGISTER_KERNEL_BUILDER(Name("TensorListConcat")                         \
623                               .TypeConstraint<T>("element_dtype")          \
624                               .Device(DEVICE_CPU),                         \
625                           TensorListConcat<CPUDevice, T>)                  \
626   REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2")                       \
627                               .TypeConstraint<T>("element_dtype")          \
628                               .Device(DEVICE_CPU),                         \
629                           TensorListConcat<CPUDevice, T>)                  \
630   REGISTER_KERNEL_BUILDER(Name("TensorListGetItem")                        \
631                               .TypeConstraint<T>("element_dtype")          \
632                               .Device(DEVICE_CPU),                         \
633                           TensorListGetItem<CPUDevice, T>)                 \
634   REGISTER_KERNEL_BUILDER(Name("TensorListPopBack")                        \
635                               .TypeConstraint<T>("element_dtype")          \
636                               .Device(DEVICE_CPU),                         \
637                           TensorListPopBack<CPUDevice, T>)                 \
638   REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor")                     \
639                               .TypeConstraint<T>("element_dtype")          \
640                               .Device(DEVICE_CPU),                         \
641                           TensorListFromTensor<CPUDevice, T>)              \
642   REGISTER_KERNEL_BUILDER(Name("TensorListScatter")                        \
643                               .TypeConstraint<T>("element_dtype")          \
644                               .Device(DEVICE_CPU),                         \
645                           TensorListScatter<CPUDevice, T>)                 \
646   REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2")                      \
647                               .TypeConstraint<T>("element_dtype")          \
648                               .Device(DEVICE_CPU),                         \
649                           TensorListScatter<CPUDevice, T>)                 \
650   REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList")        \
651                               .TypeConstraint<T>("element_dtype")          \
652                               .Device(DEVICE_CPU),                         \
653                           TensorListScatterIntoExistingList<CPUDevice, T>) \
654   REGISTER_KERNEL_BUILDER(Name("TensorListSplit")                          \
655                               .TypeConstraint<T>("element_dtype")          \
656                               .Device(DEVICE_CPU),                         \
657                           TensorListSplit<CPUDevice, T>)                   \
658   REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch")                  \
659                               .TypeConstraint<T>("element_dtype")          \
660                               .Device(DEVICE_CPU),                         \
661                           TensorListPushBackBatch<CPUDevice, T>)
662 
663 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU);
664 REGISTER_TENSOR_LIST_OPS_CPU(quint8);
665 REGISTER_TENSOR_LIST_OPS_CPU(qint8);
666 REGISTER_TENSOR_LIST_OPS_CPU(quint16);
667 REGISTER_TENSOR_LIST_OPS_CPU(qint16);
668 REGISTER_TENSOR_LIST_OPS_CPU(qint32);
669 REGISTER_TENSOR_LIST_OPS_CPU(Variant);
670 
671 #undef REGISTER_TENSOR_LIST_OPS_CPU
672 
673 #define REGISTER_TENSOR_LIST_OPS_CPU(T)
674 
675 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
676                                           TensorList,
677                                           TensorListBinaryAdd<CPUDevice>);
678 
679 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
680                                          DEVICE_CPU, TensorList,
681                                          TensorListZerosLike<CPUDevice>);
682 
TensorListDeviceCopy(const TensorList & from,TensorList * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)683 static Status TensorListDeviceCopy(
684     const TensorList& from, TensorList* to,
685     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
686   to->element_shape = from.element_shape;
687   to->element_dtype = from.element_dtype;
688   to->max_num_elements = from.max_num_elements;
689   to->tensors().reserve(from.tensors().size());
690   for (const Tensor& t : from.tensors()) {
691     to->tensors().emplace_back(t.dtype());
692     if (t.dtype() != DT_INVALID) {
693       TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
694     }
695   }
696   return OkStatus();
697 }
698 
699 #define REGISTER_LIST_COPY(DIRECTION)                                         \
700   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
701                                                        TensorListDeviceCopy)
702 
703 REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
704 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
705 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
706 
707 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
708 
709 #define REGISTER_TENSOR_LIST_OPS_DEFAULT(T)                                \
710   REGISTER_KERNEL_BUILDER(Name("TensorListStack")                          \
711                               .TypeConstraint<T>("element_dtype")          \
712                               .HostMemory("element_shape")                 \
713                               .Device(DEVICE_DEFAULT),                     \
714                           TensorListStack<CPUDevice, T>)                   \
715   REGISTER_KERNEL_BUILDER(Name("TensorListGather")                         \
716                               .TypeConstraint<T>("element_dtype")          \
717                               .HostMemory("indices")                       \
718                               .HostMemory("element_shape")                 \
719                               .Device(DEVICE_DEFAULT),                     \
720                           TensorListGather<CPUDevice, T>)                  \
721   REGISTER_KERNEL_BUILDER(Name("TensorListConcat")                         \
722                               .TypeConstraint<T>("element_dtype")          \
723                               .HostMemory("lengths")                       \
724                               .Device(DEVICE_DEFAULT),                     \
725                           TensorListConcat<CPUDevice, T>)                  \
726   REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2")                       \
727                               .TypeConstraint<T>("element_dtype")          \
728                               .HostMemory("leading_dims")                  \
729                               .HostMemory("element_shape")                 \
730                               .HostMemory("lengths")                       \
731                               .Device(DEVICE_DEFAULT),                     \
732                           TensorListConcat<CPUDevice, T>)                  \
733   REGISTER_KERNEL_BUILDER(Name("TensorListGetItem")                        \
734                               .TypeConstraint<T>("element_dtype")          \
735                               .Device(DEVICE_DEFAULT)                      \
736                               .HostMemory("index")                         \
737                               .HostMemory("element_shape"),                \
738                           TensorListGetItem<CPUDevice, T>)                 \
739   REGISTER_KERNEL_BUILDER(Name("TensorListPopBack")                        \
740                               .TypeConstraint<T>("element_dtype")          \
741                               .Device(DEVICE_DEFAULT)                      \
742                               .HostMemory("element_shape"),                \
743                           TensorListPopBack<CPUDevice, T>)                 \
744   REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch")                  \
745                               .TypeConstraint<T>("element_dtype")          \
746                               .Device(DEVICE_DEFAULT),                     \
747                           TensorListPushBackBatch<CPUDevice, T>)           \
748   REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor")                     \
749                               .TypeConstraint<T>("element_dtype")          \
750                               .Device(DEVICE_DEFAULT)                      \
751                               .HostMemory("element_shape"),                \
752                           TensorListFromTensor<CPUDevice, T>)              \
753   REGISTER_KERNEL_BUILDER(Name("TensorListScatter")                        \
754                               .TypeConstraint<T>("element_dtype")          \
755                               .Device(DEVICE_DEFAULT)                      \
756                               .HostMemory("element_shape")                 \
757                               .HostMemory("indices"),                      \
758                           TensorListScatter<CPUDevice, T>)                 \
759   REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2")                      \
760                               .TypeConstraint<T>("element_dtype")          \
761                               .Device(DEVICE_DEFAULT)                      \
762                               .HostMemory("element_shape")                 \
763                               .HostMemory("num_elements")                  \
764                               .HostMemory("indices"),                      \
765                           TensorListScatter<CPUDevice, T>)                 \
766   REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList")        \
767                               .TypeConstraint<T>("element_dtype")          \
768                               .Device(DEVICE_DEFAULT)                      \
769                               .HostMemory("indices"),                      \
770                           TensorListScatterIntoExistingList<CPUDevice, T>) \
771   REGISTER_KERNEL_BUILDER(Name("TensorListSplit")                          \
772                               .TypeConstraint<T>("element_dtype")          \
773                               .Device(DEVICE_DEFAULT)                      \
774                               .HostMemory("element_shape")                 \
775                               .HostMemory("lengths"),                      \
776                           TensorListSplit<CPUDevice, T>)
777 
778 TF_CALL_int32(REGISTER_TENSOR_LIST_OPS_DEFAULT);
779 TF_CALL_int64(REGISTER_TENSOR_LIST_OPS_DEFAULT);
780 TF_CALL_bfloat16(REGISTER_TENSOR_LIST_OPS_DEFAULT);
781 TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_OPS_DEFAULT);
782 
783 #undef REGISTER_TENSOR_LIST_OPS_DEFAULT
784 }  // namespace tensorflow
785