1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
16 #define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
17
18 #define EIGEN_USE_THREADS
19 #if GOOGLE_CUDA
20 #define EIGEN_USE_GPU
21 #endif // GOOGLE_CUDA
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 #include "tensorflow/core/kernels/concat_lib.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/lib/core/coding.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/util/tensor_ops_util.h"
36 #include "tensorflow/core/util/util.h"
37
38 namespace tensorflow {
39
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41
42 // Variant compatible type for a list of tensors. This is mutable but instances
43 // should never be mutated after stored in a variant tensor.
44 struct TensorList {
45 public:
TensorListTensorList46 TensorList() {}
47 TensorList(const TensorList& other);
48
49 static const char kTypeName[];
TypeNameTensorList50 string TypeName() const { return kTypeName; }
51
52 void Encode(VariantTensorData* data) const;
53
54 bool Decode(const VariantTensorData& data);
55
56 // TODO(apassos) fill this out
DebugStringTensorList57 string DebugString() const { return "TensorList"; }
58
59 std::vector<Tensor> tensors;
60 PartialTensorShape element_shape;
61 DataType element_dtype;
62 // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
63 // of `tensors` is unbounded.
64 int max_num_elements = -1;
65 };
66
67 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
68
69 Status GetElementShapeFromInput(OpKernelContext* c,
70 const TensorList& tensor_list, int index,
71 PartialTensorShape* element_shape);
72
73 Status GetInputList(OpKernelContext* c, int index, const TensorList** list);
74
75 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
76 int32 output_index,
77 const TensorList& input_list,
78 TensorList** output_list);
79
80 template <typename Device, typename T>
81 class TensorListStack : public OpKernel {
82 public:
83 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
84 ConstMatrixVector;
TensorListStack(OpKernelConstruction * c)85 explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) {
86 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
87 OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_));
88 }
89
Compute(OpKernelContext * c)90 void Compute(OpKernelContext* c) override {
91 const TensorList* tensor_list = nullptr;
92 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
93 OP_REQUIRES(
94 c, element_dtype_ == tensor_list->element_dtype,
95 errors::InvalidArgument(
96 "Invalid data types; op elements ", DataTypeString(element_dtype_),
97 " but list elements ", DataTypeString(tensor_list->element_dtype)));
98 if (num_elements_ != -1) {
99 OP_REQUIRES(c, tensor_list->tensors.size() == num_elements_,
100 errors::InvalidArgument(
101 "Operation expected a list with ", num_elements_,
102 " elements but got a list with ",
103 tensor_list->tensors.size(), " elements."));
104 }
105 PartialTensorShape partial_element_shape;
106 OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
107 &partial_element_shape));
108 OP_REQUIRES(
109 c,
110 partial_element_shape.IsFullyDefined() || !tensor_list->tensors.empty(),
111 errors::InvalidArgument("Tried to stack elements of an empty ",
112 "list with non-fully-defined element_shape: ",
113 partial_element_shape.DebugString()));
114
115 // Check that `element_shape` input tensor is compatible with the shapes of
116 // element tensors.
117 if (!tensor_list->element_shape.IsFullyDefined()) {
118 for (int i = 0; i < tensor_list->tensors.size(); ++i) {
119 const Tensor& t = tensor_list->tensors[i];
120 if (t.dtype() != DT_INVALID) {
121 PartialTensorShape tmp = partial_element_shape;
122 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
123 }
124 }
125 }
126
127 // Compute the shape of the output tensor by pre-pending the leading dim to
128 // the element_shape.
129 TensorShape element_shape;
130 OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape),
131 errors::InvalidArgument(
132 "Tried to stack list which only contains uninitialized ",
133 "tensors and has a non-fully-defined element_shape: ",
134 partial_element_shape.DebugString()));
135 TensorShape output_shape = element_shape;
136 output_shape.InsertDim(0, tensor_list->tensors.size());
137 Tensor* output;
138 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
139 if (output->NumElements() == 0) {
140 return;
141 }
142
143 ConstMatrixVector inputs_flat;
144 inputs_flat.reserve(tensor_list->tensors.size());
145 Tensor zeros;
146 for (const auto& t : tensor_list->tensors) {
147 if (t.dtype() != DT_INVALID) {
148 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
149 t.shaped<T, 2>({1, t.NumElements()})));
150 } else {
151 if (!zeros.NumElements()) {
152 AllocatorAttributes attr;
153 if (element_dtype_ == DT_VARIANT) {
154 attr.set_on_host(true);
155 }
156 OP_REQUIRES_OK(
157 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
158 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
159 zeros.flat<T>());
160 }
161 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
162 const_cast<const Tensor&>(zeros).shaped<T, 2>(
163 {1, zeros.NumElements()})));
164 }
165 }
166 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
167
168 #if GOOGLE_CUDA
169 if (std::is_same<Device, Eigen::GpuDevice>::value) {
170 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
171 return;
172 }
173 #endif // GOOGLE_CUDA
174 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
175 }
176
177 private:
178 int num_elements_;
179 DataType element_dtype_;
180 };
181
182 template <typename Device, typename T>
183 class TensorListGetItem : public OpKernel {
184 public:
TensorListGetItem(OpKernelConstruction * c)185 explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
186 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
187 }
188
Compute(OpKernelContext * c)189 void Compute(OpKernelContext* c) override {
190 const TensorList* l = nullptr;
191 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
192 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
193 errors::InvalidArgument("Invalid data types; op elements ",
194 DataTypeString(element_dtype_),
195 " but list elements ",
196 DataTypeString(l->element_dtype)));
197 int32 index = c->input(1).scalar<int32>()();
198 OP_REQUIRES(c, index < l->tensors.size(),
199 errors::InvalidArgument("Trying to access element ", index,
200 " in a list with ", l->tensors.size(),
201 " elements."));
202 if (l->tensors[index].dtype() != DT_INVALID) {
203 c->set_output(0, l->tensors[index]);
204 } else {
205 PartialTensorShape partial_element_shape;
206 OP_REQUIRES_OK(
207 c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape));
208 TensorShape element_shape;
209 // If l->element_shape and the element_shape input are both not fully
210 // defined, try to infer the shape from other list elements. This requires
211 // that all initialized list elements have the same shape.
212 // NOTE(srbs): This might be a performance bottleneck since we are
213 // iterating over the entire list here. This is necessary for feature
214 // parity with TensorArray.read. TensorArray has a mode in which all
215 // elements are required to be of the same shape, TensorList does not.
216 // In that mode TensorArray sets the array's element_shape on the first
217 // write call. We could do something similar here if needed.
218 if (!partial_element_shape.IsFullyDefined()) {
219 for (const Tensor& t : l->tensors) {
220 if (t.dtype() != DT_INVALID) {
221 PartialTensorShape tmp = partial_element_shape;
222 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
223 }
224 }
225 }
226 OP_REQUIRES(
227 c, partial_element_shape.AsTensorShape(&element_shape),
228 errors::InvalidArgument("Trying to read an uninitialized tensor but ",
229 "element_shape is not fully defined: ",
230 partial_element_shape.DebugString(),
231 " and no list element is set."));
232 Tensor* result;
233 AllocatorAttributes attr;
234 if (element_dtype_ == DT_VARIANT) {
235 attr.set_on_host(true);
236 }
237 OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr));
238 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
239 result->flat<T>());
240 }
241 }
242
243 private:
244 DataType element_dtype_;
245 };
246
247 template <typename Device, typename T>
248 class TensorListPopBack : public OpKernel {
249 public:
TensorListPopBack(OpKernelConstruction * c)250 explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
251 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
252 }
253
Compute(OpKernelContext * c)254 void Compute(OpKernelContext* c) override {
255 const TensorList* l = nullptr;
256 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
257 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
258 errors::InvalidArgument("Invalid data types; op elements ",
259 DataTypeString(element_dtype_),
260 " but list elements ",
261 DataTypeString(l->element_dtype)));
262
263 OP_REQUIRES(c, !l->tensors.empty(),
264 errors::InvalidArgument("Trying to pop from an empty list."));
265
266 const Tensor& t = l->tensors.back();
267 if (t.dtype() != DT_INVALID) {
268 c->set_output(1, t);
269 } else {
270 PartialTensorShape partial_element_shape;
271 OP_REQUIRES_OK(
272 c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape));
273 TensorShape element_shape;
274 OP_REQUIRES(
275 c, partial_element_shape.AsTensorShape(&element_shape),
276 errors::InvalidArgument("Trying to read an uninitialized tensor but ",
277 "element_shape is not fully defined.",
278 partial_element_shape.DebugString()));
279 Tensor* result;
280 AllocatorAttributes attr;
281 if (element_dtype_ == DT_VARIANT) {
282 attr.set_on_host(true);
283 }
284 OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr));
285 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
286 result->flat<T>());
287 }
288
289 TensorList* output_list = nullptr;
290 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
291 output_list->tensors.pop_back();
292 }
293
294 private:
295 DataType element_dtype_;
296 };
297
298 template <typename Device, typename T>
299 class TensorListConcat : public OpKernel {
300 public:
301 using ConstMatrixVector =
302 std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>;
TensorListConcat(OpKernelConstruction * c)303 explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) {
304 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
305 // TODO(skyewm): the HasAttr check can be removed once the
306 // element_shape_except_first_dim attr has been checked in for 2 weeks
307 // (around 1/14/2019).
308 if (c->HasAttr("element_shape")) {
309 PartialTensorShape element_shape;
310 OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape));
311 if (!element_shape.unknown_rank()) {
312 element_shape_except_first_dim_ = PartialTensorShape(
313 gtl::ArraySlice<int64>(element_shape.dim_sizes()).subspan(1));
314 }
315 }
316 }
317
Compute(OpKernelContext * c)318 void Compute(OpKernelContext* c) override {
319 // Check that the input Variant tensor is indeed a TensorList and has the
320 // correct element type.
321 const TensorList* tensor_list = nullptr;
322 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
323 OP_REQUIRES(
324 c, element_dtype_ == tensor_list->element_dtype,
325 errors::InvalidArgument(
326 "Invalid data types; op elements ", DataTypeString(element_dtype_),
327 " but list elements ", DataTypeString(tensor_list->element_dtype)));
328 // The leading dimension of all list elements if they are all the same.
329 // This is used as the leading dim of uninitialized tensors in the list
330 // if leading_dims is not provided.
331 int64 first_dim = -1;
332 if (c->num_inputs() > 1) {
333 // TensorListConcatV2
334 PartialTensorShape element_shape;
335 OP_REQUIRES_OK(
336 c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape));
337 OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
338 errors::InvalidArgument(
339 "Concat requires elements to be at least vectors, ",
340 "found scalars instead."));
341 // Split `element_shape` into `first_dim` and
342 // `element_shape_except_first_dim_`.
343 first_dim = element_shape.dim_size(0);
344 element_shape_except_first_dim_ = element_shape;
345 element_shape_except_first_dim_.RemoveDim(0);
346 }
347 // If the TensorList is empty, element_shape_except_first_dim_ must be fully
348 // defined.
349 OP_REQUIRES(c,
350 !tensor_list->tensors.empty() ||
351 element_shape_except_first_dim_.IsFullyDefined(),
352 errors::InvalidArgument(
353 "All except the first dimension must be fully defined ",
354 "when concating an empty tensor list. element_shape: ",
355 element_shape_except_first_dim_.DebugString()));
356 // 1. Check that `element_shape_except_first_dim_` input tensor is
357 // compatible with the shapes of element tensors.
358 // 2. Check that the elements have the same shape except the first dim.
359 // 3. If `first_dim` is known, check that it is compatible with the leading
360 // dims of all elements.
361 // 4. If `first_dim` is unknown (-1), check whether all initialized
362 // elements have the same leading dim and if so set `first_dim` to that
363 // value.
364 if (!tensor_list->element_shape.IsFullyDefined()) {
365 bool check_dim = (first_dim == -1);
366 int64 inferred_first_dim = first_dim;
367 for (int i = 0; i < tensor_list->tensors.size(); ++i) {
368 const Tensor& t = tensor_list->tensors[i];
369 if (t.dtype() != DT_INVALID) {
370 PartialTensorShape tmp = element_shape_except_first_dim_;
371 OP_REQUIRES(
372 c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
373 errors::InvalidArgument("Concat saw a scalar shape at index ", i,
374 " but requires at least vectors."));
375 TensorShape shape_except_first_dim = TensorShape(
376 gtl::ArraySlice<int64>(t.shape().dim_sizes()).subspan(1));
377 OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
378 &element_shape_except_first_dim_));
379 OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0),
380 errors::InvalidArgument(
381 "First entry of element_shape input does not match ",
382 "the first dim of list element at index: ", i,
383 " Expected: ", first_dim,
384 " Actual: ", t.shape().dim_size(0)));
385 if (check_dim) {
386 if (inferred_first_dim == -1) {
387 inferred_first_dim = t.shape().dim_size(0);
388 } else if (inferred_first_dim != t.shape().dim_size(0)) {
389 inferred_first_dim = -1;
390 check_dim = false;
391 }
392 }
393 }
394 }
395 first_dim = inferred_first_dim;
396 }
397 TensorShape output_shape;
398 OP_REQUIRES(
399 c, element_shape_except_first_dim_.AsTensorShape(&output_shape),
400 errors::InvalidArgument(
401 "Trying to concat list with only uninitialized tensors ",
402 "but element_shape_except_first_dim_ is not fully defined: ",
403 element_shape_except_first_dim_.DebugString()));
404 // Build the lengths_tensor and leading dim of the output tensor by
405 // iterating over all element tensors.
406 Tensor* lengths_tensor = nullptr;
407 OP_REQUIRES_OK(
408 c,
409 c->allocate_output(
410 1, TensorShape({static_cast<int64>(tensor_list->tensors.size())}),
411 &lengths_tensor));
412 auto lengths_tensor_vec = lengths_tensor->vec<int64>();
413 int64 leading_dim = 0;
414 for (size_t i = 0; i < tensor_list->tensors.size(); i++) {
415 int64 dim;
416 if (tensor_list->tensors[i].dtype() != DT_INVALID) {
417 dim = tensor_list->tensors[i].shape().dim_size(0);
418 } else {
419 // If leading_dims is not provided or does not contain an entry for
420 // index i use the inferred `first_dim` if set.
421 if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) &&
422 first_dim != -1) {
423 dim = first_dim;
424 } else {
425 OP_REQUIRES(c, c->num_inputs() > 2,
426 errors::InvalidArgument(
427 "Concating lists with uninitialized tensors is not ",
428 "supported in this version of TensorListConcat. ",
429 "Consider updating your GraphDef to run the newer ",
430 "version."));
431 OP_REQUIRES(c, i < c->input(2).NumElements(),
432 errors::InvalidArgument(
433 "List contains uninitialized tensor at index ", i,
434 " but leading_dims has only ",
435 c->input(2).NumElements(), " elements."));
436 dim = c->input(2).vec<int64>()(i);
437 }
438 }
439 leading_dim += dim;
440 lengths_tensor_vec(i) = dim;
441 }
442 output_shape.InsertDim(0, leading_dim);
443 Tensor* output;
444 // Allocate the output tensor and fill it up with the concated element
445 // tensors.
446 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
447 if (output->NumElements() == 0) {
448 return;
449 }
450
451 ConstMatrixVector inputs_flat;
452 inputs_flat.reserve(tensor_list->tensors.size());
453 // Store the zeros tensors in a vector to prevent them from being GC'ed till
454 // concat is complete.
455 std::vector<Tensor> zeros_vec;
456 for (int i = 0; i < tensor_list->tensors.size(); i++) {
457 const Tensor& element_tensor = tensor_list->tensors[i];
458 if (element_tensor.dtype() != DT_INVALID) {
459 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
460 element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
461 } else {
462 AllocatorAttributes attr;
463 if (element_dtype_ == DT_VARIANT) {
464 attr.set_on_host(true);
465 }
466 TensorShape element_shape = output_shape;
467 element_shape.set_dim(0, lengths_tensor_vec(i));
468 zeros_vec.emplace_back();
469 Tensor& zeros = zeros_vec.back();
470 OP_REQUIRES_OK(
471 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
472 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
473 zeros.flat<T>());
474 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
475 const_cast<const Tensor&>(zeros).shaped<T, 2>(
476 {1, zeros.NumElements()})));
477 }
478 }
479 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
480
481 #if GOOGLE_CUDA
482 if (std::is_same<Device, Eigen::GpuDevice>::value) {
483 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
484 return;
485 }
486 #endif // GOOGLE_CUDA
487 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
488 }
489
490 private:
491 DataType element_dtype_;
492 PartialTensorShape element_shape_except_first_dim_;
493 };
494
495 template <typename Device, typename T>
496 class TensorListSplit : public OpKernel {
497 public:
TensorListSplit(OpKernelConstruction * c)498 TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {}
499
Compute(OpKernelContext * c)500 void Compute(OpKernelContext* c) override {
501 Tensor* output_tensor;
502 AllocatorAttributes attr;
503 attr.set_on_host(true);
504 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
505 PartialTensorShape element_shape;
506 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
507 OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
508 errors::InvalidArgument(
509 "TensorListSplit requires element_shape to be at least of ",
510 "rank 1, but saw: ", element_shape.DebugString()));
511 TensorList output_list;
512 const Tensor& input_tensor = c->input(0);
513 output_list.element_dtype = input_tensor.dtype();
514 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
515 errors::InvalidArgument(
516 "Tensor must be at least a vector, but saw shape: ",
517 input_tensor.shape().DebugString()));
518 TensorShape tensor_shape_without_first_dim(input_tensor.shape());
519 tensor_shape_without_first_dim.RemoveDim(0);
520 PartialTensorShape element_shape_without_first_dim;
521 if (!element_shape.unknown_rank()) {
522 element_shape_without_first_dim =
523 PartialTensorShape(element_shape.dim_sizes());
524 element_shape_without_first_dim.RemoveDim(0);
525 }
526 OP_REQUIRES(c,
527 element_shape_without_first_dim.IsCompatibleWith(
528 tensor_shape_without_first_dim),
529 errors::InvalidArgument(
530 "tensor shape ", input_tensor.shape().DebugString(),
531 " is not compatible with element_shape ",
532 element_shape.DebugString()));
533 output_list.element_shape = element_shape;
534 const Tensor& lengths = c->input(2);
535 OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()),
536 errors::InvalidArgument(
537 "Expected lengths to be a vector, received shape: ",
538 lengths.shape().DebugString()));
539 output_list.tensors.reserve(lengths.shape().dim_size(0));
540 int64 start = 0;
541 int64 end = 0;
542 for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
543 int64 length = lengths.vec<int64>()(i);
544 OP_REQUIRES(
545 c, length >= 0,
546 errors::InvalidArgument("Invalid value in lengths: ", length));
547 end = start + length;
548 OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0),
549 errors::InvalidArgument("Attempting to slice [", start, ", ",
550 end, "] from tensor with length ",
551 input_tensor.shape().dim_size(0)));
552 Tensor tmp = input_tensor.Slice(start, end);
553 start = end;
554 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
555 // prevent this.
556 Tensor aligned;
557 OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
558 aligned.flat<T>().device(c->eigen_device<Device>()) =
559 tmp.unaligned_flat<T>();
560 output_list.tensors.emplace_back(aligned);
561 }
562 OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
563 errors::InvalidArgument(
564 "Unused values in tensor. Length of tensor: ",
565 input_tensor.shape().dim_size(0), " Values used: ", end));
566 output_tensor->scalar<Variant>()() = std::move(output_list);
567 }
568 };
569
570 template <typename Device, typename T>
571 class TensorListGather : public OpKernel {
572 public:
573 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
574 ConstMatrixVector;
TensorListGather(OpKernelConstruction * c)575 explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
576 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
577 }
578
Compute(OpKernelContext * c)579 void Compute(OpKernelContext* c) override {
580 const TensorList* tensor_list = nullptr;
581 OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
582 OP_REQUIRES(
583 c, element_dtype_ == tensor_list->element_dtype,
584 errors::InvalidArgument(
585 "Invalid data types; op elements ", DataTypeString(element_dtype_),
586 " but list elements ", DataTypeString(tensor_list->element_dtype)));
587 const Tensor& indices = c->input(1);
588 PartialTensorShape partial_element_shape;
589 OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2,
590 &partial_element_shape));
591 OP_REQUIRES(
592 c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0,
593 errors::InvalidArgument("Tried to gather 0-elements from "
594 "a list with non-fully-defined shape: ",
595 partial_element_shape.DebugString()));
596
597 // Check that `element_shape` input tensor is compatible with the shapes of
598 // element tensors.
599 if (!tensor_list->element_shape.IsFullyDefined()) {
600 for (int index = 0; index < indices.NumElements(); ++index) {
601 const int i = indices.flat<int32>()(index);
602 const Tensor& t = tensor_list->tensors[i];
603 if (t.dtype() != DT_INVALID) {
604 PartialTensorShape tmp = partial_element_shape;
605 OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
606 }
607 }
608 }
609
610 // Compute the shape of the output tensor by pre-pending the leading dim to
611 // the element_shape.
612 TensorShape element_shape;
613 OP_REQUIRES(
614 c, partial_element_shape.AsTensorShape(&element_shape),
615 errors::InvalidArgument("Tried to gather uninitialized tensors from a ",
616 "list with non-fully-defined element_shape: ",
617 partial_element_shape.DebugString()));
618 TensorShape output_shape = element_shape;
619 output_shape.InsertDim(0, indices.NumElements());
620 Tensor* output;
621 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
622 if (output->NumElements() == 0) {
623 return;
624 }
625
626 ConstMatrixVector inputs_flat;
627 inputs_flat.reserve(indices.NumElements());
628 Tensor zeros;
629 for (int index = 0; index < indices.NumElements(); ++index) {
630 const int i = indices.flat<int32>()(index);
631 OP_REQUIRES(
632 c, i < tensor_list->tensors.size(),
633 errors::InvalidArgument("Index ", i, " out o range; list only has ",
634 tensor_list->tensors.size(), " elements."));
635 const Tensor& t = tensor_list->tensors[i];
636 if (t.dtype() != DT_INVALID) {
637 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
638 t.shaped<T, 2>({1, t.NumElements()})));
639 } else {
640 if (!zeros.NumElements()) {
641 AllocatorAttributes attr;
642 if (element_dtype_ == DT_VARIANT) {
643 attr.set_on_host(true);
644 }
645 OP_REQUIRES_OK(
646 c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
647 functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
648 zeros.flat<T>());
649 }
650 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
651 const_cast<const Tensor&>(zeros).shaped<T, 2>(
652 {1, zeros.NumElements()})));
653 }
654 }
655 auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
656
657 #if GOOGLE_CUDA
658 if (std::is_same<Device, Eigen::GpuDevice>::value) {
659 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
660 return;
661 }
662 #endif // GOOGLE_CUDA
663 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
664 }
665
666 private:
667 DataType element_dtype_;
668 };
669
670 template <typename Device, typename T>
671 class TensorListFromTensor : public OpKernel {
672 public:
TensorListFromTensor(OpKernelConstruction * c)673 TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
674
Compute(OpKernelContext * c)675 void Compute(OpKernelContext* c) override {
676 Tensor* output_tensor;
677 AllocatorAttributes attr;
678 attr.set_on_host(true);
679 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
680 PartialTensorShape element_shape;
681 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
682 TensorList output_list;
683 const Tensor& t = c->input(0);
684 output_list.element_dtype = t.dtype();
685 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
686 errors::InvalidArgument(
687 "Tensor must be at least a vector, but saw shape: ",
688 t.shape().DebugString()));
689 TensorShape output_shape(t.shape());
690 output_shape.RemoveDim(0);
691 OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
692 errors::InvalidArgument(
693 "Specified a list with shape ", element_shape.DebugString(),
694 " from a tensor with shape ", output_shape.DebugString()));
695 output_list.element_shape = element_shape;
696 output_list.tensors.reserve(t.shape().dim_size(0));
697 for (int i = 0; i < t.shape().dim_size(0); ++i) {
698 Tensor tmp = t.Slice(i, i + 1);
699 TensorShape tmp_shape = tmp.shape();
700 tmp_shape.RemoveDim(0);
701 OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
702 errors::Unknown("Unexpected shape error."));
703 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
704 // prevent this.
705 Tensor aligned;
706 OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
707 aligned.flat<T>().device(c->eigen_device<Device>()) =
708 tmp.unaligned_flat<T>();
709 output_list.tensors.push_back(aligned);
710 }
711 output_tensor->scalar<Variant>()() = std::move(output_list);
712 }
713 };
714
715 // Scatters values in `value` into `list`. Assumes that `indices` are valid.
716 template <typename Device, typename T>
Scatter(OpKernelContext * c,const Tensor & value,const Tensor & indices,TensorList * list)717 Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
718 TensorList* list) {
719 for (int index = 0; index < indices.NumElements(); ++index) {
720 const int i = indices.flat<int32>()(index);
721 Tensor tmp = value.Slice(index, index + 1);
722 TensorShape tmp_shape = tmp.shape();
723 tmp_shape.RemoveDim(0);
724 if (!tmp.CopyFrom(tmp, tmp_shape)) {
725 return errors::Unknown("Unexpected shape error.");
726 }
727 // TODO(apassos) maybe not always align; but weird compiler bugs seem to
728 // prevent this.
729 Tensor aligned;
730 TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
731 // TODO(apassos) do all slices in a single kernel invocation instead of
732 // many small ones.
733 aligned.flat<T>().device(c->eigen_device<Device>()) =
734 tmp.unaligned_flat<T>();
735 std::swap(list->tensors[i], aligned);
736 }
737 return Status::OK();
738 }
739
740 template <typename Device, typename T>
741 class TensorListScatterIntoExistingList : public OpKernel {
742 public:
TensorListScatterIntoExistingList(OpKernelConstruction * c)743 TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {}
744
Compute(OpKernelContext * c)745 void Compute(OpKernelContext* c) override {
746 const TensorList* l = nullptr;
747 OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
748 const Tensor& input_tensor = c->input(1);
749 const Tensor& indices = c->input(2);
750
751 // Check that inputs are valid.
752 OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype,
753 errors::InvalidArgument(
754 "Invalid data types; input tensor type: ",
755 DataTypeString(input_tensor.dtype()),
756 " list element_type: ", DataTypeString(l->element_dtype)));
757 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
758 errors::InvalidArgument(
759 "Tensor must be at least a vector, but saw shape: ",
760 input_tensor.shape().DebugString()));
761 OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()),
762 errors::InvalidArgument(
763 "Expected indices to be a vector, but received shape: ",
764 indices.shape().DebugString()));
765 OP_REQUIRES(
766 c, indices.NumElements() == input_tensor.shape().dim_size(0),
767 errors::InvalidArgument(
768 "Expected len(indices) == tensor.shape[0], but saw: ",
769 indices.NumElements(), " vs. ", input_tensor.shape().dim_size(0)));
770
771 // Resize the list if needed to accommodate all indices.
772 TensorList* output_list = nullptr;
773 OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
774 const auto indices_vec = indices.vec<int32>();
775 int32 max_index =
776 (indices.NumElements() == 0)
777 ? -1
778 : *std::max_element(indices_vec.data(),
779 indices_vec.data() + indices.NumElements());
780 if (max_index + 1 > output_list->tensors.size()) {
781 output_list->tensors.resize(max_index + 1);
782 }
783
784 // Scatter the values.
785 OP_REQUIRES_OK(c,
786 Scatter<Device, T>(c, input_tensor, indices, output_list));
787 }
788 };
789
790 template <typename Device, typename T>
791 class TensorListScatter : public OpKernel {
792 public:
TensorListScatter(OpKernelConstruction * c)793 TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
794
Compute(OpKernelContext * c)795 void Compute(OpKernelContext* c) override {
796 Tensor* output_tensor;
797 AllocatorAttributes attr;
798 attr.set_on_host(true);
799 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
800 Tensor indices = c->input(1);
801 PartialTensorShape element_shape;
802 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
803 // TensorListScatterV2 passes the num_elements input, TensorListScatter does
804 // not.
805 int num_elements = c->num_inputs() >= 4 ? c->input(3).scalar<int>()() : -1;
806 OP_REQUIRES(c, num_elements >= -1,
807 errors::InvalidArgument(
808 "TensorListScatter expects num_elements >= -1, found: ",
809 num_elements));
810 TensorList output_list;
811 const Tensor& input_tensor = c->input(0);
812 output_list.element_dtype = input_tensor.dtype();
813 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
814 errors::InvalidArgument(
815 "Tensor must be at least a vector, but saw shape: ",
816 input_tensor.shape().DebugString()));
817 TensorShape output_shape(input_tensor.shape());
818 output_shape.RemoveDim(0);
819 OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
820 errors::InvalidArgument(
821 "Specified a list with shape ", element_shape.DebugString(),
822 " from a tensor with shape ", output_shape.DebugString()));
823 output_list.element_shape = element_shape;
824
825 OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
826 errors::InvalidArgument(
827 "Invalid number of rows in input tensor. Expected: ",
828 indices.NumElements(),
829 " Actual: ", input_tensor.shape().dim_size(0)));
830
831 // Validate indices and resize output_list.tensors to fit the highest index.
832 {
833 int highest_index = -1;
834 for (int index = 0; index < indices.NumElements(); ++index) {
835 const int i = indices.flat<int32>()(index);
836 OP_REQUIRES(
837 c, i >= 0,
838 errors::InvalidArgument(
839 "Indices in TensorListScatter must all be non-negative."));
840 OP_REQUIRES(c, num_elements == -1 || i < num_elements,
841 errors::InvalidArgument(
842 "TensorListScatter: Trying to scatter at index ", i,
843 " in list with size ", num_elements));
844 if (i > highest_index) {
845 highest_index = i;
846 }
847 }
848 output_list.tensors.resize(std::max(highest_index + 1, num_elements),
849 Tensor(DT_INVALID));
850 }
851
852 OP_REQUIRES_OK(c,
853 Scatter<Device, T>(c, input_tensor, indices, &output_list));
854 output_tensor->scalar<Variant>()() = std::move(output_list);
855 }
856 };
857
858 template <typename Device>
TensorListBinaryAdd(OpKernelContext * c,const TensorList & a,const TensorList & b,TensorList * out)859 Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
860 const TensorList& b, TensorList* out) {
861 if (a.element_dtype != b.element_dtype) {
862 return errors::InvalidArgument(
863 "Trying to add two lists of tensors of different dtypes. One is ",
864 DataTypeString(a.element_dtype), " and the other is ",
865 DataTypeString(b.element_dtype));
866 }
867 out->element_dtype = a.element_dtype;
868 if (!a.element_shape.IsCompatibleWith(b.element_shape)) {
869 return errors::InvalidArgument(
870 "Trying to add two lists of tensors with incompatible element shapes. "
871 "One is ",
872 a.element_shape.DebugString(), " and the other is ",
873 b.element_shape.DebugString());
874 }
875
876 TF_RETURN_IF_ERROR(
877 a.element_shape.MergeWith(b.element_shape, &out->element_shape));
878 if (a.tensors.size() != b.tensors.size()) {
879 return errors::InvalidArgument(
880 "Trying to add two lists of tensors with different lengths. One is ",
881 a.tensors.size(), " and the other is ", b.tensors.size());
882 }
883 out->tensors.reserve(a.tensors.size());
884 for (int i = 0; i < a.tensors.size(); ++i) {
885 const Tensor& a_tensor = a.tensors[i];
886 const Tensor& b_tensor = b.tensors[i];
887 Tensor out_tensor;
888 TF_RETURN_IF_ERROR(
889 BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
890 out->tensors.push_back(out_tensor);
891 }
892 return Status::OK();
893 }
894
895 template <typename Device>
TensorListZerosLike(OpKernelContext * c,const TensorList & x,TensorList * y)896 Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
897 TensorList* y) {
898 y->element_dtype = x.element_dtype;
899 y->element_shape = x.element_shape;
900 y->tensors.reserve(x.tensors.size());
901 for (const Tensor& t : x.tensors) {
902 Tensor out_tensor;
903 TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
904 y->tensors.emplace_back(out_tensor);
905 }
906 return Status::OK();
907 }
908
909 template <typename Device, typename T>
910 class TensorListPushBackBatch : public OpKernel {
911 public:
TensorListPushBackBatch(OpKernelConstruction * c)912 explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
913 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
914 }
915
Compute(OpKernelContext * c)916 void Compute(OpKernelContext* c) override {
917 const Tensor& input = c->input(1);
918 OP_REQUIRES(c, element_dtype_ == input.dtype(),
919 errors::InvalidArgument("Invalid data types; list elements ",
920 DataTypeString(element_dtype_),
921 " but tried to append ",
922 DataTypeString(input.dtype())));
923 OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
924 errors::InvalidArgument(
925 "Expected tensor to be at least a vector, but saw shape: ",
926 input.shape().DebugString()));
927
928 const TensorShape& tls_shape = c->input(0).shape();
929
930 // For purposes of input forwarding, we want the least restrictive
931 // AllocatorAttributes possible. If we need to allocate later,
932 // we'll request the DT_VARIANT be allocated on host.
933 AllocatorAttributes attr;
934
935 std::unique_ptr<Tensor> tls_alias = c->forward_input(
936 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
937 DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
938
939 const Tensor& tls = tls_alias ? *tls_alias : c->input(0);
940
941 OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
942 errors::InvalidArgument(
943 "Expected input_handles dtype to be Variant, but saw: ",
944 DataTypeString(tls.dtype())));
945 OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
946 errors::InvalidArgument(
947 "Expected input_handles to be a vector, but saw shape: ",
948 tls_shape.DebugString()));
949 const int64 batch_size = tls.NumElements();
950 OP_REQUIRES(c, input.dim_size(0) == batch_size,
951 errors::InvalidArgument(
952 "Expected tensor.shape[0] == input_handles.size, but saw ",
953 input.dim_size(0), " vs. ", batch_size));
954 auto tls_t = tls.vec<Variant>();
955
956 TensorShape input_element_shape = input.shape();
957 input_element_shape.RemoveDim(0);
958 std::vector<const TensorList*> tl_batch;
959 for (int64 b = 0; b < batch_size; ++b) {
960 const TensorList* l = tls_t(b).get<TensorList>();
961 OP_REQUIRES(c, l != nullptr,
962 errors::InvalidArgument("Input handle at index ", b,
963 " is not a list. Saw: '",
964 tls_t(b).DebugString(), "'"));
965 OP_REQUIRES(
966 c, l->element_shape.IsCompatibleWith(input_element_shape),
967 errors::InvalidArgument(
968 "Tried to append a tensor with incompatible shape to a "
969 "list at index ",
970 b, ". Op element shape: ", input_element_shape.DebugString(),
971 " list shape: ", l->element_shape.DebugString()));
972 OP_REQUIRES(c, element_dtype_ == l->element_dtype,
973 errors::InvalidArgument(
974 "Invalid data type at index ", b, "; op elements ",
975 DataTypeString(element_dtype_), " but list elements ",
976 DataTypeString(l->element_dtype)));
977 tl_batch.push_back(l);
978 }
979
980 Tensor* result;
981
982 if (tls_alias) {
983 result = tls_alias.get();
984 c->set_output(0, *result);
985 } else {
986 // DT_VARIANT tensors always allocated on host.
987 AllocatorAttributes attr;
988 attr.set_on_host(true);
989 OP_REQUIRES_OK(
990 c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
991 }
992
993 if (batch_size == 0) {
994 return;
995 }
996
997 auto input_t = input.flat_outer_dims<T, 2>();
998 auto result_t = result->vec<Variant>();
999
1000 for (int64 b = 0; b < batch_size; ++b) {
1001 if (!tls_alias) {
1002 result_t(b) = *tl_batch[b];
1003 }
1004 TensorList* output = result_t(b).get<TensorList>();
1005 DCHECK(output != nullptr);
1006 Tensor* frame;
1007 PersistentTensor tmp;
1008 OP_REQUIRES_OK(c, c->allocate_persistent(
1009 element_dtype_, input_element_shape, &tmp, &frame));
1010 if (input_element_shape.num_elements() > 0) {
1011 auto frame_t = frame->flat<T>();
1012 frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
1013 }
1014 output->tensors.push_back(std::move(*frame));
1015 }
1016 }
1017
1018 private:
1019 DataType element_dtype_;
1020 };
1021
1022 } // namespace tensorflow
1023
1024 #endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_
1025