1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
16 #define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
17
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 #define EIGEN_USE_GPU
21 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 #include "tensorflow/core/kernels/concat_lib.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/kernels/tensor_list.h"
33 #include "tensorflow/core/lib/core/coding.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/refcount.h"
36 #include "tensorflow/core/lib/gtl/array_slice.h"
37 #include "tensorflow/core/platform/platform.h"
38 #include "tensorflow/core/util/tensor_ops_util.h"
39 #include "tensorflow/core/util/util.h"
40
41 namespace tensorflow {
42
43 typedef Eigen::ThreadPoolDevice CPUDevice;
44
45 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
46
47 Status GetElementShapeFromInput(OpKernelContext* c,
48 const TensorList& tensor_list, int index,
49 PartialTensorShape* element_shape);
50
51 Status GetInputList(OpKernelContext* c, int index, const TensorList** list);
52
53 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
54 int32 output_index,
55 const TensorList& input_list,
56 TensorList** output_list);
57
58 template <typename Device, typename T>
59 class TensorListStack : public OpKernel {
60 public:
61 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
62 ConstMatrixVector;
TensorListStack(OpKernelConstruction * c)63 explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) {
64 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
65 OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_));
66 }
67
Compute(OpKernelContext * c)68 void Compute(OpKernelContext* c) override {
69 const TensorList* tensor_list = nullptr;
70 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
71 OP_REQUIRES(
72 c, element_dtype_ == tensor_list->element_dtype,
73 errors::InvalidArgument(
74 "Invalid data types; op elements ", DataTypeString(element_dtype_),
75 " but list elements ", DataTypeString(tensor_list->element_dtype)));
76 if (num_elements_ != -1) {
77 OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_,
78 errors::InvalidArgument(
79 "Operation expected a list with ", num_elements_,
80 " elements but got a list with ",
81 tensor_list->tensors().size(), " elements."));
82 }
83 PartialTensorShape partial_element_shape;
84 OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
85 &partial_element_shape));
86 OP_REQUIRES(
87 c,
88 partial_element_shape.IsFullyDefined() ||
89 !tensor_list->tensors().empty(),
90 errors::InvalidArgument("Tried to stack elements of an empty ",
91 "list with non-fully-defined element_shape: ",
92 partial_element_shape.DebugString()));
93
94 // Check that `element_shape` input tensor is compatible with the shapes of
95 // element tensors.
96 if (!tensor_list->element_shape.IsFullyDefined()) {
97 for (int i = 0; i < tensor_list->tensors().size(); ++i) {
98 const Tensor& t = tensor_list->tensors()[i];
99 if (t.dtype() != DT_INVALID) {
100 PartialTensorShape tmp = partial_element_shape;
101 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
102 }
103 }
104 }
105
106 // Compute the shape of the output tensor by pre-pending the leading dim to
107 // the element_shape.
108 TensorShape element_shape;
109 OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape),
110 errors::InvalidArgument(
111 "Tried to stack list which only contains uninitialized ",
112 "tensors and has a non-fully-defined element_shape: ",
113 partial_element_shape.DebugString()));
114 TensorShape output_shape = element_shape;
115 output_shape.InsertDim(0, tensor_list->tensors().size());
116 Tensor* output;
117 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
118 if (output->NumElements() == 0) {
119 return;
120 }
121
122 ConstMatrixVector inputs_flat;
123 inputs_flat.reserve(tensor_list->tensors().size());
124 Tensor zeros;
125 for (const auto& t : tensor_list->tensors()) {
126 if (t.dtype() != DT_INVALID) {
127 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
128 t.shaped<T, 2>({1, t.NumElements()})));
129 } else {
130 if (!zeros.NumElements()) {
131 AllocatorAttributes attr;
132 if (element_dtype_ == DT_VARIANT) {
133 attr.set_on_host(true);
134 }
135 OP_REQUIRES_OK(
136 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
137 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
138 zeros.flat<T>());
139 }
140 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
141 const_cast<const Tensor&>(zeros).shaped<T, 2>(
142 {1, zeros.NumElements()})));
143 }
144 }
145 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
146
147 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
148 if (std::is_same<Device, Eigen::GpuDevice>::value) {
149 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
150 return;
151 }
152 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
153 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
154 }
155
156 private:
157 int num_elements_;
158 DataType element_dtype_;
159 };
160
161 template <typename Device, typename T>
162 class TensorListGetItem : public OpKernel {
163 public:
TensorListGetItem(OpKernelConstruction * c)164 explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
165 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
166 }
167
Compute(OpKernelContext * c)168 void Compute(OpKernelContext* c) override {
169 const TensorList* l = nullptr;
170 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
171 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
172 errors::InvalidArgument("Invalid data types; op elements ",
173 DataTypeString(element_dtype_),
174 " but list elements ",
175 DataTypeString(l->element_dtype)));
176 int32 index = c->input(1).scalar<int32>()();
177 OP_REQUIRES(c, index < l->tensors().size(),
178 errors::InvalidArgument("Trying to access element ", index,
179 " in a list with ", l->tensors().size(),
180 " elements."));
181 if (l->tensors()[index].dtype() != DT_INVALID) {
182 c->set_output(0, l->tensors()[index]);
183 } else {
184 PartialTensorShape partial_element_shape;
185 OP_REQUIRES_OK(
186 c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape));
187 TensorShape element_shape;
188 // If l->element_shape and the element_shape input are both not fully
189 // defined, try to infer the shape from other list elements. This requires
190 // that all initialized list elements have the same shape.
191 // NOTE(srbs): This might be a performance bottleneck since we are
192 // iterating over the entire list here. This is necessary for feature
193 // parity with TensorArray.read. TensorArray has a mode in which all
194 // elements are required to be of the same shape, TensorList does not.
195 // In that mode TensorArray sets the array's element_shape on the first
196 // write call. We could do something similar here if needed.
197 if (!partial_element_shape.IsFullyDefined()) {
198 for (const Tensor& t : l->tensors()) {
199 if (t.dtype() != DT_INVALID) {
200 PartialTensorShape tmp = partial_element_shape;
201 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
202 }
203 }
204 }
205 OP_REQUIRES(
206 c, partial_element_shape.AsTensorShape(&element_shape),
207 errors::InvalidArgument("Trying to read an uninitialized tensor but ",
208 "element_shape is not fully defined: ",
209 partial_element_shape.DebugString(),
210 " and no list element is set."));
211 Tensor* result;
212 AllocatorAttributes attr;
213 if (element_dtype_ == DT_VARIANT) {
214 attr.set_on_host(true);
215 }
216 OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr));
217 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
218 result->flat<T>());
219 }
220 }
221
222 private:
223 DataType element_dtype_;
224 };
225
226 template <typename Device, typename T>
227 class TensorListPopBack : public OpKernel {
228 public:
TensorListPopBack(OpKernelConstruction * c)229 explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
230 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
231 }
232
Compute(OpKernelContext * c)233 void Compute(OpKernelContext* c) override {
234 const TensorList* l = nullptr;
235 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
236 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
237 errors::InvalidArgument("Invalid data types; op elements ",
238 DataTypeString(element_dtype_),
239 " but list elements ",
240 DataTypeString(l->element_dtype)));
241
242 OP_REQUIRES(c, !l->tensors().empty(),
243 errors::InvalidArgument("Trying to pop from an empty list."));
244
245 const Tensor& t = l->tensors().back();
246 if (t.dtype() != DT_INVALID) {
247 c->set_output(1, t);
248 } else {
249 PartialTensorShape partial_element_shape;
250 OP_REQUIRES_OK(
251 c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape));
252 TensorShape element_shape;
253 OP_REQUIRES(
254 c, partial_element_shape.AsTensorShape(&element_shape),
255 errors::InvalidArgument("Trying to read an uninitialized tensor but ",
256 "element_shape is not fully defined.",
257 partial_element_shape.DebugString()));
258 Tensor* result;
259 AllocatorAttributes attr;
260 if (element_dtype_ == DT_VARIANT) {
261 attr.set_on_host(true);
262 }
263 OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr));
264 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
265 result->flat<T>());
266 }
267
268 TensorList* output_list = nullptr;
269 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
270 output_list->tensors().pop_back();
271 }
272
273 private:
274 DataType element_dtype_;
275 };
276
277 template <typename Device, typename T>
278 class TensorListConcat : public OpKernel {
279 public:
280 using ConstMatrixVector =
281 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>;
TensorListConcat(OpKernelConstruction * c)282 explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) {
283 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
284 if (c->HasAttr("element_shape")) {
285 OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape_));
286 }
287 }
288
Compute(OpKernelContext * c)289 void Compute(OpKernelContext* c) override {
290 PartialTensorShape element_shape_except_first_dim;
291 if (!element_shape_.unknown_rank()) {
292 element_shape_except_first_dim = PartialTensorShape(
293 gtl::ArraySlice<int64>(element_shape_.dim_sizes()).subspan(1));
294 }
295 // Check that the input Variant tensor is indeed a TensorList and has the
296 // correct element type.
297 const TensorList* tensor_list = nullptr;
298 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
299 OP_REQUIRES(
300 c, element_dtype_ == tensor_list->element_dtype,
301 errors::InvalidArgument(
302 "Invalid data types; op elements ", DataTypeString(element_dtype_),
303 " but list elements ", DataTypeString(tensor_list->element_dtype)));
304 // The leading dimension of all list elements if they are all the same.
305 // This is used as the leading dim of uninitialized tensors in the list
306 // if leading_dims is not provided.
307 int64 first_dim = -1;
308 if (c->num_inputs() > 1) {
309 // TensorListConcatV2
310 PartialTensorShape element_shape;
311 OP_REQUIRES_OK(
312 c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape));
313 OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
314 errors::InvalidArgument(
315 "Concat requires elements to be at least vectors, ",
316 "found scalars instead."));
317 // Split `element_shape` into `first_dim` and
318 // `element_shape_except_first_dim`.
319 first_dim = element_shape.dim_size(0);
320 element_shape_except_first_dim = element_shape;
321 element_shape_except_first_dim.RemoveDim(0);
322 }
323 // If the TensorList is empty, element_shape_except_first_dim must be fully
324 // defined.
325 OP_REQUIRES(c,
326 !tensor_list->tensors().empty() ||
327 element_shape_except_first_dim.IsFullyDefined(),
328 errors::InvalidArgument(
329 "All except the first dimension must be fully defined ",
330 "when concating an empty tensor list. element_shape: ",
331 element_shape_except_first_dim.DebugString()));
332 // 1. Check that `element_shape_except_first_dim` input tensor is
333 // compatible with the shapes of element tensors.
334 // 2. Check that the elements have the same shape except the first dim.
335 // 3. If `first_dim` is known, check that it is compatible with the leading
336 // dims of all elements.
337 // 4. If `first_dim` is unknown (-1), check whether all initialized
338 // elements have the same leading dim and if so set `first_dim` to that
339 // value.
340 if (!tensor_list->element_shape.IsFullyDefined()) {
341 bool check_dim = (first_dim == -1);
342 int64 inferred_first_dim = first_dim;
343 for (int i = 0; i < tensor_list->tensors().size(); ++i) {
344 const Tensor& t = tensor_list->tensors()[i];
345 if (t.dtype() != DT_INVALID) {
346 PartialTensorShape tmp = element_shape_except_first_dim;
347 OP_REQUIRES(
348 c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
349 errors::InvalidArgument("Concat saw a scalar shape at index ", i,
350 " but requires at least vectors."));
351 TensorShape shape_except_first_dim = TensorShape(
352 gtl::ArraySlice<int64>(t.shape().dim_sizes()).subspan(1));
353 OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
354 &element_shape_except_first_dim));
355 OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0),
356 errors::InvalidArgument(
357 "First entry of element_shape input does not match ",
358 "the first dim of list element at index: ", i,
359 " Expected: ", first_dim,
360 " Actual: ", t.shape().dim_size(0)));
361 if (check_dim) {
362 if (inferred_first_dim == -1) {
363 inferred_first_dim = t.shape().dim_size(0);
364 } else if (inferred_first_dim != t.shape().dim_size(0)) {
365 inferred_first_dim = -1;
366 check_dim = false;
367 }
368 }
369 }
370 }
371 first_dim = inferred_first_dim;
372 }
373 TensorShape output_shape;
374 OP_REQUIRES(c, element_shape_except_first_dim.AsTensorShape(&output_shape),
375 errors::InvalidArgument(
376 "Trying to concat list with only uninitialized tensors ",
377 "but element_shape_except_first_dim is not fully defined: ",
378 element_shape_except_first_dim.DebugString()));
379 // Build the lengths_tensor and leading dim of the output tensor by
380 // iterating over all element tensors.
381 Tensor* lengths_tensor = nullptr;
382 OP_REQUIRES_OK(
383 c,
384 c->allocate_output(
385 1, TensorShape({static_cast<int64>(tensor_list->tensors().size())}),
386 &lengths_tensor));
387 auto lengths_tensor_vec = lengths_tensor->vec<int64>();
388 int64 leading_dim = 0;
389 for (size_t i = 0; i < tensor_list->tensors().size(); i++) {
390 int64 dim;
391 if (tensor_list->tensors()[i].dtype() != DT_INVALID) {
392 dim = tensor_list->tensors()[i].shape().dim_size(0);
393 } else {
394 // If leading_dims is not provided or does not contain an entry for
395 // index i use the inferred `first_dim` if set.
396 if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) &&
397 first_dim != -1) {
398 dim = first_dim;
399 } else {
400 OP_REQUIRES(c, c->num_inputs() > 2,
401 errors::InvalidArgument(
402 "Concating lists with uninitialized tensors is not ",
403 "supported in this version of TensorListConcat. ",
404 "Consider updating your GraphDef to run the newer ",
405 "version."));
406 OP_REQUIRES(c, i < c->input(2).NumElements(),
407 errors::InvalidArgument(
408 "List contains uninitialized tensor at index ", i,
409 " but leading_dims has only ",
410 c->input(2).NumElements(), " elements."));
411 dim = c->input(2).vec<int64>()(i);
412 }
413 }
414 leading_dim += dim;
415 lengths_tensor_vec(i) = dim;
416 }
417 output_shape.InsertDim(0, leading_dim);
418 Tensor* output;
419 // Allocate the output tensor and fill it up with the concated element
420 // tensors.
421 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
422 if (output->NumElements() == 0) {
423 return;
424 }
425
426 ConstMatrixVector inputs_flat;
427 inputs_flat.reserve(tensor_list->tensors().size());
428 // Store the zeros tensors in a vector to prevent them from being GC'ed till
429 // concat is complete.
430 std::vector<Tensor> zeros_vec;
431 for (int i = 0; i < tensor_list->tensors().size(); i++) {
432 const Tensor& element_tensor = tensor_list->tensors()[i];
433 if (element_tensor.dtype() != DT_INVALID) {
434 if (element_tensor.NumElements() > 0) {
435 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
436 element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
437 }
438 } else {
439 AllocatorAttributes attr;
440 if (element_dtype_ == DT_VARIANT) {
441 attr.set_on_host(true);
442 }
443 TensorShape element_shape = output_shape;
444 element_shape.set_dim(0, lengths_tensor_vec(i));
445 zeros_vec.emplace_back();
446 Tensor& zeros = zeros_vec.back();
447 OP_REQUIRES_OK(
448 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
449 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
450 zeros.flat<T>());
451 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
452 const_cast<const Tensor&>(zeros).shaped<T, 2>(
453 {1, zeros.NumElements()})));
454 }
455 }
456 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
457
458 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
459 if (std::is_same<Device, Eigen::GpuDevice>::value) {
460 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
461 return;
462 }
463 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
464 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
465 }
466
467 private:
468 DataType element_dtype_;
469 PartialTensorShape element_shape_;
470 };
471
472 template <typename Device, typename T>
473 class TensorListSplit : public OpKernel {
474 public:
TensorListSplit(OpKernelConstruction * c)475 TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {}
476
Compute(OpKernelContext * c)477 void Compute(OpKernelContext* c) override {
478 Tensor* output_tensor;
479 AllocatorAttributes attr;
480 attr.set_on_host(true);
481 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
482 PartialTensorShape element_shape;
483 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
484 OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
485 errors::InvalidArgument(
486 "TensorListSplit requires element_shape to be at least of ",
487 "rank 1, but saw: ", element_shape.DebugString()));
488 TensorList output_list;
489 const Tensor& input_tensor = c->input(0);
490 output_list.element_dtype = input_tensor.dtype();
491 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
492 errors::InvalidArgument(
493 "Tensor must be at least a vector, but saw shape: ",
494 input_tensor.shape().DebugString()));
495 TensorShape tensor_shape_without_first_dim(input_tensor.shape());
496 tensor_shape_without_first_dim.RemoveDim(0);
497 PartialTensorShape element_shape_without_first_dim;
498 if (!element_shape.unknown_rank()) {
499 element_shape_without_first_dim =
500 PartialTensorShape(element_shape.dim_sizes());
501 element_shape_without_first_dim.RemoveDim(0);
502 }
503 OP_REQUIRES(c,
504 element_shape_without_first_dim.IsCompatibleWith(
505 tensor_shape_without_first_dim),
506 errors::InvalidArgument(
507 "tensor shape ", input_tensor.shape().DebugString(),
508 " is not compatible with element_shape ",
509 element_shape.DebugString()));
510 output_list.element_shape = element_shape;
511 const Tensor& lengths = c->input(2);
512 OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()),
513 errors::InvalidArgument(
514 "Expected lengths to be a vector, received shape: ",
515 lengths.shape().DebugString()));
516 output_list.tensors().reserve(lengths.shape().dim_size(0));
517 int64 start = 0;
518 int64 end = 0;
519 for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
520 int64 length = lengths.vec<int64>()(i);
521 OP_REQUIRES(
522 c, length >= 0,
523 errors::InvalidArgument("Invalid value in lengths: ", length));
524 end = start + length;
525 OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0),
526 errors::InvalidArgument("Attempting to slice [", start, ", ",
527 end, "] from tensor with length ",
528 input_tensor.shape().dim_size(0)));
529 Tensor tmp = input_tensor.Slice(start, end);
530 start = end;
531 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
532 // prevent this.
533 Tensor aligned;
534 OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
535 aligned.flat<T>().device(c->eigen_device<Device>()) =
536 tmp.unaligned_flat<T>();
537 output_list.tensors().emplace_back(aligned);
538 }
539 OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
540 errors::InvalidArgument(
541 "Unused values in tensor. Length of tensor: ",
542 input_tensor.shape().dim_size(0), " Values used: ", end));
543 output_tensor->scalar<Variant>()() = std::move(output_list);
544 }
545 };
546
547 template <typename Device, typename T>
548 class TensorListGather : public OpKernel {
549 public:
550 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
551 ConstMatrixVector;
TensorListGather(OpKernelConstruction * c)552 explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
553 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
554 }
555
Compute(OpKernelContext * c)556 void Compute(OpKernelContext* c) override {
557 const TensorList* tensor_list = nullptr;
558 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
559 OP_REQUIRES(
560 c, element_dtype_ == tensor_list->element_dtype,
561 errors::InvalidArgument(
562 "Invalid data types; op elements ", DataTypeString(element_dtype_),
563 " but list elements ", DataTypeString(tensor_list->element_dtype)));
564 const Tensor& indices = c->input(1);
565 PartialTensorShape partial_element_shape;
566 OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2,
567 &partial_element_shape));
568 OP_REQUIRES(
569 c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0,
570 errors::InvalidArgument("Tried to gather 0-elements from "
571 "a list with non-fully-defined shape: ",
572 partial_element_shape.DebugString()));
573
574 // Check that `element_shape` input tensor is compatible with the shapes of
575 // element tensors.
576 if (!tensor_list->element_shape.IsFullyDefined()) {
577 for (int index = 0; index < indices.NumElements(); ++index) {
578 const int i = indices.flat<int32>()(index);
579 const Tensor& t = tensor_list->tensors()[i];
580 if (t.dtype() != DT_INVALID) {
581 PartialTensorShape tmp = partial_element_shape;
582 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
583 }
584 }
585 }
586
587 // Compute the shape of the output tensor by pre-pending the leading dim to
588 // the element_shape.
589 TensorShape element_shape;
590 OP_REQUIRES(
591 c, partial_element_shape.AsTensorShape(&element_shape),
592 errors::InvalidArgument("Tried to gather uninitialized tensors from a ",
593 "list with non-fully-defined element_shape: ",
594 partial_element_shape.DebugString()));
595 TensorShape output_shape = element_shape;
596 output_shape.InsertDim(0, indices.NumElements());
597 Tensor* output;
598 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
599 if (output->NumElements() == 0) {
600 return;
601 }
602
603 ConstMatrixVector inputs_flat;
604 inputs_flat.reserve(indices.NumElements());
605 Tensor zeros;
606 for (int index = 0; index < indices.NumElements(); ++index) {
607 const int i = indices.flat<int32>()(index);
608 OP_REQUIRES(
609 c, i < tensor_list->tensors().size(),
610 errors::InvalidArgument("Index ", i, " out o range; list only has ",
611 tensor_list->tensors().size(), " elements."));
612 const Tensor& t = tensor_list->tensors()[i];
613 if (t.dtype() != DT_INVALID) {
614 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
615 t.shaped<T, 2>({1, t.NumElements()})));
616 } else {
617 if (!zeros.NumElements()) {
618 AllocatorAttributes attr;
619 if (element_dtype_ == DT_VARIANT) {
620 attr.set_on_host(true);
621 }
622 OP_REQUIRES_OK(
623 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
624 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
625 zeros.flat<T>());
626 }
627 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
628 const_cast<const Tensor&>(zeros).shaped<T, 2>(
629 {1, zeros.NumElements()})));
630 }
631 }
632 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
633
634 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
635 if (std::is_same<Device, Eigen::GpuDevice>::value) {
636 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
637 return;
638 }
639 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
640 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
641 }
642
643 private:
644 DataType element_dtype_;
645 };
646
647 template <typename Device, typename T>
648 class TensorListFromTensor : public OpKernel {
649 public:
TensorListFromTensor(OpKernelConstruction * c)650 TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
651
Compute(OpKernelContext * c)652 void Compute(OpKernelContext* c) override {
653 Tensor* output_tensor;
654 AllocatorAttributes attr;
655 attr.set_on_host(true);
656 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
657 PartialTensorShape element_shape;
658 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
659 TensorList output_list;
660 const Tensor& t = c->input(0);
661 output_list.element_dtype = t.dtype();
662 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
663 errors::InvalidArgument(
664 "Tensor must be at least a vector, but saw shape: ",
665 t.shape().DebugString()));
666 TensorShape output_shape(t.shape());
667 output_shape.RemoveDim(0);
668 OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
669 errors::InvalidArgument(
670 "Specified a list with shape ", element_shape.DebugString(),
671 " from a tensor with shape ", output_shape.DebugString()));
672 output_list.element_shape = element_shape;
673 output_list.tensors().reserve(t.shape().dim_size(0));
674 for (int i = 0; i < t.shape().dim_size(0); ++i) {
675 Tensor tmp = t.Slice(i, i + 1);
676 TensorShape tmp_shape = tmp.shape();
677 tmp_shape.RemoveDim(0);
678 OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
679 errors::Unknown("Unexpected shape error."));
680 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
681 // prevent this.
682 Tensor aligned;
683 OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
684 aligned.flat<T>().device(c->eigen_device<Device>()) =
685 tmp.unaligned_flat<T>();
686 output_list.tensors().push_back(aligned);
687 }
688 output_tensor->scalar<Variant>()() = std::move(output_list);
689 }
690 };
691
692 // Scatters values in `value` into `list`. Assumes that `indices` are valid.
693 template <typename Device, typename T>
Scatter(OpKernelContext * c,const Tensor & value,const Tensor & indices,TensorList * list)694 Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
695 TensorList* list) {
696 for (int index = 0; index < indices.NumElements(); ++index) {
697 const int i = indices.flat<int32>()(index);
698 Tensor tmp = value.Slice(index, index + 1);
699 TensorShape tmp_shape = tmp.shape();
700 tmp_shape.RemoveDim(0);
701 if (!tmp.CopyFrom(tmp, tmp_shape)) {
702 return errors::Unknown("Unexpected shape error.");
703 }
704 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
705 // prevent this.
706 Tensor aligned;
707 TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
708 // TODO(apassos) do all slices in a single kernel invocation instead of
709 // many small ones.
710 aligned.flat<T>().device(c->eigen_device<Device>()) =
711 tmp.unaligned_flat<T>();
712 std::swap(list->tensors()[i], aligned);
713 }
714 return Status::OK();
715 }
716
717 template <typename Device, typename T>
718 class TensorListScatterIntoExistingList : public OpKernel {
719 public:
TensorListScatterIntoExistingList(OpKernelConstruction * c)720 TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {}
721
Compute(OpKernelContext * c)722 void Compute(OpKernelContext* c) override {
723 const TensorList* l = nullptr;
724 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
725 const Tensor& input_tensor = c->input(1);
726 const Tensor& indices = c->input(2);
727
728 // Check that inputs are valid.
729 OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype,
730 errors::InvalidArgument(
731 "Invalid data types; input tensor type: ",
732 DataTypeString(input_tensor.dtype()),
733 " list element_type: ", DataTypeString(l->element_dtype)));
734 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
735 errors::InvalidArgument(
736 "Tensor must be at least a vector, but saw shape: ",
737 input_tensor.shape().DebugString()));
738 OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()),
739 errors::InvalidArgument(
740 "Expected indices to be a vector, but received shape: ",
741 indices.shape().DebugString()));
742 OP_REQUIRES(
743 c, indices.NumElements() == input_tensor.shape().dim_size(0),
744 errors::InvalidArgument(
745 "Expected len(indices) == tensor.shape[0], but saw: ",
746 indices.NumElements(), " vs. ", input_tensor.shape().dim_size(0)));
747
748 // Resize the list if needed to accommodate all indices.
749 TensorList* output_list = nullptr;
750 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
751 const auto indices_vec = indices.vec<int32>();
752 int32 max_index =
753 (indices.NumElements() == 0)
754 ? -1
755 : *std::max_element(indices_vec.data(),
756 indices_vec.data() + indices.NumElements());
757 if (max_index + 1 > output_list->tensors().size()) {
758 output_list->tensors().resize(max_index + 1);
759 }
760
761 // Scatter the values.
762 OP_REQUIRES_OK(c,
763 Scatter<Device, T>(c, input_tensor, indices, output_list));
764 }
765 };
766
767 template <typename Device, typename T>
768 class TensorListScatter : public OpKernel {
769 public:
TensorListScatter(OpKernelConstruction * c)770 TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
771
Compute(OpKernelContext * c)772 void Compute(OpKernelContext* c) override {
773 Tensor* output_tensor;
774 AllocatorAttributes attr;
775 attr.set_on_host(true);
776 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
777 Tensor indices = c->input(1);
778 PartialTensorShape element_shape;
779 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
780 // TensorListScatterV2 passes the num_elements input, TensorListScatter does
781 // not.
782 int num_elements = c->num_inputs() >= 4 ? c->input(3).scalar<int>()() : -1;
783 OP_REQUIRES(c, num_elements >= -1,
784 errors::InvalidArgument(
785 "TensorListScatter expects num_elements >= -1, found: ",
786 num_elements));
787 TensorList output_list;
788 const Tensor& input_tensor = c->input(0);
789 output_list.element_dtype = input_tensor.dtype();
790 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
791 errors::InvalidArgument(
792 "Tensor must be at least a vector, but saw shape: ",
793 input_tensor.shape().DebugString()));
794 TensorShape output_shape(input_tensor.shape());
795 output_shape.RemoveDim(0);
796 OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
797 errors::InvalidArgument(
798 "Specified a list with shape ", element_shape.DebugString(),
799 " from a tensor with shape ", output_shape.DebugString()));
800 output_list.element_shape = element_shape;
801
802 OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
803 errors::InvalidArgument(
804 "Invalid number of rows in input tensor. Expected: ",
805 indices.NumElements(),
806 " Actual: ", input_tensor.shape().dim_size(0)));
807
808 // Validate indices and resize output_list.tensors to fit the highest index.
809 {
810 int highest_index = -1;
811 for (int index = 0; index < indices.NumElements(); ++index) {
812 const int i = indices.flat<int32>()(index);
813 OP_REQUIRES(
814 c, i >= 0,
815 errors::InvalidArgument(
816 "Indices in TensorListScatter must all be non-negative."));
817 OP_REQUIRES(c, num_elements == -1 || i < num_elements,
818 errors::InvalidArgument(
819 "TensorListScatter: Trying to scatter at index ", i,
820 " in list with size ", num_elements));
821 if (i > highest_index) {
822 highest_index = i;
823 }
824 }
825 output_list.tensors().resize(std::max(highest_index + 1, num_elements),
826 Tensor(DT_INVALID));
827 }
828
829 OP_REQUIRES_OK(c,
830 Scatter<Device, T>(c, input_tensor, indices, &output_list));
831 output_tensor->scalar<Variant>()() = std::move(output_list);
832 }
833 };
834
835 template <typename Device>
TensorListBinaryAdd(OpKernelContext * c,const TensorList & a,const TensorList & b,TensorList * out)836 Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
837 const TensorList& b, TensorList* out) {
838 if (a.element_dtype != b.element_dtype) {
839 return errors::InvalidArgument(
840 "Trying to add two lists of tensors of different dtypes. One is ",
841 DataTypeString(a.element_dtype), " and the other is ",
842 DataTypeString(b.element_dtype));
843 }
844 out->element_dtype = a.element_dtype;
845 if (!a.element_shape.IsCompatibleWith(b.element_shape)) {
846 return errors::InvalidArgument(
847 "Trying to add two lists of tensors with incompatible element shapes. "
848 "One is ",
849 a.element_shape.DebugString(), " and the other is ",
850 b.element_shape.DebugString());
851 }
852
853 TF_RETURN_IF_ERROR(
854 a.element_shape.MergeWith(b.element_shape, &out->element_shape));
855 if (a.tensors().size() != b.tensors().size()) {
856 return errors::InvalidArgument(
857 "Trying to add two lists of tensors with different lengths. One is ",
858 a.tensors().size(), " and the other is ", b.tensors().size());
859 }
860 out->tensors().reserve(a.tensors().size());
861 for (int i = 0; i < a.tensors().size(); ++i) {
862 const Tensor& a_tensor = a.tensors()[i];
863 const Tensor& b_tensor = b.tensors()[i];
864 Tensor out_tensor;
865 TF_RETURN_IF_ERROR(
866 BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
867 out->tensors().push_back(out_tensor);
868 }
869 return Status::OK();
870 }
871
872 template <typename Device>
TensorListZerosLike(OpKernelContext * c,const TensorList & x,TensorList * y)873 Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
874 TensorList* y) {
875 y->element_dtype = x.element_dtype;
876 y->element_shape = x.element_shape;
877 y->tensors().reserve(x.tensors().size());
878 for (const Tensor& t : x.tensors()) {
879 Tensor out_tensor;
880 TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
881 y->tensors().emplace_back(out_tensor);
882 }
883 return Status::OK();
884 }
885
886 template <typename Device, typename T>
887 class TensorListPushBackBatch : public OpKernel {
888 public:
TensorListPushBackBatch(OpKernelConstruction * c)889 explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
890 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
891 }
892
Compute(OpKernelContext * c)893 void Compute(OpKernelContext* c) override {
894 const Tensor& input = c->input(1);
895 OP_REQUIRES(c, element_dtype_ == input.dtype(),
896 errors::InvalidArgument("Invalid data types; list elements ",
897 DataTypeString(element_dtype_),
898 " but tried to append ",
899 DataTypeString(input.dtype())));
900 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
901 errors::InvalidArgument(
902 "Expected tensor to be at least a vector, but saw shape: ",
903 input.shape().DebugString()));
904
905 const TensorShape& tls_shape = c->input(0).shape();
906
907 // For purposes of input forwarding, we want the least restrictive
908 // AllocatorAttributes possible. If we need to allocate later,
909 // we'll request the DT_VARIANT be allocated on host.
910 AllocatorAttributes attr;
911
912 std::unique_ptr<Tensor> tls_alias = c->forward_input(
913 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
914 DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
915
916 bool ok_to_alias = tls_alias != nullptr;
917 if (tls_alias && tls_alias->dtype() == DT_VARIANT &&
918 tls_alias->NumElements() > 0) {
919 auto alias_t = tls_alias->flat<Variant>();
920 for (int i = 0; i < tls_alias->NumElements(); ++i) {
921 TensorList* tl_i = alias_t(i).get<TensorList>();
922 if (tl_i == nullptr || !tl_i->RefCountIsOne()) {
923 ok_to_alias = false;
924 break;
925 }
926 }
927 }
928 const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0);
929
930 OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
931 errors::InvalidArgument(
932 "Expected input_handles dtype to be Variant, but saw: ",
933 DataTypeString(tls.dtype())));
934 OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
935 errors::InvalidArgument(
936 "Expected input_handles to be a vector, but saw shape: ",
937 tls_shape.DebugString()));
938 const int64 batch_size = tls.NumElements();
939 OP_REQUIRES(c, input.dim_size(0) == batch_size,
940 errors::InvalidArgument(
941 "Expected tensor.shape[0] == input_handles.size, but saw ",
942 input.dim_size(0), " vs. ", batch_size));
943 auto tls_t = tls.vec<Variant>();
944
945 TensorShape input_element_shape = input.shape();
946 input_element_shape.RemoveDim(0);
947 std::vector<const TensorList*> tl_batch;
948 for (int64 b = 0; b < batch_size; ++b) {
949 const TensorList* l = tls_t(b).get<TensorList>();
950 OP_REQUIRES(c, l != nullptr,
951 errors::InvalidArgument("Input handle at index ", b,
952 " is not a list. Saw: '",
953 tls_t(b).DebugString(), "'"));
954 OP_REQUIRES(
955 c, l->element_shape.IsCompatibleWith(input_element_shape),
956 errors::InvalidArgument(
957 "Tried to append a tensor with incompatible shape to a "
958 "list at index ",
959 b, ". Op element shape: ", input_element_shape.DebugString(),
960 " list shape: ", l->element_shape.DebugString()));
961 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
962 errors::InvalidArgument(
963 "Invalid data type at index ", b, "; op elements ",
964 DataTypeString(element_dtype_), " but list elements ",
965 DataTypeString(l->element_dtype)));
966 tl_batch.push_back(l);
967 }
968
969 Tensor* result;
970
971 if (ok_to_alias) {
972 result = tls_alias.get();
973 c->set_output(0, *result);
974 } else {
975 // DT_VARIANT tensors always allocated on host.
976 AllocatorAttributes attr;
977 attr.set_on_host(true);
978 OP_REQUIRES_OK(
979 c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
980 }
981
982 if (batch_size == 0) {
983 return;
984 }
985
986 auto input_t = input.flat_outer_dims<T, 2>();
987 auto result_t = result->vec<Variant>();
988
989 for (int64 b = 0; b < batch_size; ++b) {
990 if (!ok_to_alias) {
991 result_t(b) = tl_batch[b]->Copy();
992 }
993 TensorList* output = result_t(b).get<TensorList>();
994 DCHECK(output != nullptr);
995 Tensor* frame;
996 PersistentTensor tmp;
997 OP_REQUIRES_OK(c, c->allocate_persistent(
998 element_dtype_, input_element_shape, &tmp, &frame));
999 if (input_element_shape.num_elements() > 0) {
1000 auto frame_t = frame->flat<T>();
1001 frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
1002 }
1003 output->tensors().push_back(std::move(*frame));
1004 }
1005 }
1006
1007 private:
1008 DataType element_dtype_;
1009 };
1010
1011 } // namespace tensorflow
1012
1013 #endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
1014