• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <limits>
17 
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA
22 
23 #include "tensorflow/core/kernels/list_kernels.h"
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/kernels/concat_lib.h"
32 #include "tensorflow/core/lib/core/coding.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/util/util.h"
35 
36 namespace tensorflow {
37 
38 typedef Eigen::ThreadPoolDevice CPUDevice;
39 
40 // Variant compatible type for a list of tensors. This is mutable but instances
41 // should never be mutated after stored in a variant tensor.
TensorList(const TensorList & other)42 TensorList::TensorList(const TensorList& other)
43     : tensors(other.tensors),
44       element_shape(other.element_shape),
45       element_dtype(other.element_dtype),
46       max_num_elements(other.max_num_elements) {}
47 
Encode(VariantTensorData * data) const48 void TensorList::Encode(VariantTensorData* data) const {
49   data->set_type_name(TypeName());
50   std::vector<size_t> invalid_indices;
51   for (size_t i = 0; i < tensors.size(); i++) {
52     if (tensors.at(i).dtype() != DT_INVALID) {
53       *data->add_tensors() = tensors.at(i);
54     } else {
55       invalid_indices.push_back(i);
56     }
57   }
58   string metadata;
59   // TODO(b/118838800): Add a proto for storing the metadata.
60   // Metadata format:
61   // <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto>
62   core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size()));
63   for (size_t i : invalid_indices) {
64     core::PutVarint64(&metadata, static_cast<uint64>(i));
65   }
66   core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
67   core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements));
68   TensorShapeProto element_shape_proto;
69   element_shape.AsProto(&element_shape_proto);
70   element_shape_proto.AppendToString(&metadata);
71   data->set_metadata(metadata);
72 }
73 
TensorListDeviceCopy(const TensorList & from,TensorList * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)74 static Status TensorListDeviceCopy(
75     const TensorList& from, TensorList* to,
76     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
77   to->element_shape = from.element_shape;
78   to->element_dtype = from.element_dtype;
79   to->max_num_elements = from.max_num_elements;
80   to->tensors.reserve(from.tensors.size());
81   for (const Tensor& t : from.tensors) {
82     to->tensors.emplace_back(t.dtype());
83     if (t.dtype() != DT_INVALID) {
84       TF_RETURN_IF_ERROR(copy(t, &to->tensors.back()));
85     }
86   }
87   return Status::OK();
88 }
89 
90 #define REGISTER_LIST_COPY(DIRECTION)                                         \
91   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
92                                                        TensorListDeviceCopy)
93 
94 REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
95 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
96 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
97 
98 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
99 
Decode(const VariantTensorData & data)100 bool TensorList::Decode(const VariantTensorData& data) {
101   // TODO(srbs): Change the signature to Decode(VariantTensorData data) so
102   // that we do not have to copy each tensor individually below. This would
103   // require changing VariantTensorData::tensors() as well.
104   string metadata;
105   data.get_metadata(&metadata);
106   uint64 scratch;
107   StringPiece iter(metadata);
108   std::vector<size_t> invalid_indices;
109   core::GetVarint64(&iter, &scratch);
110   size_t num_invalid_tensors = static_cast<size_t>(scratch);
111   invalid_indices.resize(num_invalid_tensors);
112   for (size_t i = 0; i < num_invalid_tensors; i++) {
113     core::GetVarint64(&iter, &scratch);
114     invalid_indices[i] = static_cast<size_t>(scratch);
115   }
116 
117   size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
118   tensors.reserve(total_num_tensors);
119   std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
120   std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
121   for (size_t i = 0; i < total_num_tensors; i++) {
122     if (invalid_indices_it != invalid_indices.end() &&
123         *invalid_indices_it == i) {
124       tensors.emplace_back(Tensor(DT_INVALID));
125       invalid_indices_it++;
126     } else if (tensors_it != data.tensors().end()) {
127       tensors.emplace_back(*tensors_it);
128       tensors_it++;
129     } else {
130       // VariantTensorData is corrupted.
131       return false;
132     }
133   }
134 
135   core::GetVarint64(&iter, &scratch);
136   element_dtype = static_cast<DataType>(scratch);
137   core::GetVarint64(&iter, &scratch);
138   max_num_elements = static_cast<int>(scratch);
139   TensorShapeProto element_shape_proto;
140   element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
141   element_shape = PartialTensorShape(element_shape_proto);
142   return true;
143 }
144 
TensorShapeFromTensor(const Tensor & t,PartialTensorShape * out)145 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
146   if (t.shape() == TensorShape({})) {
147     if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
148         (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) {
149       *out = PartialTensorShape();
150       return Status::OK();
151     }
152     return errors::InvalidArgument(
153         "The only valid scalar shape tensor is the fully unknown shape "
154         "specified as -1.");
155   }
156   if (t.dtype() == DT_INT32) {
157     return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
158                                                 t.NumElements(), out);
159   } else if (t.dtype() == DT_INT64) {
160     return PartialTensorShape::MakePartialShape(t.vec<int64>().data(),
161                                                 t.NumElements(), out);
162   }
163   return errors::InvalidArgument(
164       "Expected an int32 or int64 shape tensor; found ",
165       DataTypeString(t.dtype()));
166 }
167 
GetElementShapeFromInput(OpKernelContext * c,const TensorList & tensor_list,int index,PartialTensorShape * element_shape)168 Status GetElementShapeFromInput(OpKernelContext* c,
169                                 const TensorList& tensor_list, int index,
170                                 PartialTensorShape* element_shape) {
171   TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape));
172   // Check that `element_shape` and `tensor_list.element_shape` are
173   // compatible and store the merged shape in `element_shape`.
174   PartialTensorShape tmp = *element_shape;
175   TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape));
176   return Status::OK();
177 }
178 
GetInputList(OpKernelContext * c,int index,const TensorList ** list)179 Status GetInputList(OpKernelContext* c, int index, const TensorList** list) {
180   if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
181     return errors::InvalidArgument("Input list must be a scalar saw: ",
182                                    c->input(index).shape().DebugString());
183   }
184   const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>();
185   if (l == nullptr) {
186     return errors::InvalidArgument(
187         "Input handle is not a list. Saw: '",
188         c->input(index).scalar<Variant>()().DebugString(), "'");
189   }
190   *list = l;
191   return Status::OK();
192 }
193 
ForwardInputOrCreateNewList(OpKernelContext * c,int32 input_index,int32 output_index,const TensorList & input_list,TensorList ** output_list)194 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
195                                    int32 output_index,
196                                    const TensorList& input_list,
197                                    TensorList** output_list) {
198   // Attempt to forward the input tensor to the output if possible.
199   AllocatorAttributes attr;
200   attr.set_on_host(true);
201   std::unique_ptr<Tensor> maybe_output =
202       c->forward_input(input_index, output_index, DT_VARIANT, TensorShape{},
203                        c->input_memory_type(input_index), attr);
204   Tensor* output_tensor;
205   if (maybe_output != nullptr) {
206     // Woohoo, forwarding succeeded!
207     output_tensor = maybe_output.get();
208   } else {
209     // If forwarding is not possible allocate a new output tensor and copy
210     // the `input_list` to it.
211     TF_RETURN_IF_ERROR(
212         c->allocate_output(output_index, {}, &output_tensor, attr));
213     output_tensor->scalar<Variant>()() = input_list;
214   }
215   *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
216   return Status::OK();
217 }
218 
219 class EmptyTensorList : public OpKernel {
220  public:
EmptyTensorList(OpKernelConstruction * ctx)221   explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
222     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
223   }
224 
Compute(OpKernelContext * ctx)225   void Compute(OpKernelContext* ctx) override {
226     const Tensor& max_num_elements_t = ctx->input(1);
227     OP_REQUIRES(
228         ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()),
229         errors::InvalidArgument(
230             "max_num_elements expected to be a scalar ",
231             "but got shape: ", max_num_elements_t.shape().DebugString()));
232     Tensor* result;
233     AllocatorAttributes attr;
234     attr.set_on_host(true);
235     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
236     TensorList empty;
237     empty.element_dtype = element_dtype_;
238     empty.max_num_elements = max_num_elements_t.scalar<int32>()();
239     PartialTensorShape element_shape;
240     OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
241     empty.element_shape = element_shape;
242     result->scalar<Variant>()() = std::move(empty);
243   }
244 
245  private:
246   DataType element_dtype_;
247 };
248 
249 const char TensorList::kTypeName[] = "tensorflow::TensorList";
250 
251 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
252                         EmptyTensorList);
253 
254 #if GOOGLE_CUDA
255 
256 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
257                             .Device(DEVICE_GPU)
258                             .HostMemory("element_shape")
259                             .HostMemory("max_num_elements"),
260                         EmptyTensorList);
261 
262 #endif  // GOOGLE_CUDA
263 
264 class TensorListPushBack : public OpKernel {
265  public:
TensorListPushBack(OpKernelConstruction * c)266   explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
267     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
268   }
269 
~TensorListPushBack()270   ~TensorListPushBack() override {}
271 
Compute(OpKernelContext * c)272   void Compute(OpKernelContext* c) override {
273     const Tensor& input = c->input(1);
274     OP_REQUIRES(c, element_dtype_ == input.dtype(),
275                 errors::InvalidArgument("Invalid data types; list elements ",
276                                         DataTypeString(element_dtype_),
277                                         " but tried to append ",
278                                         DataTypeString(input.dtype())));
279 
280     const TensorList* l = nullptr;
281     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
282     OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
283                 errors::InvalidArgument(
284                     "Tried to append a tensor with incompatible shape to a "
285                     "list. Op element shape: ",
286                     input.shape().DebugString(),
287                     " list shape: ", l->element_shape.DebugString()));
288     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
289                 errors::InvalidArgument("Invalid data types; op elements ",
290                                         DataTypeString(element_dtype_),
291                                         " but list elements ",
292                                         DataTypeString(l->element_dtype)));
293 
294     if (l->max_num_elements != -1) {
295       OP_REQUIRES(
296           c, l->tensors.size() < l->max_num_elements,
297           errors::InvalidArgument("Tried to push item into a full list",
298                                   " list size: ", l->tensors.size(),
299                                   " max_num_elements: ", l->max_num_elements));
300     }
301 
302     TensorList* output_list = nullptr;
303     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
304     output_list->tensors.push_back(input);
305   }
306 
307  private:
308   DataType element_dtype_;
309 };
310 
311 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
312                         TensorListPushBack);
313 
314 #if GOOGLE_CUDA
315 
316 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
317                         TensorListPushBack);
318 
319 #endif  // GOOGLE_CUDA
320 
321 class TensorListLength : public OpKernel {
322  public:
TensorListLength(OpKernelConstruction * c)323   explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
~TensorListLength()324   ~TensorListLength() override {}
325 
Compute(OpKernelContext * c)326   void Compute(OpKernelContext* c) override {
327     const TensorList* l = nullptr;
328     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
329     Tensor* result;
330     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
331     result->scalar<int32>()() = l->tensors.size();
332   }
333 };
334 
335 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
336                         TensorListLength);
337 
338 #if GOOGLE_CUDA
339 
340 REGISTER_KERNEL_BUILDER(
341     Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
342     TensorListLength);
343 
344 #endif  // GOOGLE_CUDA
345 
346 class TensorListElementShape : public OpKernel {
347  public:
TensorListElementShape(OpKernelConstruction * c)348   explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
349 
Compute(OpKernelContext * c)350   void Compute(OpKernelContext* c) override {
351     const TensorList* l = nullptr;
352     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
353     Tensor* result;
354     if (l->element_shape.unknown_rank()) {
355       OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
356       if (result->dtype() == DT_INT32) {
357         result->scalar<int32>()() = -1;
358       } else {
359         result->scalar<int64>()() = -1;
360       }
361     } else {
362       OP_REQUIRES_OK(c, c->allocate_output(
363                             0, TensorShape{l->element_shape.dims()}, &result));
364       for (int i = 0; i < l->element_shape.dims(); ++i) {
365         if (result->dtype() == DT_INT32) {
366           result->flat<int32>()(i) = l->element_shape.dim_size(i);
367         } else {
368           result->flat<int64>()(i) = l->element_shape.dim_size(i);
369         }
370       }
371     }
372   }
373 };
374 
375 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
376                         TensorListElementShape);
377 
378 class TensorListReserve : public OpKernel {
379  public:
TensorListReserve(OpKernelConstruction * c)380   explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
381     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
382   }
383 
Compute(OpKernelContext * c)384   void Compute(OpKernelContext* c) override {
385     PartialTensorShape element_shape;
386     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
387     int32 num_elements = c->input(1).scalar<int32>()();
388     TensorList output;
389     output.element_shape = element_shape;
390     output.element_dtype = element_dtype_;
391     output.tensors.resize(num_elements, Tensor(DT_INVALID));
392     Tensor* result;
393     AllocatorAttributes attr;
394     attr.set_on_host(true);
395     OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
396     result->scalar<Variant>()() = std::move(output);
397   }
398 
399  private:
400   DataType element_dtype_;
401 };
402 
403 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
404                         TensorListReserve);
405 
406 #if GOOGLE_CUDA
407 
408 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
409                             .Device(DEVICE_GPU)
410                             .HostMemory("element_shape")
411                             .HostMemory("num_elements"),
412                         TensorListReserve);
413 
414 #endif  // GOOGLE_CUDA
415 class TensorListResize : public OpKernel {
416  public:
TensorListResize(OpKernelConstruction * c)417   explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {}
418 
Compute(OpKernelContext * c)419   void Compute(OpKernelContext* c) override {
420     const TensorList* input_list = nullptr;
421     OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
422     int32 size = c->input(1).scalar<int32>()();
423     OP_REQUIRES(
424         c, size >= 0,
425         errors::InvalidArgument(
426             "TensorListSlice expects size to be non-negative. Got: ", size));
427 
428     AllocatorAttributes attr;
429     attr.set_on_host(true);
430     std::unique_ptr<Tensor> maybe_result = c->forward_input(
431         0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
432     if (maybe_result != nullptr) {
433       maybe_result->scalar<Variant>()().get<TensorList>()->tensors.resize(
434           size, Tensor(DT_INVALID));
435     } else {
436       Tensor* result;
437       OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
438       TensorList output_list;
439       output_list.element_shape = input_list->element_shape;
440       output_list.element_dtype = input_list->element_dtype;
441       output_list.max_num_elements = input_list->max_num_elements;
442       if (size > input_list->tensors.size()) {
443         output_list.tensors.insert(output_list.tensors.begin(),
444                                    input_list->tensors.begin(),
445                                    input_list->tensors.end());
446         // Add DT_INVALID tensors to the end of the list if the requested size
447         // is larger than the list length.
448         output_list.tensors.resize(size, Tensor(DT_INVALID));
449       } else {
450         output_list.tensors.insert(output_list.tensors.begin(),
451                                    input_list->tensors.begin(),
452                                    input_list->tensors.begin() + size);
453       }
454       result->scalar<Variant>()() = std::move(output_list);
455     }
456   }
457 };
458 
459 REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU),
460                         TensorListResize);
461 
462 #if GOOGLE_CUDA
463 
464 REGISTER_KERNEL_BUILDER(
465     Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"),
466     TensorListResize);
467 
468 #endif  // GOOGLE_CUDA
469 
470 class TensorListSetItem : public OpKernel {
471  public:
TensorListSetItem(OpKernelConstruction * c)472   explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
473     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
474   }
475 
Compute(OpKernelContext * c)476   void Compute(OpKernelContext* c) override {
477     const TensorList* l = nullptr;
478     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
479     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
480                 errors::InvalidArgument("Invalid data types; op elements ",
481                                         DataTypeString(element_dtype_),
482                                         " but list elements ",
483                                         DataTypeString(l->element_dtype)));
484     int32 index = c->input(1).scalar<int32>()();
485     OP_REQUIRES(c, index < l->tensors.size(),
486                 errors::InvalidArgument("Trying to modify element ", index,
487                                         " in a list with ", l->tensors.size(),
488                                         " elements."));
489     const Tensor& value = c->input(2);
490     OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
491                 errors::InvalidArgument(
492                     "Tried to set a tensor with incompatible shape at a "
493                     "list index. Item element shape: ",
494                     value.shape().DebugString(),
495                     " list shape: ", l->element_shape.DebugString()));
496     TensorList* output_list = nullptr;
497     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
498     output_list->tensors[index] = value;
499   }
500 
501  private:
502   DataType element_dtype_;
503 };
504 
505 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
506                         TensorListSetItem);
507 
508 #if GOOGLE_CUDA
509 
510 #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T)                      \
511   REGISTER_KERNEL_BUILDER(Name("TensorListSetItem")               \
512                               .TypeConstraint<T>("element_dtype") \
513                               .Device(DEVICE_GPU)                 \
514                               .HostMemory("index"),               \
515                           TensorListSetItem);
516 
517 TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
518 TF_CALL_complex64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
519 TF_CALL_complex128(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
520 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
521 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
522 REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16)
523 #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU
524 
525 #endif  // GOOGLE_CUDA
526 
527 class TensorListConcatLists : public OpKernel {
528  public:
TensorListConcatLists(OpKernelConstruction * c)529   explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) {
530     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
531   }
532 
Compute(OpKernelContext * c)533   void Compute(OpKernelContext* c) override {
534     const TensorShape& tl_a_shape = c->input(0).shape();
535     const TensorShape& tl_b_shape = c->input(1).shape();
536     OP_REQUIRES(
537         c, tl_a_shape == tl_b_shape,
538         errors::InvalidArgument("Incompatible input TensorList tensor shapes: ",
539                                 tl_a_shape.DebugString(), " vs. ",
540                                 tl_b_shape.DebugString()));
541     AllocatorAttributes attr;
542     std::unique_ptr<Tensor> tl_alias = c->forward_input(
543         0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape,
544         DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
545 
546     // tl_a may be aliased by tl_alias.
547     const Tensor& tl_a = c->input(0);
548     const Tensor& tl_b = c->input(1);
549 
550     Tensor* output;
551     if (tl_alias) {
552       c->set_output(0, *tl_alias);
553       output = tl_alias.get();
554     } else {
555       attr.set_on_host(true);
556       OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
557     }
558 
559     auto output_t = output->flat<Variant>();
560     auto tl_a_t = tl_a.flat<Variant>();
561     auto tl_b_t = tl_b.flat<Variant>();
562 
563     for (int64 b = 0; b < tl_a.NumElements(); ++b) {
564       const TensorList* l_a = tl_a_t(b).get<TensorList>();
565       const TensorList* l_b = tl_b_t(b).get<TensorList>();
566       OP_REQUIRES(
567           c, l_a != nullptr,
568           errors::InvalidArgument("input_a is not a TensorList at index ", b,
569                                   ".  Saw: '", tl_a_t(b).DebugString(), "'"));
570       OP_REQUIRES(
571           c, l_b != nullptr,
572           errors::InvalidArgument("input_b is not a TensorList at index ", b,
573                                   ".  Saw: '", tl_b_t(b).DebugString(), "'"));
574       OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
575                   errors::InvalidArgument(
576                       "input_a[", b, "].dtype != element_dtype.  Saw: ",
577                       DataTypeString(l_a->element_dtype), " vs. ",
578                       DataTypeString(element_dtype_)));
579       OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
580                   errors::InvalidArgument(
581                       "input_b[", b, "].dtype != element_dtype.  Saw: ",
582                       DataTypeString(l_b->element_dtype), " vs. ",
583                       DataTypeString(element_dtype_)));
584       OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
585                   errors::InvalidArgument(
586                       "input_a and input_b TensorList element shapes are not "
587                       "identical at index ",
588                       b, ".  Saw ", l_a->element_shape.DebugString(), " vs. ",
589                       l_b->element_shape.DebugString()));
590       if (tl_alias) {
591         TensorList* out = output_t(b).get<TensorList>();
592         DCHECK(out != nullptr) << "Expected output to alias input_a, but it "
593                                   "doesn't contain a TensorList at index "
594                                << b;
595         std::copy(l_b->tensors.begin(), l_b->tensors.end(),
596                   std::back_inserter(out->tensors));
597       } else {
598         TensorList out = *l_a;
599         std::copy(l_b->tensors.begin(), l_b->tensors.end(),
600                   std::back_inserter(out.tensors));
601         output_t(b) = std::move(out);
602       }
603     }
604   }
605 
606  private:
607   DataType element_dtype_;
608 };
609 
610 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU),
611                         TensorListConcatLists);
612 
613 #if GOOGLE_CUDA
614 
615 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU),
616                         TensorListConcatLists);
617 
618 #endif  // GOOGLE_CUDA
619 
620 #define REGISTER_TENSOR_LIST_OPS_CPU(T)                                    \
621   REGISTER_KERNEL_BUILDER(Name("TensorListStack")                          \
622                               .TypeConstraint<T>("element_dtype")          \
623                               .Device(DEVICE_CPU),                         \
624                           TensorListStack<CPUDevice, T>)                   \
625   REGISTER_KERNEL_BUILDER(Name("TensorListGather")                         \
626                               .TypeConstraint<T>("element_dtype")          \
627                               .Device(DEVICE_CPU),                         \
628                           TensorListGather<CPUDevice, T>)                  \
629   REGISTER_KERNEL_BUILDER(Name("TensorListConcat")                         \
630                               .TypeConstraint<T>("element_dtype")          \
631                               .Device(DEVICE_CPU),                         \
632                           TensorListConcat<CPUDevice, T>)                  \
633   REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2")                       \
634                               .TypeConstraint<T>("element_dtype")          \
635                               .Device(DEVICE_CPU),                         \
636                           TensorListConcat<CPUDevice, T>)                  \
637   REGISTER_KERNEL_BUILDER(Name("TensorListGetItem")                        \
638                               .TypeConstraint<T>("element_dtype")          \
639                               .Device(DEVICE_CPU),                         \
640                           TensorListGetItem<CPUDevice, T>)                 \
641   REGISTER_KERNEL_BUILDER(Name("TensorListPopBack")                        \
642                               .TypeConstraint<T>("element_dtype")          \
643                               .Device(DEVICE_CPU),                         \
644                           TensorListPopBack<CPUDevice, T>)                 \
645   REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor")                     \
646                               .TypeConstraint<T>("element_dtype")          \
647                               .Device(DEVICE_CPU),                         \
648                           TensorListFromTensor<CPUDevice, T>)              \
649   REGISTER_KERNEL_BUILDER(Name("TensorListScatter")                        \
650                               .TypeConstraint<T>("element_dtype")          \
651                               .Device(DEVICE_CPU),                         \
652                           TensorListScatter<CPUDevice, T>)                 \
653   REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2")                      \
654                               .TypeConstraint<T>("element_dtype")          \
655                               .Device(DEVICE_CPU),                         \
656                           TensorListScatter<CPUDevice, T>)                 \
657   REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList")        \
658                               .TypeConstraint<T>("element_dtype")          \
659                               .Device(DEVICE_CPU),                         \
660                           TensorListScatterIntoExistingList<CPUDevice, T>) \
661   REGISTER_KERNEL_BUILDER(Name("TensorListSplit")                          \
662                               .TypeConstraint<T>("element_dtype")          \
663                               .Device(DEVICE_CPU),                         \
664                           TensorListSplit<CPUDevice, T>)                   \
665   REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch")                  \
666                               .TypeConstraint<T>("element_dtype")          \
667                               .Device(DEVICE_CPU),                         \
668                           TensorListPushBackBatch<CPUDevice, T>)
669 
670 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU);
671 REGISTER_TENSOR_LIST_OPS_CPU(quint8);
672 REGISTER_TENSOR_LIST_OPS_CPU(qint8);
673 REGISTER_TENSOR_LIST_OPS_CPU(quint16);
674 REGISTER_TENSOR_LIST_OPS_CPU(qint16);
675 REGISTER_TENSOR_LIST_OPS_CPU(qint32);
676 REGISTER_TENSOR_LIST_OPS_CPU(Variant);
677 
678 #undef REGISTER_TENSOR_LIST_OPS_CPU
679 
680 #define REGISTER_TENSOR_LIST_OPS_CPU(T)
681 
682 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
683                                           TensorList,
684                                           TensorListBinaryAdd<CPUDevice>);
685 
686 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
687                                          DEVICE_CPU, TensorList,
688                                          TensorListZerosLike<CPUDevice>);
689 
690 }  // namespace tensorflow
691