1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <limits>
17
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA
20 #define EIGEN_USE_GPU
21 #endif // GOOGLE_CUDA
22
23 #include "tensorflow/core/kernels/list_kernels.h"
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/framework/variant.h"
30 #include "tensorflow/core/framework/variant_op_registry.h"
31 #include "tensorflow/core/kernels/concat_lib.h"
32 #include "tensorflow/core/lib/core/coding.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/util/util.h"
35
36 namespace tensorflow {
37
38 typedef Eigen::ThreadPoolDevice CPUDevice;
39
40 // Variant compatible type for a list of tensors. This is mutable but instances
41 // should never be mutated after stored in a variant tensor.
TensorList(const TensorList & other)42 TensorList::TensorList(const TensorList& other)
43 : tensors(other.tensors),
44 element_shape(other.element_shape),
45 element_dtype(other.element_dtype),
46 max_num_elements(other.max_num_elements) {}
47
Encode(VariantTensorData * data) const48 void TensorList::Encode(VariantTensorData* data) const {
49 data->set_type_name(TypeName());
50 std::vector<size_t> invalid_indices;
51 for (size_t i = 0; i < tensors.size(); i++) {
52 if (tensors.at(i).dtype() != DT_INVALID) {
53 *data->add_tensors() = tensors.at(i);
54 } else {
55 invalid_indices.push_back(i);
56 }
57 }
58 string metadata;
59 // TODO(b/118838800): Add a proto for storing the metadata.
60 // Metadata format:
61 // <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto>
62 core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size()));
63 for (size_t i : invalid_indices) {
64 core::PutVarint64(&metadata, static_cast<uint64>(i));
65 }
66 core::PutVarint64(&metadata, static_cast<uint64>(element_dtype));
67 core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements));
68 TensorShapeProto element_shape_proto;
69 element_shape.AsProto(&element_shape_proto);
70 element_shape_proto.AppendToString(&metadata);
71 data->set_metadata(metadata);
72 }
73
TensorListDeviceCopy(const TensorList & from,TensorList * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)74 static Status TensorListDeviceCopy(
75 const TensorList& from, TensorList* to,
76 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
77 to->element_shape = from.element_shape;
78 to->element_dtype = from.element_dtype;
79 to->max_num_elements = from.max_num_elements;
80 to->tensors.reserve(from.tensors.size());
81 for (const Tensor& t : from.tensors) {
82 to->tensors.emplace_back(t.dtype());
83 if (t.dtype() != DT_INVALID) {
84 TF_RETURN_IF_ERROR(copy(t, &to->tensors.back()));
85 }
86 }
87 return Status::OK();
88 }
89
90 #define REGISTER_LIST_COPY(DIRECTION) \
91 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
92 TensorListDeviceCopy)
93
94 REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
95 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
96 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
97
98 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName);
99
Decode(const VariantTensorData & data)100 bool TensorList::Decode(const VariantTensorData& data) {
101 // TODO(srbs): Change the signature to Decode(VariantTensorData data) so
102 // that we do not have to copy each tensor individually below. This would
103 // require changing VariantTensorData::tensors() as well.
104 string metadata;
105 data.get_metadata(&metadata);
106 uint64 scratch;
107 StringPiece iter(metadata);
108 std::vector<size_t> invalid_indices;
109 core::GetVarint64(&iter, &scratch);
110 size_t num_invalid_tensors = static_cast<size_t>(scratch);
111 invalid_indices.resize(num_invalid_tensors);
112 for (size_t i = 0; i < num_invalid_tensors; i++) {
113 core::GetVarint64(&iter, &scratch);
114 invalid_indices[i] = static_cast<size_t>(scratch);
115 }
116
117 size_t total_num_tensors = data.tensors().size() + num_invalid_tensors;
118 tensors.reserve(total_num_tensors);
119 std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin();
120 std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin();
121 for (size_t i = 0; i < total_num_tensors; i++) {
122 if (invalid_indices_it != invalid_indices.end() &&
123 *invalid_indices_it == i) {
124 tensors.emplace_back(Tensor(DT_INVALID));
125 invalid_indices_it++;
126 } else if (tensors_it != data.tensors().end()) {
127 tensors.emplace_back(*tensors_it);
128 tensors_it++;
129 } else {
130 // VariantTensorData is corrupted.
131 return false;
132 }
133 }
134
135 core::GetVarint64(&iter, &scratch);
136 element_dtype = static_cast<DataType>(scratch);
137 core::GetVarint64(&iter, &scratch);
138 max_num_elements = static_cast<int>(scratch);
139 TensorShapeProto element_shape_proto;
140 element_shape_proto.ParseFromString(string(iter.data(), iter.size()));
141 element_shape = PartialTensorShape(element_shape_proto);
142 return true;
143 }
144
TensorShapeFromTensor(const Tensor & t,PartialTensorShape * out)145 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
146 if (t.shape() == TensorShape({})) {
147 if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
148 (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) {
149 *out = PartialTensorShape();
150 return Status::OK();
151 }
152 return errors::InvalidArgument(
153 "The only valid scalar shape tensor is the fully unknown shape "
154 "specified as -1.");
155 }
156 if (t.dtype() == DT_INT32) {
157 return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
158 t.NumElements(), out);
159 } else if (t.dtype() == DT_INT64) {
160 return PartialTensorShape::MakePartialShape(t.vec<int64>().data(),
161 t.NumElements(), out);
162 }
163 return errors::InvalidArgument(
164 "Expected an int32 or int64 shape tensor; found ",
165 DataTypeString(t.dtype()));
166 }
167
GetElementShapeFromInput(OpKernelContext * c,const TensorList & tensor_list,int index,PartialTensorShape * element_shape)168 Status GetElementShapeFromInput(OpKernelContext* c,
169 const TensorList& tensor_list, int index,
170 PartialTensorShape* element_shape) {
171 TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape));
172 // Check that `element_shape` and `tensor_list.element_shape` are
173 // compatible and store the merged shape in `element_shape`.
174 PartialTensorShape tmp = *element_shape;
175 TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape));
176 return Status::OK();
177 }
178
GetInputList(OpKernelContext * c,int index,const TensorList ** list)179 Status GetInputList(OpKernelContext* c, int index, const TensorList** list) {
180 if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
181 return errors::InvalidArgument("Input list must be a scalar saw: ",
182 c->input(index).shape().DebugString());
183 }
184 const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>();
185 if (l == nullptr) {
186 return errors::InvalidArgument(
187 "Input handle is not a list. Saw: '",
188 c->input(index).scalar<Variant>()().DebugString(), "'");
189 }
190 *list = l;
191 return Status::OK();
192 }
193
ForwardInputOrCreateNewList(OpKernelContext * c,int32 input_index,int32 output_index,const TensorList & input_list,TensorList ** output_list)194 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
195 int32 output_index,
196 const TensorList& input_list,
197 TensorList** output_list) {
198 // Attempt to forward the input tensor to the output if possible.
199 AllocatorAttributes attr;
200 attr.set_on_host(true);
201 std::unique_ptr<Tensor> maybe_output =
202 c->forward_input(input_index, output_index, DT_VARIANT, TensorShape{},
203 c->input_memory_type(input_index), attr);
204 Tensor* output_tensor;
205 if (maybe_output != nullptr) {
206 // Woohoo, forwarding succeeded!
207 output_tensor = maybe_output.get();
208 } else {
209 // If forwarding is not possible allocate a new output tensor and copy
210 // the `input_list` to it.
211 TF_RETURN_IF_ERROR(
212 c->allocate_output(output_index, {}, &output_tensor, attr));
213 output_tensor->scalar<Variant>()() = input_list;
214 }
215 *output_list = output_tensor->scalar<Variant>()().get<TensorList>();
216 return Status::OK();
217 }
218
219 class EmptyTensorList : public OpKernel {
220 public:
EmptyTensorList(OpKernelConstruction * ctx)221 explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) {
222 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_));
223 }
224
Compute(OpKernelContext * ctx)225 void Compute(OpKernelContext* ctx) override {
226 const Tensor& max_num_elements_t = ctx->input(1);
227 OP_REQUIRES(
228 ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()),
229 errors::InvalidArgument(
230 "max_num_elements expected to be a scalar ",
231 "but got shape: ", max_num_elements_t.shape().DebugString()));
232 Tensor* result;
233 AllocatorAttributes attr;
234 attr.set_on_host(true);
235 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
236 TensorList empty;
237 empty.element_dtype = element_dtype_;
238 empty.max_num_elements = max_num_elements_t.scalar<int32>()();
239 PartialTensorShape element_shape;
240 OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape));
241 empty.element_shape = element_shape;
242 result->scalar<Variant>()() = std::move(empty);
243 }
244
245 private:
246 DataType element_dtype_;
247 };
248
249 const char TensorList::kTypeName[] = "tensorflow::TensorList";
250
251 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU),
252 EmptyTensorList);
253
254 #if GOOGLE_CUDA
255
256 REGISTER_KERNEL_BUILDER(Name("EmptyTensorList")
257 .Device(DEVICE_GPU)
258 .HostMemory("element_shape")
259 .HostMemory("max_num_elements"),
260 EmptyTensorList);
261
262 #endif // GOOGLE_CUDA
263
264 class TensorListPushBack : public OpKernel {
265 public:
TensorListPushBack(OpKernelConstruction * c)266 explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) {
267 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
268 }
269
~TensorListPushBack()270 ~TensorListPushBack() override {}
271
Compute(OpKernelContext * c)272 void Compute(OpKernelContext* c) override {
273 const Tensor& input = c->input(1);
274 OP_REQUIRES(c, element_dtype_ == input.dtype(),
275 errors::InvalidArgument("Invalid data types; list elements ",
276 DataTypeString(element_dtype_),
277 " but tried to append ",
278 DataTypeString(input.dtype())));
279
280 const TensorList* l = nullptr;
281 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
282 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()),
283 errors::InvalidArgument(
284 "Tried to append a tensor with incompatible shape to a "
285 "list. Op element shape: ",
286 input.shape().DebugString(),
287 " list shape: ", l->element_shape.DebugString()));
288 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
289 errors::InvalidArgument("Invalid data types; op elements ",
290 DataTypeString(element_dtype_),
291 " but list elements ",
292 DataTypeString(l->element_dtype)));
293
294 if (l->max_num_elements != -1) {
295 OP_REQUIRES(
296 c, l->tensors.size() < l->max_num_elements,
297 errors::InvalidArgument("Tried to push item into a full list",
298 " list size: ", l->tensors.size(),
299 " max_num_elements: ", l->max_num_elements));
300 }
301
302 TensorList* output_list = nullptr;
303 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
304 output_list->tensors.push_back(input);
305 }
306
307 private:
308 DataType element_dtype_;
309 };
310
311 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU),
312 TensorListPushBack);
313
314 #if GOOGLE_CUDA
315
316 REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU),
317 TensorListPushBack);
318
319 #endif // GOOGLE_CUDA
320
321 class TensorListLength : public OpKernel {
322 public:
TensorListLength(OpKernelConstruction * c)323 explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {}
~TensorListLength()324 ~TensorListLength() override {}
325
Compute(OpKernelContext * c)326 void Compute(OpKernelContext* c) override {
327 const TensorList* l = nullptr;
328 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
329 Tensor* result;
330 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
331 result->scalar<int32>()() = l->tensors.size();
332 }
333 };
334
335 REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU),
336 TensorListLength);
337
338 #if GOOGLE_CUDA
339
340 REGISTER_KERNEL_BUILDER(
341 Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"),
342 TensorListLength);
343
344 #endif // GOOGLE_CUDA
345
346 class TensorListElementShape : public OpKernel {
347 public:
TensorListElementShape(OpKernelConstruction * c)348 explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {}
349
Compute(OpKernelContext * c)350 void Compute(OpKernelContext* c) override {
351 const TensorList* l = nullptr;
352 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
353 Tensor* result;
354 if (l->element_shape.unknown_rank()) {
355 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result));
356 if (result->dtype() == DT_INT32) {
357 result->scalar<int32>()() = -1;
358 } else {
359 result->scalar<int64>()() = -1;
360 }
361 } else {
362 OP_REQUIRES_OK(c, c->allocate_output(
363 0, TensorShape{l->element_shape.dims()}, &result));
364 for (int i = 0; i < l->element_shape.dims(); ++i) {
365 if (result->dtype() == DT_INT32) {
366 result->flat<int32>()(i) = l->element_shape.dim_size(i);
367 } else {
368 result->flat<int64>()(i) = l->element_shape.dim_size(i);
369 }
370 }
371 }
372 }
373 };
374
375 REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU),
376 TensorListElementShape);
377
378 class TensorListReserve : public OpKernel {
379 public:
TensorListReserve(OpKernelConstruction * c)380 explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) {
381 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
382 }
383
Compute(OpKernelContext * c)384 void Compute(OpKernelContext* c) override {
385 PartialTensorShape element_shape;
386 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape));
387 int32 num_elements = c->input(1).scalar<int32>()();
388 TensorList output;
389 output.element_shape = element_shape;
390 output.element_dtype = element_dtype_;
391 output.tensors.resize(num_elements, Tensor(DT_INVALID));
392 Tensor* result;
393 AllocatorAttributes attr;
394 attr.set_on_host(true);
395 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
396 result->scalar<Variant>()() = std::move(output);
397 }
398
399 private:
400 DataType element_dtype_;
401 };
402
403 REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU),
404 TensorListReserve);
405
406 #if GOOGLE_CUDA
407
408 REGISTER_KERNEL_BUILDER(Name("TensorListReserve")
409 .Device(DEVICE_GPU)
410 .HostMemory("element_shape")
411 .HostMemory("num_elements"),
412 TensorListReserve);
413
414 #endif // GOOGLE_CUDA
415 class TensorListResize : public OpKernel {
416 public:
TensorListResize(OpKernelConstruction * c)417 explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {}
418
Compute(OpKernelContext * c)419 void Compute(OpKernelContext* c) override {
420 const TensorList* input_list = nullptr;
421 OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
422 int32 size = c->input(1).scalar<int32>()();
423 OP_REQUIRES(
424 c, size >= 0,
425 errors::InvalidArgument(
426 "TensorListSlice expects size to be non-negative. Got: ", size));
427
428 AllocatorAttributes attr;
429 attr.set_on_host(true);
430 std::unique_ptr<Tensor> maybe_result = c->forward_input(
431 0, 0, DT_VARIANT, TensorShape{}, c->input_memory_type(0), attr);
432 if (maybe_result != nullptr) {
433 maybe_result->scalar<Variant>()().get<TensorList>()->tensors.resize(
434 size, Tensor(DT_INVALID));
435 } else {
436 Tensor* result;
437 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
438 TensorList output_list;
439 output_list.element_shape = input_list->element_shape;
440 output_list.element_dtype = input_list->element_dtype;
441 output_list.max_num_elements = input_list->max_num_elements;
442 if (size > input_list->tensors.size()) {
443 output_list.tensors.insert(output_list.tensors.begin(),
444 input_list->tensors.begin(),
445 input_list->tensors.end());
446 // Add DT_INVALID tensors to the end of the list if the requested size
447 // is larger than the list length.
448 output_list.tensors.resize(size, Tensor(DT_INVALID));
449 } else {
450 output_list.tensors.insert(output_list.tensors.begin(),
451 input_list->tensors.begin(),
452 input_list->tensors.begin() + size);
453 }
454 result->scalar<Variant>()() = std::move(output_list);
455 }
456 }
457 };
458
459 REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU),
460 TensorListResize);
461
462 #if GOOGLE_CUDA
463
464 REGISTER_KERNEL_BUILDER(
465 Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"),
466 TensorListResize);
467
468 #endif // GOOGLE_CUDA
469
470 class TensorListSetItem : public OpKernel {
471 public:
TensorListSetItem(OpKernelConstruction * c)472 explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) {
473 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
474 }
475
Compute(OpKernelContext * c)476 void Compute(OpKernelContext* c) override {
477 const TensorList* l = nullptr;
478 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
479 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
480 errors::InvalidArgument("Invalid data types; op elements ",
481 DataTypeString(element_dtype_),
482 " but list elements ",
483 DataTypeString(l->element_dtype)));
484 int32 index = c->input(1).scalar<int32>()();
485 OP_REQUIRES(c, index < l->tensors.size(),
486 errors::InvalidArgument("Trying to modify element ", index,
487 " in a list with ", l->tensors.size(),
488 " elements."));
489 const Tensor& value = c->input(2);
490 OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()),
491 errors::InvalidArgument(
492 "Tried to set a tensor with incompatible shape at a "
493 "list index. Item element shape: ",
494 value.shape().DebugString(),
495 " list shape: ", l->element_shape.DebugString()));
496 TensorList* output_list = nullptr;
497 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
498 output_list->tensors[index] = value;
499 }
500
501 private:
502 DataType element_dtype_;
503 };
504
505 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU),
506 TensorListSetItem);
507
508 #if GOOGLE_CUDA
509
510 #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T) \
511 REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \
512 .TypeConstraint<T>("element_dtype") \
513 .Device(DEVICE_GPU) \
514 .HostMemory("index"), \
515 TensorListSetItem);
516
517 TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
518 TF_CALL_complex64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
519 TF_CALL_complex128(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
520 TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
521 TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU);
522 REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16)
523 #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU
524
525 #endif // GOOGLE_CUDA
526
527 class TensorListConcatLists : public OpKernel {
528 public:
TensorListConcatLists(OpKernelConstruction * c)529 explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) {
530 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
531 }
532
Compute(OpKernelContext * c)533 void Compute(OpKernelContext* c) override {
534 const TensorShape& tl_a_shape = c->input(0).shape();
535 const TensorShape& tl_b_shape = c->input(1).shape();
536 OP_REQUIRES(
537 c, tl_a_shape == tl_b_shape,
538 errors::InvalidArgument("Incompatible input TensorList tensor shapes: ",
539 tl_a_shape.DebugString(), " vs. ",
540 tl_b_shape.DebugString()));
541 AllocatorAttributes attr;
542 std::unique_ptr<Tensor> tl_alias = c->forward_input(
543 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape,
544 DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
545
546 // tl_a may be aliased by tl_alias.
547 const Tensor& tl_a = c->input(0);
548 const Tensor& tl_b = c->input(1);
549
550 Tensor* output;
551 if (tl_alias) {
552 c->set_output(0, *tl_alias);
553 output = tl_alias.get();
554 } else {
555 attr.set_on_host(true);
556 OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr));
557 }
558
559 auto output_t = output->flat<Variant>();
560 auto tl_a_t = tl_a.flat<Variant>();
561 auto tl_b_t = tl_b.flat<Variant>();
562
563 for (int64 b = 0; b < tl_a.NumElements(); ++b) {
564 const TensorList* l_a = tl_a_t(b).get<TensorList>();
565 const TensorList* l_b = tl_b_t(b).get<TensorList>();
566 OP_REQUIRES(
567 c, l_a != nullptr,
568 errors::InvalidArgument("input_a is not a TensorList at index ", b,
569 ". Saw: '", tl_a_t(b).DebugString(), "'"));
570 OP_REQUIRES(
571 c, l_b != nullptr,
572 errors::InvalidArgument("input_b is not a TensorList at index ", b,
573 ". Saw: '", tl_b_t(b).DebugString(), "'"));
574 OP_REQUIRES(c, l_a->element_dtype == element_dtype_,
575 errors::InvalidArgument(
576 "input_a[", b, "].dtype != element_dtype. Saw: ",
577 DataTypeString(l_a->element_dtype), " vs. ",
578 DataTypeString(element_dtype_)));
579 OP_REQUIRES(c, l_b->element_dtype == element_dtype_,
580 errors::InvalidArgument(
581 "input_b[", b, "].dtype != element_dtype. Saw: ",
582 DataTypeString(l_b->element_dtype), " vs. ",
583 DataTypeString(element_dtype_)));
584 OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape),
585 errors::InvalidArgument(
586 "input_a and input_b TensorList element shapes are not "
587 "identical at index ",
588 b, ". Saw ", l_a->element_shape.DebugString(), " vs. ",
589 l_b->element_shape.DebugString()));
590 if (tl_alias) {
591 TensorList* out = output_t(b).get<TensorList>();
592 DCHECK(out != nullptr) << "Expected output to alias input_a, but it "
593 "doesn't contain a TensorList at index "
594 << b;
595 std::copy(l_b->tensors.begin(), l_b->tensors.end(),
596 std::back_inserter(out->tensors));
597 } else {
598 TensorList out = *l_a;
599 std::copy(l_b->tensors.begin(), l_b->tensors.end(),
600 std::back_inserter(out.tensors));
601 output_t(b) = std::move(out);
602 }
603 }
604 }
605
606 private:
607 DataType element_dtype_;
608 };
609
610 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU),
611 TensorListConcatLists);
612
613 #if GOOGLE_CUDA
614
615 REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU),
616 TensorListConcatLists);
617
618 #endif // GOOGLE_CUDA
619
620 #define REGISTER_TENSOR_LIST_OPS_CPU(T) \
621 REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
622 .TypeConstraint<T>("element_dtype") \
623 .Device(DEVICE_CPU), \
624 TensorListStack<CPUDevice, T>) \
625 REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
626 .TypeConstraint<T>("element_dtype") \
627 .Device(DEVICE_CPU), \
628 TensorListGather<CPUDevice, T>) \
629 REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \
630 .TypeConstraint<T>("element_dtype") \
631 .Device(DEVICE_CPU), \
632 TensorListConcat<CPUDevice, T>) \
633 REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
634 .TypeConstraint<T>("element_dtype") \
635 .Device(DEVICE_CPU), \
636 TensorListConcat<CPUDevice, T>) \
637 REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \
638 .TypeConstraint<T>("element_dtype") \
639 .Device(DEVICE_CPU), \
640 TensorListGetItem<CPUDevice, T>) \
641 REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \
642 .TypeConstraint<T>("element_dtype") \
643 .Device(DEVICE_CPU), \
644 TensorListPopBack<CPUDevice, T>) \
645 REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
646 .TypeConstraint<T>("element_dtype") \
647 .Device(DEVICE_CPU), \
648 TensorListFromTensor<CPUDevice, T>) \
649 REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
650 .TypeConstraint<T>("element_dtype") \
651 .Device(DEVICE_CPU), \
652 TensorListScatter<CPUDevice, T>) \
653 REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \
654 .TypeConstraint<T>("element_dtype") \
655 .Device(DEVICE_CPU), \
656 TensorListScatter<CPUDevice, T>) \
657 REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \
658 .TypeConstraint<T>("element_dtype") \
659 .Device(DEVICE_CPU), \
660 TensorListScatterIntoExistingList<CPUDevice, T>) \
661 REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \
662 .TypeConstraint<T>("element_dtype") \
663 .Device(DEVICE_CPU), \
664 TensorListSplit<CPUDevice, T>) \
665 REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
666 .TypeConstraint<T>("element_dtype") \
667 .Device(DEVICE_CPU), \
668 TensorListPushBackBatch<CPUDevice, T>)
669
670 TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU);
671 REGISTER_TENSOR_LIST_OPS_CPU(quint8);
672 REGISTER_TENSOR_LIST_OPS_CPU(qint8);
673 REGISTER_TENSOR_LIST_OPS_CPU(quint16);
674 REGISTER_TENSOR_LIST_OPS_CPU(qint16);
675 REGISTER_TENSOR_LIST_OPS_CPU(qint32);
676 REGISTER_TENSOR_LIST_OPS_CPU(Variant);
677
678 #undef REGISTER_TENSOR_LIST_OPS_CPU
679
680 #define REGISTER_TENSOR_LIST_OPS_CPU(T)
681
682 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
683 TensorList,
684 TensorListBinaryAdd<CPUDevice>);
685
686 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
687 DEVICE_CPU, TensorList,
688 TensorListZerosLike<CPUDevice>);
689
690 } // namespace tensorflow
691