1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21
22 #include "tensorflow/core/kernels/list_kernels.h"
23
24 #include <limits>
25
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant.h"
32 #include "tensorflow/core/framework/variant_op_registry.h"
33 #include "tensorflow/core/kernels/concat_lib.h"
34 #include "tensorflow/core/lib/core/coding.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/util/util.h"
37
38 namespace tensorflow {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41
TensorShapeFromTensor(const Tensor & t,PartialTensorShape * out)42 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
43 if (t.shape() == TensorShape({})) {
44 if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
45 (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) {
46 *out = PartialTensorShape();
47 return Status::OK();
48 }
49 return errors::InvalidArgument(
50 "The only valid scalar shape tensor is the fully unknown shape "
51 "specified as -1.");
52 }
53 if (t.dtype() == DT_INT32) {
54 return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
55 t.NumElements(), out);
56 } else if (t.dtype() == DT_INT64) {
57 return PartialTensorShape::MakePartialShape(t.vec<int64>().data(),
58 t.NumElements(), out);
59 }
60 return errors::InvalidArgument(
61 "Expected an int32 or int64 shape tensor; found ",
62 DataTypeString(t.dtype()));
63 }
64
GetElementShapeFromInput(OpKernelContext * c,const TensorList & tensor_list,int index,PartialTensorShape * element_shape)65 Status GetElementShapeFromInput(OpKernelContext* c,
66 const TensorList& tensor_list, int index,
67 PartialTensorShape* element_shape) {
68 TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape));
69 // Check that `element_shape` and `tensor_list.element_shape` are
70 // compatible and store the merged shape in `element_shape`.
71 PartialTensorShape tmp = *element_shape;
72 TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape));
73 return Status::OK();
74 }
75
GetInputList(OpKernelContext * c,int index,const TensorList ** list)76 Status GetInputList(OpKernelContext* c, int index, const TensorList** list) {
77 if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
78 return errors::InvalidArgument("Input list must be a scalar saw: ",
79 c->input(index).shape().DebugString());
80 }
81 const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>();
82 if (l == nullptr) {
83 return errors::InvalidArgument(
84 "Input handle is not a list. Saw: '",
85 c->input(index).scalar<Variant>()().DebugString(), "'");
86 }
87 *list = l;
88 return Status::OK();
89 }
90
ForwardInputOrCreateNewList(OpKernelContext * c,int32 input_index,int32 output_index,const TensorList & input_list,TensorList ** output_list)91 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
92 int32 output_index,
93 const TensorList& input_list,
94 TensorList** output_list) {
95 // Attempt to forward the input tensor to the output if possible.
96 std::unique_ptr<Tensor> maybe_output = c->forward_input(
97 input_index, output_index, DT_VARIANT, TensorShape{},
98 c->input_memory_type(input_index), AllocatorAttributes());
99 Tensor* output_tensor;
100 if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
101 maybe_output->NumElements() == 1) {
102 output_tensor = maybe_output.get();
103 TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>();
104 if (tmp_out == nullptr) {
105 return errors::InvalidArgument(
106 "Expected input ", input_index, " to be a TensorList but saw ",
107 output_tensor->scalar<Variant>()().TypeName());
108 }
109 if (tmp_out->RefCountIsOne()) {
110 // Woohoo, forwarding succeeded!
111 c->set_output(output_index, *output_tensor);
112 *output_list = tmp_out;
113 return Status::OK();
114 }
115 }
116
117 // If forwarding is not possible allocate a new output tensor and copy
118 // the `input_list` to it.
119 AllocatorAttributes attr;
120 attr.set_on_host(true);
121 TF_RETURN_IF_ERROR(
122 c->allocate_output(output_index, {}, &output_tensor, attr));
123 output_tensor->scalar<Variant>()() = input_list.Copy();
124
125 *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
126 return Status::OK();
127 }
128
129 class EmptyTensorList : public OpKernel {
130 public:
EmptyTensorList(OpKernelConstruction * ctx)131 explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
132 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
133 }
134
Compute(OpKernelContext * ctx)135 void Compute(OpKernelContext* ctx) override {
136 const Tensor& max_num_elements_t = ctx->input(1);
137 OP_REQUIRES(
138 ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()),
139 errors::InvalidArgument(
140 "max_num_elements expected to be a scalar ",
141 "but got shape: ", max_num_elements_t.shape().DebugString()));
142 Tensor* result;
143 AllocatorAttributes attr;
144 attr.set_on_host(true);
145 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
146 TensorList empty;
147 empty.element_dtype = element_dtype_;
148 empty.max_num_elements = max_num_elements_t.scalar<int32>()();
149 PartialTensorShape element_shape;
150 OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
151 empty.element_shape = element_shape;
152 result->scalar<Variant>()() = std::move(empty);
153 }
154
155 private:
156 DataType element_dtype_;
157 };
158
159 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
160 EmptyTensorList);
161
162 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
163
164 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
165 .Device(DEVICE_GPU)
166 .HostMemory("element_shape")
167 .HostMemory("max_num_elements"),
168 EmptyTensorList);
169
170 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
171
172 class TensorListPushBack : public OpKernel {
173 public:
TensorListPushBack(OpKernelConstruction * c)174 explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
175 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
176 }
177
~TensorListPushBack()178 ~TensorListPushBack() override {}
179
Compute(OpKernelContext * c)180 void Compute(OpKernelContext* c) override {
181 const Tensor& input = c->input(1);
182 OP_REQUIRES(c, element_dtype_ == input.dtype(),
183 errors::InvalidArgument("Invalid data types; list elements ",
184 DataTypeString(element_dtype_),
185 " but tried to append ",
186 DataTypeString(input.dtype())));
187
188 const TensorList* l = nullptr;
189 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
190 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
191 errors::InvalidArgument(
192 "Tried to append a tensor with incompatible shape to a "
193 "list. Op element shape: ",
194 input.shape().DebugString(),
195 " list shape: ", l->element_shape.DebugString()));
196 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
197 errors::InvalidArgument("Invalid data types; op elements ",
198 DataTypeString(element_dtype_),
199 " but list elements ",
200 DataTypeString(l->element_dtype)));
201
202 if (l->max_num_elements != -1) {
203 OP_REQUIRES(
204 c, l->tensors().size() < l->max_num_elements,
205 errors::InvalidArgument("Tried to push item into a full list",
206 " list size: ", l->tensors().size(),
207 " max_num_elements: ", l->max_num_elements));
208 }
209
210 TensorList* output_list = nullptr;
211 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
212 output_list->tensors().push_back(input);
213 }
214
215 private:
216 DataType element_dtype_;
217 };
218
219 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
220 TensorListPushBack);
221
222 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
223
224 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
225 TensorListPushBack);
226
227 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
228
229 class TensorListLength : public OpKernel {
230 public:
TensorListLength(OpKernelConstruction * c)231 explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
~TensorListLength()232 ~TensorListLength() override {}
233
Compute(OpKernelContext * c)234 void Compute(OpKernelContext* c) override {
235 const TensorList* l = nullptr;
236 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
237 Tensor* result;
238 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
239 result->scalar<int32>()() = l->tensors().size();
240 }
241 };
242
243 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
244 TensorListLength);
245
246 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
247
248 REGISTER_KERNEL_BUILDER(
249 Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
250 TensorListLength);
251
252 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
253
254 class TensorListElementShape : public OpKernel {
255 public:
TensorListElementShape(OpKernelConstruction * c)256 explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
257
Compute(OpKernelContext * c)258 void Compute(OpKernelContext* c) override {
259 const TensorList* l = nullptr;
260 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
261 Tensor* result;
262 if (l->element_shape.unknown_rank()) {
263 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
264 if (result->dtype() == DT_INT32) {
265 result->scalar<int32>()() = -1;
266 } else {
267 result->scalar<int64>()() = -1;
268 }
269 } else {
270 OP_REQUIRES_OK(c, c->allocate_output(
271 0, TensorShape{l->element_shape.dims()}, &result));
272 for (int i = 0; i < l->element_shape.dims(); ++i) {
273 if (result->dtype() == DT_INT32) {
274 result->flat<int32>()(i) = l->element_shape.dim_size(i);
275 } else {
276 result->flat<int64>()(i) = l->element_shape.dim_size(i);
277 }
278 }
279 }
280 }
281 };
282
283 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
284 TensorListElementShape);
285
286 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
287
288 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
289 .Device(DEVICE_GPU)
290 .HostMemory("element_shape"),
291 TensorListElementShape);
292
293 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
294
295 class TensorListReserve : public OpKernel {
296 public:
TensorListReserve(OpKernelConstruction * c)297 explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
298 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
299 }
300
Compute(OpKernelContext * c)301 void Compute(OpKernelContext* c) override {
302 PartialTensorShape element_shape;
303 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
304 int32 num_elements = c->input(1).scalar<int32>()();
305 TensorList output;
306 output.element_shape = element_shape;
307 output.element_dtype = element_dtype_;
308 output.tensors().resize(num_elements, Tensor(DT_INVALID));
309 Tensor* result;
310 AllocatorAttributes attr;
311 attr.set_on_host(true);
312 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
313 result->scalar<Variant>()() = std::move(output);
314 }
315
316 private:
317 DataType element_dtype_;
318 };
319
320 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
321 TensorListReserve);
322
323 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
324
325 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
326 .Device(DEVICE_GPU)
327 .HostMemory("element_shape")
328 .HostMemory("num_elements"),
329 TensorListReserve);
330
331 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
332 class TensorListResize : public OpKernel {
333 public:
TensorListResize(OpKernelConstruction * c)334 explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {}
335
Compute(OpKernelContext * c)336 void Compute(OpKernelContext* c) override {
337 const TensorList* input_list = nullptr;
338 OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
339 int32 size = c->input(1).scalar<int32>()();
340 OP_REQUIRES(
341 c, size >= 0,
342 errors::InvalidArgument(
343 "TensorListSlice expects size to be non-negative. Got: ", size));
344
345 std::unique_ptr<Tensor> maybe_result =
346 c->forward_input(0, 0, DT_VARIANT, TensorShape{},
347 c->input_memory_type(0), AllocatorAttributes());
348 if (maybe_result != nullptr) {
349 TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>();
350 if (out->RefCountIsOne()) {
351 // We are able to forward the input.
352 out->tensors().resize(size, Tensor(DT_INVALID));
353 c->set_output(0, *maybe_result);
354 return;
355 }
356 }
357
358 // We were not able to forward the input. Will have to resize from scratch.
359 Tensor* result;
360 AllocatorAttributes attr;
361 attr.set_on_host(true);
362 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
363 TensorList output_list;
364 output_list.element_shape = input_list->element_shape;
365 output_list.element_dtype = input_list->element_dtype;
366 output_list.max_num_elements = input_list->max_num_elements;
367 if (size > input_list->tensors().size()) {
368 output_list.tensors().insert(output_list.tensors().begin(),
369 input_list->tensors().begin(),
370 input_list->tensors().end());
371 // Add DT_INVALID tensors to the end of the list if the requested size
372 // is larger than the list length.
373 output_list.tensors().resize(size, Tensor(DT_INVALID));
374 } else {
375 output_list.tensors().insert(output_list.tensors().begin(),
376 input_list->tensors().begin(),
377 input_list->tensors().begin() + size);
378 }
379 result->scalar<Variant>()() = std::move(output_list);
380 }
381 };
382
383 REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU),
384 TensorListResize);
385
386 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
387
388 REGISTER_KERNEL_BUILDER(
389 Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"),
390 TensorListResize);
391
392 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
393
394 class TensorListSetItem : public OpKernel {
395 public:
TensorListSetItem(OpKernelConstruction * c)396 explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
397 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
398 }
399
Compute(OpKernelContext * c)400 void Compute(OpKernelContext* c) override {
401 const TensorList* l = nullptr;
402 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
403 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
404 errors::InvalidArgument("Invalid data types; op elements ",
405 DataTypeString(element_dtype_),
406 " but list elements ",
407 DataTypeString(l->element_dtype)));
408 int32 index = c->input(1).scalar<int32>()();
409 OP_REQUIRES(c, index < l->tensors().size(),
410 errors::InvalidArgument("Trying to modify element ", index,
411 " in a list with ", l->tensors().size(),
412 " elements."));
413 const Tensor& value = c->input(2);
414 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
415 errors::InvalidArgument(
416 "Tried to set a tensor with incompatible shape at a "
417 "list index. Item element shape: ",
418 value.shape().DebugString(),
419 " list shape: ", l->element_shape.DebugString()));
420 TensorList* output_list = nullptr;
421 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
422 output_list->tensors()[index] = value;
423 }
424
425 private:
426 DataType element_dtype_;
427 };
428
429 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
430 TensorListSetItem);
431
432 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
433
434 #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T) \
435 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \
436 .TypeConstraint<T>("element_dtype") \
437 .Device(DEVICE_GPU) \
438 .HostMemory("index"), \
439 TensorListSetItem);
440
441 TF_CALL_GPU_ALL_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
442 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
443 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
444 REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16)
445 #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU
446
447 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
448
449 class TensorListConcatLists : public OpKernel {
450 public:
TensorListConcatLists(OpKernelConstruction * c)451 explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) {
452 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
453 }
454
Compute(OpKernelContext * c)455 void Compute(OpKernelContext* c) override {
456 const TensorShape& tl_a_shape = c->input(0).shape();
457 const TensorShape& tl_b_shape = c->input(1).shape();
458 OP_REQUIRES(
459 c, tl_a_shape == tl_b_shape,
460 errors::InvalidArgument("Incompatible input TensorList tensor shapes: ",
461 tl_a_shape.DebugString(), " vs. ",
462 tl_b_shape.DebugString()));
463 AllocatorAttributes attr;
464 std::unique_ptr<Tensor> tl_alias = c->forward_input(
465 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape,
466 DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
467
468 // tl_a may be aliased by tl_alias.
469 const Tensor& tl_a = c->input(0);
470 const Tensor& tl_b = c->input(1);
471
472 Tensor* output = nullptr;
473 bool ok_to_alias = tl_alias != nullptr;
474 if (tl_alias && tl_alias->dtype() == DT_VARIANT &&
475 tl_alias->NumElements() > 0) {
476 auto tl_a_t = tl_alias->flat<Variant>();
477 for (int64 i = 0; i < tl_alias->NumElements(); ++i) {
478 TensorList* aliased = tl_a_t(i).get<TensorList>();
479 if (aliased == nullptr || !aliased->RefCountIsOne()) {
480 ok_to_alias = false;
481 break;
482 }
483 }
484 if (ok_to_alias) {
485 c->set_output(0, *tl_alias);
486 output = tl_alias.get();
487 }
488 }
489 if (!ok_to_alias) {
490 // Couldn't alias the entire Tensor. We'll be conservative and not try
491 // to alias individual batch entries.
492 attr.set_on_host(true);
493 OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
494 }
495
496 auto output_t = output->flat<Variant>();
497 auto tl_a_t = tl_a.flat<Variant>();
498 auto tl_b_t = tl_b.flat<Variant>();
499
500 for (int64 i = 0; i < tl_a.NumElements(); ++i) {
501 const TensorList* l_a = tl_a_t(i).get<TensorList>();
502 const TensorList* l_b = tl_b_t(i).get<TensorList>();
503 OP_REQUIRES(
504 c, l_a != nullptr,
505 errors::InvalidArgument("input_a is not a TensorList at index ", i,
506 ". Saw: '", tl_a_t(i).DebugString(), "'"));
507 OP_REQUIRES(
508 c, l_b != nullptr,
509 errors::InvalidArgument("input_b is not a TensorList at index ", i,
510 ". Saw: '", tl_b_t(i).DebugString(), "'"));
511 OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
512 errors::InvalidArgument(
513 "input_a[", i, "].dtype != element_dtype. Saw: ",
514 DataTypeString(l_a->element_dtype), " vs. ",
515 DataTypeString(element_dtype_)));
516 OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
517 errors::InvalidArgument(
518 "input_b[", i, "].dtype != element_dtype. Saw: ",
519 DataTypeString(l_b->element_dtype), " vs. ",
520 DataTypeString(element_dtype_)));
521 OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
522 errors::InvalidArgument(
523 "input_a and input_b TensorList element shapes are not "
524 "identical at index ",
525 i, ". Saw ", l_a->element_shape.DebugString(), " vs. ",
526 l_b->element_shape.DebugString()));
527 if (ok_to_alias) {
528 TensorList* out = output_t(i).get<TensorList>();
529 std::copy(l_b->tensors().begin(), l_b->tensors().end(),
530 std::back_inserter(out->tensors()));
531 } else {
532 TensorList out = l_a->Copy();
533 std::copy(l_b->tensors().begin(), l_b->tensors().end(),
534 std::back_inserter(out.tensors()));
535 output_t(i) = std::move(out);
536 }
537 }
538 }
539
540 private:
541 DataType element_dtype_;
542 };
543
544 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU),
545 TensorListConcatLists);
546
547 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
548
549 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU),
550 TensorListConcatLists);
551
552 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
553
554 #define REGISTER_TENSOR_LIST_OPS_CPU(T) \
555 REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
556 .TypeConstraint<T>("element_dtype") \
557 .Device(DEVICE_CPU), \
558 TensorListStack<CPUDevice, T>) \
559 REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
560 .TypeConstraint<T>("element_dtype") \
561 .Device(DEVICE_CPU), \
562 TensorListGather<CPUDevice, T>) \
563 REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \
564 .TypeConstraint<T>("element_dtype") \
565 .Device(DEVICE_CPU), \
566 TensorListConcat<CPUDevice, T>) \
567 REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
568 .TypeConstraint<T>("element_dtype") \
569 .Device(DEVICE_CPU), \
570 TensorListConcat<CPUDevice, T>) \
571 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \
572 .TypeConstraint<T>("element_dtype") \
573 .Device(DEVICE_CPU), \
574 TensorListGetItem<CPUDevice, T>) \
575 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \
576 .TypeConstraint<T>("element_dtype") \
577 .Device(DEVICE_CPU), \
578 TensorListPopBack<CPUDevice, T>) \
579 REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
580 .TypeConstraint<T>("element_dtype") \
581 .Device(DEVICE_CPU), \
582 TensorListFromTensor<CPUDevice, T>) \
583 REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
584 .TypeConstraint<T>("element_dtype") \
585 .Device(DEVICE_CPU), \
586 TensorListScatter<CPUDevice, T>) \
587 REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \
588 .TypeConstraint<T>("element_dtype") \
589 .Device(DEVICE_CPU), \
590 TensorListScatter<CPUDevice, T>) \
591 REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \
592 .TypeConstraint<T>("element_dtype") \
593 .Device(DEVICE_CPU), \
594 TensorListScatterIntoExistingList<CPUDevice, T>) \
595 REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \
596 .TypeConstraint<T>("element_dtype") \
597 .Device(DEVICE_CPU), \
598 TensorListSplit<CPUDevice, T>) \
599 REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
600 .TypeConstraint<T>("element_dtype") \
601 .Device(DEVICE_CPU), \
602 TensorListPushBackBatch<CPUDevice, T>)
603
604 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU);
605 REGISTER_TENSOR_LIST_OPS_CPU(quint8);
606 REGISTER_TENSOR_LIST_OPS_CPU(qint8);
607 REGISTER_TENSOR_LIST_OPS_CPU(quint16);
608 REGISTER_TENSOR_LIST_OPS_CPU(qint16);
609 REGISTER_TENSOR_LIST_OPS_CPU(qint32);
610 REGISTER_TENSOR_LIST_OPS_CPU(Variant);
611
612 #undef REGISTER_TENSOR_LIST_OPS_CPU
613
614 #define REGISTER_TENSOR_LIST_OPS_CPU(T)
615
616 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
617 TensorList,
618 TensorListBinaryAdd<CPUDevice>);
619
620 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
621 DEVICE_CPU, TensorList,
622 TensorListZerosLike<CPUDevice>);
623
624 } // namespace tensorflow
625