1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 #define EIGEN_USE_GPU
20 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21
22 #include "tensorflow/core/kernels/list_kernels.h"
23
24 #include <algorithm>
25 #include <iterator>
26 #include <limits>
27 #include <memory>
28 #include <utility>
29
30 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_types.h"
36 #include "tensorflow/core/framework/variant.h"
37 #include "tensorflow/core/framework/variant_op_registry.h"
38 #include "tensorflow/core/platform/errors.h"
39
40 namespace tensorflow {
41
42 typedef Eigen::ThreadPoolDevice CPUDevice;
43
TensorShapeFromTensor(const Tensor & t,PartialTensorShape * out)44 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
45 if (t.shape() == TensorShape({})) {
46 if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
47 (t.dtype() == DT_INT64 && t.scalar<int64_t>()() == -1)) {
48 *out = PartialTensorShape();
49 return OkStatus();
50 }
51 return errors::InvalidArgument(
52 "The only valid scalar shape tensor is the fully unknown shape "
53 "specified as -1.");
54 } else if (t.shape().dims() != 1) {
55 return errors::InvalidArgument("Shape must be at most rank 1 but is rank ",
56 t.shape().dims());
57 }
58 if (t.dtype() == DT_INT32) {
59 return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
60 t.NumElements(), out);
61 } else if (t.dtype() == DT_INT64) {
62 return PartialTensorShape::MakePartialShape(t.vec<int64_t>().data(),
63 t.NumElements(), out);
64 }
65 return errors::InvalidArgument(
66 "Expected an int32 or int64 shape tensor; found ",
67 DataTypeString(t.dtype()));
68 }
69
GetElementShapeFromInput(OpKernelContext * c,const TensorList & tensor_list,int index,PartialTensorShape * element_shape)70 Status GetElementShapeFromInput(OpKernelContext* c,
71 const TensorList& tensor_list, int index,
72 PartialTensorShape* element_shape) {
73 TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape));
74 // Check that `element_shape` and `tensor_list.element_shape` are
75 // compatible and store the merged shape in `element_shape`.
76 PartialTensorShape tmp = *element_shape;
77 TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape));
78 return OkStatus();
79 }
80
GetInputList(OpKernelContext * c,int index,const TensorList ** list)81 Status GetInputList(OpKernelContext* c, int index, const TensorList** list) {
82 if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
83 return errors::InvalidArgument("Input list must be a scalar saw: ",
84 c->input(index).shape().DebugString());
85 }
86 const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>();
87 if (l == nullptr) {
88 return errors::InvalidArgument(
89 "Input handle is not a list. Saw: '",
90 c->input(index).scalar<Variant>()().DebugString(), "'");
91 }
92 *list = l;
93 return OkStatus();
94 }
95
ForwardInputOrCreateNewList(OpKernelContext * c,int32_t input_index,int32_t output_index,const TensorList & input_list,TensorList ** output_list)96 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index,
97 int32_t output_index,
98 const TensorList& input_list,
99 TensorList** output_list) {
100 // Attempt to forward the input tensor to the output if possible.
101 std::unique_ptr<Tensor> maybe_output = c->forward_input(
102 input_index, output_index, DT_VARIANT, TensorShape{},
103 c->input_memory_type(input_index), AllocatorAttributes());
104 Tensor* output_tensor;
105 if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
106 maybe_output->NumElements() == 1) {
107 output_tensor = maybe_output.get();
108 TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>();
109 if (tmp_out == nullptr) {
110 return errors::InvalidArgument(
111 "Expected input ", input_index, " to be a TensorList but saw ",
112 output_tensor->scalar<Variant>()().TypeName());
113 }
114 if (tmp_out->RefCountIsOne()) {
115 // Woohoo, forwarding succeeded!
116 c->set_output(output_index, *output_tensor);
117 *output_list = tmp_out;
118 return OkStatus();
119 }
120 }
121
122 // If forwarding is not possible allocate a new output tensor and copy
123 // the `input_list` to it.
124 AllocatorAttributes attr;
125 attr.set_on_host(true);
126 TF_RETURN_IF_ERROR(
127 c->allocate_output(output_index, {}, &output_tensor, attr));
128 output_tensor->scalar<Variant>()() = input_list.Copy();
129
130 *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
131 return OkStatus();
132 }
133
134 class EmptyTensorList : public OpKernel {
135 public:
EmptyTensorList(OpKernelConstruction * ctx)136 explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
137 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
138 }
139
Compute(OpKernelContext * ctx)140 void Compute(OpKernelContext* ctx) override {
141 const Tensor& max_num_elements_t = ctx->input(1);
142 OP_REQUIRES(
143 ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()),
144 errors::InvalidArgument(
145 "max_num_elements expected to be a scalar ",
146 "but got shape: ", max_num_elements_t.shape().DebugString()));
147 Tensor* result;
148 AllocatorAttributes attr;
149 attr.set_on_host(true);
150 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
151 TensorList empty;
152 empty.element_dtype = element_dtype_;
153 empty.max_num_elements = max_num_elements_t.scalar<int32>()();
154 PartialTensorShape element_shape;
155 OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
156 empty.element_shape = element_shape;
157 result->scalar<Variant>()() = std::move(empty);
158 }
159
160 private:
161 DataType element_dtype_;
162 };
163
164 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
165 EmptyTensorList);
166
167 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
168
169 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
170 .Device(DEVICE_GPU)
171 .HostMemory("element_shape")
172 .HostMemory("max_num_elements"),
173 EmptyTensorList);
174
175 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
176
177 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
178 .Device(DEVICE_DEFAULT)
179 .HostMemory("element_shape")
180 .HostMemory("max_num_elements"),
181 EmptyTensorList);
182
183 class TensorListPushBack : public OpKernel {
184 public:
TensorListPushBack(OpKernelConstruction * c)185 explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
186 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
187 }
188
~TensorListPushBack()189 ~TensorListPushBack() override {}
190
Compute(OpKernelContext * c)191 void Compute(OpKernelContext* c) override {
192 const Tensor& input = c->input(1);
193 OP_REQUIRES(c, element_dtype_ == input.dtype(),
194 errors::InvalidArgument("Invalid data types; list elements ",
195 DataTypeString(element_dtype_),
196 " but tried to append ",
197 DataTypeString(input.dtype())));
198
199 const TensorList* l = nullptr;
200 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
201 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
202 errors::InvalidArgument(
203 "Tried to append a tensor with incompatible shape to a "
204 "list. Op element shape: ",
205 input.shape().DebugString(),
206 " list shape: ", l->element_shape.DebugString()));
207 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
208 errors::InvalidArgument("Invalid data types; op elements ",
209 DataTypeString(element_dtype_),
210 " but list elements ",
211 DataTypeString(l->element_dtype)));
212
213 if (l->max_num_elements != -1) {
214 OP_REQUIRES(
215 c, l->tensors().size() < l->max_num_elements,
216 errors::InvalidArgument("Tried to push item into a full list",
217 " list size: ", l->tensors().size(),
218 " max_num_elements: ", l->max_num_elements));
219 }
220
221 TensorList* output_list = nullptr;
222 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
223 output_list->tensors().push_back(input);
224 }
225
226 private:
227 DataType element_dtype_;
228 };
229
230 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
231 TensorListPushBack);
232
233 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
234
235 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
236 TensorListPushBack);
237
238 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
239
240 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_DEFAULT),
241 TensorListPushBack);
242
243 class TensorListLength : public OpKernel {
244 public:
TensorListLength(OpKernelConstruction * c)245 explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
~TensorListLength()246 ~TensorListLength() override {}
247
Compute(OpKernelContext * c)248 void Compute(OpKernelContext* c) override {
249 const TensorList* l = nullptr;
250 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
251 Tensor* result;
252 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
253 result->scalar<int32>()() = l->tensors().size();
254 }
255 };
256
257 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
258 TensorListLength);
259
260 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
261
262 REGISTER_KERNEL_BUILDER(
263 Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
264 TensorListLength);
265
266 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
267
268 REGISTER_KERNEL_BUILDER(
269 Name("TensorListLength").Device(DEVICE_DEFAULT).HostMemory("length"),
270 TensorListLength);
271
272 class TensorListElementShape : public OpKernel {
273 public:
TensorListElementShape(OpKernelConstruction * c)274 explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
275
Compute(OpKernelContext * c)276 void Compute(OpKernelContext* c) override {
277 const TensorList* l = nullptr;
278 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
279 Tensor* result;
280 if (l->element_shape.unknown_rank()) {
281 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
282 if (result->dtype() == DT_INT32) {
283 result->scalar<int32>()() = -1;
284 } else {
285 result->scalar<int64_t>()() = -1;
286 }
287 } else {
288 OP_REQUIRES_OK(c, c->allocate_output(
289 0, TensorShape{l->element_shape.dims()}, &result));
290 for (int i = 0; i < l->element_shape.dims(); ++i) {
291 if (result->dtype() == DT_INT32) {
292 result->flat<int32>()(i) = l->element_shape.dim_size(i);
293 } else {
294 result->flat<int64_t>()(i) = l->element_shape.dim_size(i);
295 }
296 }
297 }
298 }
299 };
300
301 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
302 TensorListElementShape);
303
304 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
305
306 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
307 .Device(DEVICE_GPU)
308 .HostMemory("element_shape"),
309 TensorListElementShape);
310
311 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
312
313 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape")
314 .Device(DEVICE_DEFAULT)
315 .HostMemory("element_shape"),
316 TensorListElementShape);
317
318 class TensorListReserve : public OpKernel {
319 public:
TensorListReserve(OpKernelConstruction * c)320 explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
321 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
322 }
323
Compute(OpKernelContext * c)324 void Compute(OpKernelContext* c) override {
325 PartialTensorShape element_shape;
326 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
327 OP_REQUIRES(
328 c, TensorShapeUtils::IsScalar(c->input(1).shape()),
329 errors::InvalidArgument(
330 "The num_elements to reserve must be a tensor size 1, but got ",
331 c->input(1).shape()));
332 int32_t num_elements = c->input(1).scalar<int32>()();
333 OP_REQUIRES(c, num_elements >= 0,
334 errors::InvalidArgument("The num_elements to reserve must be a "
335 "non negative number, but got ",
336 num_elements));
337 TensorList output;
338 output.element_shape = element_shape;
339 output.element_dtype = element_dtype_;
340 output.tensors().resize(num_elements, Tensor(DT_INVALID));
341 Tensor* result;
342 AllocatorAttributes attr;
343 attr.set_on_host(true);
344 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
345 result->scalar<Variant>()() = std::move(output);
346 }
347
348 private:
349 DataType element_dtype_;
350 };
351
352 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
353 TensorListReserve);
354
355 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
356
357 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
358 .Device(DEVICE_GPU)
359 .HostMemory("element_shape")
360 .HostMemory("num_elements"),
361 TensorListReserve);
362
363 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
364
365 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
366 .Device(DEVICE_DEFAULT)
367 .HostMemory("element_shape")
368 .HostMemory("num_elements"),
369 TensorListReserve);
370
371 class TensorListResize : public OpKernel {
372 public:
TensorListResize(OpKernelConstruction * c)373 explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {}
374
Compute(OpKernelContext * c)375 void Compute(OpKernelContext* c) override {
376 const TensorList* input_list = nullptr;
377 OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
378 int32_t size = c->input(1).scalar<int32>()();
379 OP_REQUIRES(
380 c, size >= 0,
381 errors::InvalidArgument(
382 "TensorListSlice expects size to be non-negative. Got: ", size));
383
384 std::unique_ptr<Tensor> maybe_result =
385 c->forward_input(0, 0, DT_VARIANT, TensorShape{},
386 c->input_memory_type(0), AllocatorAttributes());
387 if (maybe_result != nullptr) {
388 TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>();
389 if (out->RefCountIsOne()) {
390 // We are able to forward the input.
391 out->tensors().resize(size, Tensor(DT_INVALID));
392 c->set_output(0, *maybe_result);
393 return;
394 }
395 }
396
397 // We were not able to forward the input. Will have to resize from scratch.
398 Tensor* result;
399 AllocatorAttributes attr;
400 attr.set_on_host(true);
401 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
402 TensorList output_list;
403 output_list.element_shape = input_list->element_shape;
404 output_list.element_dtype = input_list->element_dtype;
405 output_list.max_num_elements = input_list->max_num_elements;
406 if (size > input_list->tensors().size()) {
407 output_list.tensors().insert(output_list.tensors().begin(),
408 input_list->tensors().begin(),
409 input_list->tensors().end());
410 // Add DT_INVALID tensors to the end of the list if the requested size
411 // is larger than the list length.
412 output_list.tensors().resize(size, Tensor(DT_INVALID));
413 } else {
414 output_list.tensors().insert(output_list.tensors().begin(),
415 input_list->tensors().begin(),
416 input_list->tensors().begin() + size);
417 }
418 result->scalar<Variant>()() = std::move(output_list);
419 }
420 };
421
422 REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU),
423 TensorListResize);
424
425 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
426
427 REGISTER_KERNEL_BUILDER(
428 Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"),
429 TensorListResize);
430
431 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
432
433 REGISTER_KERNEL_BUILDER(
434 Name("TensorListResize").Device(DEVICE_DEFAULT).HostMemory("size"),
435 TensorListResize);
436
437 class TensorListSetItem : public OpKernel {
438 public:
TensorListSetItem(OpKernelConstruction * c)439 explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
440 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
441 }
442
Compute(OpKernelContext * c)443 void Compute(OpKernelContext* c) override {
444 const TensorList* l = nullptr;
445 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
446 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
447 errors::InvalidArgument("Invalid data types; op elements ",
448 DataTypeString(element_dtype_),
449 " but list elements ",
450 DataTypeString(l->element_dtype)));
451 int32_t index = c->input(1).scalar<int32>()();
452 OP_REQUIRES(c, index < l->tensors().size(),
453 errors::InvalidArgument("Trying to modify element ", index,
454 " in a list with ", l->tensors().size(),
455 " elements."));
456 const Tensor& value = c->input(2);
457 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
458 errors::InvalidArgument(
459 "Tried to set a tensor with incompatible shape at a "
460 "list index. Item element shape: ",
461 value.shape().DebugString(),
462 " list shape: ", l->element_shape.DebugString()));
463 TensorList* output_list = nullptr;
464 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
465 output_list->tensors()[index] = value;
466 }
467
468 private:
469 DataType element_dtype_;
470 };
471
472 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
473 TensorListSetItem);
474
475 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
476
477 #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T) \
478 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \
479 .TypeConstraint<T>("element_dtype") \
480 .Device(DEVICE_GPU) \
481 .HostMemory("index"), \
482 TensorListSetItem);
483
484 TF_CALL_GPU_ALL_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
485 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
486 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
487 REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16)
488 #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU
489
490 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
491
492 #define REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(T) \
493 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \
494 .TypeConstraint<T>("element_dtype") \
495 .Device(DEVICE_DEFAULT) \
496 .HostMemory("index"), \
497 TensorListSetItem);
498
499 TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
500 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
501 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT);
502 REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT(bfloat16)
503 #undef REGISTER_TENSOR_LIST_SET_ITEM_DEFAULT
504
505 class TensorListConcatLists : public OpKernel {
506 public:
TensorListConcatLists(OpKernelConstruction * c)507 explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) {
508 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
509 }
510
Compute(OpKernelContext * c)511 void Compute(OpKernelContext* c) override {
512 const TensorShape& tl_a_shape = c->input(0).shape();
513 const TensorShape& tl_b_shape = c->input(1).shape();
514 OP_REQUIRES(
515 c, tl_a_shape == tl_b_shape,
516 errors::InvalidArgument("Incompatible input TensorList tensor shapes: ",
517 tl_a_shape.DebugString(), " vs. ",
518 tl_b_shape.DebugString()));
519 AllocatorAttributes attr;
520 std::unique_ptr<Tensor> tl_alias = c->forward_input(
521 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape,
522 DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
523
524 // tl_a may be aliased by tl_alias.
525 const Tensor& tl_a = c->input(0);
526 const Tensor& tl_b = c->input(1);
527
528 Tensor* output = nullptr;
529 bool ok_to_alias = tl_alias != nullptr;
530 if (tl_alias && tl_alias->dtype() == DT_VARIANT &&
531 tl_alias->NumElements() > 0) {
532 auto tl_a_t = tl_alias->flat<Variant>();
533 for (int64_t i = 0; i < tl_alias->NumElements(); ++i) {
534 TensorList* aliased = tl_a_t(i).get<TensorList>();
535 if (aliased == nullptr || !aliased->RefCountIsOne()) {
536 ok_to_alias = false;
537 break;
538 }
539 }
540 if (ok_to_alias) {
541 c->set_output(0, *tl_alias);
542 output = tl_alias.get();
543 }
544 }
545 if (!ok_to_alias) {
546 // Couldn't alias the entire Tensor. We'll be conservative and not try
547 // to alias individual batch entries.
548 attr.set_on_host(true);
549 OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
550 }
551
552 auto output_t = output->flat<Variant>();
553 auto tl_a_t = tl_a.flat<Variant>();
554 auto tl_b_t = tl_b.flat<Variant>();
555
556 for (int64_t i = 0; i < tl_a.NumElements(); ++i) {
557 const TensorList* l_a = tl_a_t(i).get<TensorList>();
558 const TensorList* l_b = tl_b_t(i).get<TensorList>();
559 OP_REQUIRES(
560 c, l_a != nullptr,
561 errors::InvalidArgument("input_a is not a TensorList at index ", i,
562 ". Saw: '", tl_a_t(i).DebugString(), "'"));
563 OP_REQUIRES(
564 c, l_b != nullptr,
565 errors::InvalidArgument("input_b is not a TensorList at index ", i,
566 ". Saw: '", tl_b_t(i).DebugString(), "'"));
567 OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
568 errors::InvalidArgument(
569 "input_a[", i, "].dtype != element_dtype. Saw: ",
570 DataTypeString(l_a->element_dtype), " vs. ",
571 DataTypeString(element_dtype_)));
572 OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
573 errors::InvalidArgument(
574 "input_b[", i, "].dtype != element_dtype. Saw: ",
575 DataTypeString(l_b->element_dtype), " vs. ",
576 DataTypeString(element_dtype_)));
577 OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
578 errors::InvalidArgument(
579 "input_a and input_b TensorList element shapes are not "
580 "identical at index ",
581 i, ". Saw ", l_a->element_shape.DebugString(), " vs. ",
582 l_b->element_shape.DebugString()));
583 if (ok_to_alias) {
584 TensorList* out = output_t(i).get<TensorList>();
585 std::copy(l_b->tensors().begin(), l_b->tensors().end(),
586 std::back_inserter(out->tensors()));
587 } else {
588 TensorList out = l_a->Copy();
589 std::copy(l_b->tensors().begin(), l_b->tensors().end(),
590 std::back_inserter(out.tensors()));
591 output_t(i) = std::move(out);
592 }
593 }
594 }
595
596 private:
597 DataType element_dtype_;
598 };
599
600 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU),
601 TensorListConcatLists);
602
603 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
604
605 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU),
606 TensorListConcatLists);
607
608 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
609
610 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_DEFAULT),
611 TensorListConcatLists);
612
613 #define REGISTER_TENSOR_LIST_OPS_CPU(T) \
614 REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
615 .TypeConstraint<T>("element_dtype") \
616 .Device(DEVICE_CPU), \
617 TensorListStack<CPUDevice, T>) \
618 REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
619 .TypeConstraint<T>("element_dtype") \
620 .Device(DEVICE_CPU), \
621 TensorListGather<CPUDevice, T>) \
622 REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \
623 .TypeConstraint<T>("element_dtype") \
624 .Device(DEVICE_CPU), \
625 TensorListConcat<CPUDevice, T>) \
626 REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
627 .TypeConstraint<T>("element_dtype") \
628 .Device(DEVICE_CPU), \
629 TensorListConcat<CPUDevice, T>) \
630 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \
631 .TypeConstraint<T>("element_dtype") \
632 .Device(DEVICE_CPU), \
633 TensorListGetItem<CPUDevice, T>) \
634 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \
635 .TypeConstraint<T>("element_dtype") \
636 .Device(DEVICE_CPU), \
637 TensorListPopBack<CPUDevice, T>) \
638 REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
639 .TypeConstraint<T>("element_dtype") \
640 .Device(DEVICE_CPU), \
641 TensorListFromTensor<CPUDevice, T>) \
642 REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
643 .TypeConstraint<T>("element_dtype") \
644 .Device(DEVICE_CPU), \
645 TensorListScatter<CPUDevice, T>) \
646 REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \
647 .TypeConstraint<T>("element_dtype") \
648 .Device(DEVICE_CPU), \
649 TensorListScatter<CPUDevice, T>) \
650 REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \
651 .TypeConstraint<T>("element_dtype") \
652 .Device(DEVICE_CPU), \
653 TensorListScatterIntoExistingList<CPUDevice, T>) \
654 REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \
655 .TypeConstraint<T>("element_dtype") \
656 .Device(DEVICE_CPU), \
657 TensorListSplit<CPUDevice, T>) \
658 REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
659 .TypeConstraint<T>("element_dtype") \
660 .Device(DEVICE_CPU), \
661 TensorListPushBackBatch<CPUDevice, T>)
662
663 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU);
664 REGISTER_TENSOR_LIST_OPS_CPU(quint8);
665 REGISTER_TENSOR_LIST_OPS_CPU(qint8);
666 REGISTER_TENSOR_LIST_OPS_CPU(quint16);
667 REGISTER_TENSOR_LIST_OPS_CPU(qint16);
668 REGISTER_TENSOR_LIST_OPS_CPU(qint32);
669 REGISTER_TENSOR_LIST_OPS_CPU(Variant);
670
671 #undef REGISTER_TENSOR_LIST_OPS_CPU
672
673 #define REGISTER_TENSOR_LIST_OPS_CPU(T)
674
675 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
676 TensorList,
677 TensorListBinaryAdd<CPUDevice>);
678
679 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
680 DEVICE_CPU, TensorList,
681 TensorListZerosLike<CPUDevice>);
682
TensorListDeviceCopy(const TensorList & from,TensorList * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)683 static Status TensorListDeviceCopy(
684 const TensorList& from, TensorList* to,
685 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
686 to->element_shape = from.element_shape;
687 to->element_dtype = from.element_dtype;
688 to->max_num_elements = from.max_num_elements;
689 to->tensors().reserve(from.tensors().size());
690 for (const Tensor& t : from.tensors()) {
691 to->tensors().emplace_back(t.dtype());
692 if (t.dtype() != DT_INVALID) {
693 TF_RETURN_IF_ERROR(copy(t, &to->tensors().back()));
694 }
695 }
696 return OkStatus();
697 }
698
699 #define REGISTER_LIST_COPY(DIRECTION) \
700 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
701 TensorListDeviceCopy)
702
703 REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
704 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
705 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
706
707 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
708
709 #define REGISTER_TENSOR_LIST_OPS_DEFAULT(T) \
710 REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
711 .TypeConstraint<T>("element_dtype") \
712 .HostMemory("element_shape") \
713 .Device(DEVICE_DEFAULT), \
714 TensorListStack<CPUDevice, T>) \
715 REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
716 .TypeConstraint<T>("element_dtype") \
717 .HostMemory("indices") \
718 .HostMemory("element_shape") \
719 .Device(DEVICE_DEFAULT), \
720 TensorListGather<CPUDevice, T>) \
721 REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \
722 .TypeConstraint<T>("element_dtype") \
723 .HostMemory("lengths") \
724 .Device(DEVICE_DEFAULT), \
725 TensorListConcat<CPUDevice, T>) \
726 REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
727 .TypeConstraint<T>("element_dtype") \
728 .HostMemory("leading_dims") \
729 .HostMemory("element_shape") \
730 .HostMemory("lengths") \
731 .Device(DEVICE_DEFAULT), \
732 TensorListConcat<CPUDevice, T>) \
733 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \
734 .TypeConstraint<T>("element_dtype") \
735 .Device(DEVICE_DEFAULT) \
736 .HostMemory("index") \
737 .HostMemory("element_shape"), \
738 TensorListGetItem<CPUDevice, T>) \
739 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \
740 .TypeConstraint<T>("element_dtype") \
741 .Device(DEVICE_DEFAULT) \
742 .HostMemory("element_shape"), \
743 TensorListPopBack<CPUDevice, T>) \
744 REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
745 .TypeConstraint<T>("element_dtype") \
746 .Device(DEVICE_DEFAULT), \
747 TensorListPushBackBatch<CPUDevice, T>) \
748 REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
749 .TypeConstraint<T>("element_dtype") \
750 .Device(DEVICE_DEFAULT) \
751 .HostMemory("element_shape"), \
752 TensorListFromTensor<CPUDevice, T>) \
753 REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
754 .TypeConstraint<T>("element_dtype") \
755 .Device(DEVICE_DEFAULT) \
756 .HostMemory("element_shape") \
757 .HostMemory("indices"), \
758 TensorListScatter<CPUDevice, T>) \
759 REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \
760 .TypeConstraint<T>("element_dtype") \
761 .Device(DEVICE_DEFAULT) \
762 .HostMemory("element_shape") \
763 .HostMemory("num_elements") \
764 .HostMemory("indices"), \
765 TensorListScatter<CPUDevice, T>) \
766 REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \
767 .TypeConstraint<T>("element_dtype") \
768 .Device(DEVICE_DEFAULT) \
769 .HostMemory("indices"), \
770 TensorListScatterIntoExistingList<CPUDevice, T>) \
771 REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \
772 .TypeConstraint<T>("element_dtype") \
773 .Device(DEVICE_DEFAULT) \
774 .HostMemory("element_shape") \
775 .HostMemory("lengths"), \
776 TensorListSplit<CPUDevice, T>)
777
778 TF_CALL_int32(REGISTER_TENSOR_LIST_OPS_DEFAULT);
779 TF_CALL_int64(REGISTER_TENSOR_LIST_OPS_DEFAULT);
780 TF_CALL_bfloat16(REGISTER_TENSOR_LIST_OPS_DEFAULT);
781 TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_OPS_DEFAULT);
782
783 #undef REGISTER_TENSOR_LIST_OPS_DEFAULT
784 } // namespace tensorflow
785