1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA TensorList operators.
17
18 #include <limits>
19 #include <vector>
20
21 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
22 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/framework/bounds_check.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/partial_tensor_shape.h"
35 #include "tensorflow/core/framework/register_types.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_types.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace tensorflow {
44
45 namespace {
46
47 // GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
48 // may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
49 // dimension is static, a constant dimension is returned. If a dim is dynamic, a
50 // dynamic XlaOp representing the dynamic size is returned.
GetTensorListDynamicDims(XlaOpKernelContext * ctx,const xla::Shape & element_shape,const xla::Shape & list_shape,int64 num_elements)51 xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
52 XlaOpKernelContext* ctx, const xla::Shape& element_shape,
53 const xla::Shape& list_shape, int64 num_elements) {
54 std::vector<int64> dynamic_sizes;
55 // The multiplier can be a dynamic value.
56 TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
57 std::vector<bool> dims_are_dynamic;
58 TF_RETURN_IF_ERROR(
59 ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
60 bool leading_dim_is_dynamic;
61 TF_RETURN_IF_ERROR(
62 ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
63 std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
64 // Set dynamic dimension size to 0 for initialization value.
65 std::vector<xla::XlaOp> dynamic_dims;
66 if (leading_dim_is_dynamic) {
67 dynamic_dims.push_back(ctx->Input(1));
68 } else {
69 dynamic_dims.push_back(
70 xla::ConstantR0<int32>(ctx->builder(), num_elements));
71 }
72 for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
73 if (dims_are_dynamic[dim]) {
74 auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
75 dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
76 dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
77 dynamic_dims.push_back(dynamic_dim_size);
78 } else {
79 dynamic_dims.push_back(
80 xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
81 }
82 }
83 list_dynamic_dims.push_back(dynamic_dims);
84 return list_dynamic_dims;
85 }
86
87 class TensorListLengthOp : public XlaOpKernel {
88 public:
TensorListLengthOp(OpKernelConstruction * ctx)89 explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
90
Compile(XlaOpKernelContext * ctx)91 void Compile(XlaOpKernelContext* ctx) override {
92 int64 leading_dim;
93 xla::XlaOp leading_dim_size;
94 bool leading_dim_is_dynamic;
95 OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
96 &leading_dim_is_dynamic,
97 &leading_dim_size));
98 ctx->SetOutput(0, leading_dim_size);
99 }
100
101 private:
102 TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
103 };
104
105 REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
106
107 // "input" is the shape input for EmptyTensorList/TensorListReserve ops.
108 // If "input" is a compile time constant and not "unknown rank" (-1), return
109 // its value in "*shape".
TryGetElementShapeFromInput(XlaOpKernelContext * ctx,xla::XlaOp input,xla::PrimitiveType dtype,bool * got_shape,xla::Shape * shape)110 Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input,
111 xla::PrimitiveType dtype, bool* got_shape,
112 xla::Shape* shape) {
113 auto is_compile_time_constant_or = input.builder()->IsConstant(input);
114 TF_RETURN_IF_ERROR(is_compile_time_constant_or.status());
115
116 bool is_compile_time_constant = is_compile_time_constant_or.ValueOrDie();
117 if (!is_compile_time_constant) {
118 *got_shape = false;
119 return Status::OK();
120 }
121
122 PartialTensorShape partial_shape;
123 TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape));
124 if (!partial_shape.IsFullyDefined()) {
125 *got_shape = false;
126 return Status::OK();
127 }
128
129 *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes());
130 *got_shape = true;
131 return Status::OK();
132 }
133
134 class TensorListReserveOp : public XlaOpKernel {
135 public:
TensorListReserveOp(OpKernelConstruction * ctx)136 explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
137 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
138 // Only non-nested TensorList is supported for now.
139 OP_REQUIRES(
140 ctx, dtype_ != DT_VARIANT,
141 errors::Unimplemented(
142 "Only non-nested TensorList is supported for TensorListReserve."));
143 }
144
Compile(XlaOpKernelContext * ctx)145 void Compile(XlaOpKernelContext* ctx) override {
146 int64 num_elements;
147 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
148 bool num_element_is_dynamic;
149 OP_REQUIRES_OK(
150 ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
151 OP_REQUIRES(
152 ctx, num_elements >= 0,
153 errors::InvalidArgument(
154 "XLA compilation requires a fixed tensor list size. Set the number "
155 "of elements. This could also happen if you're using a TensorArray "
156 "in a while loop that does not have its maximum_iteration set, you "
157 "can fix this by setting maximum_iteration to a suitable value."));
158
159 // If element shape is compile time constant and it's not "unknown rank"
160 // shape (-1), create an initialized TensorList. Otherwise create an
161 // uninitialized TensorList.
162 xla::XlaOp element_shape_handle = ctx->Input(0);
163 xla::PrimitiveType type;
164 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
165 bool got_shape;
166 xla::Shape element_shape;
167 OP_REQUIRES_OK(ctx,
168 TryGetElementShapeFromInput(ctx, element_shape_handle, type,
169 &got_shape, &element_shape));
170 if (got_shape) {
171 xla::Shape list_shape;
172 OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
173 element_shape, num_elements,
174 num_element_is_dynamic, &list_shape));
175 // Set up dynamic dimension sizes to create the zero tensor.
176 auto list_dynamic_dims_or = GetTensorListDynamicDims(
177 ctx, element_shape, list_shape, num_elements);
178 OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
179 xla::XlaOp new_list;
180 OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
181 ctx->builder(), list_shape,
182 list_dynamic_dims_or.ValueOrDie(), &new_list));
183 xla::XlaOp result;
184 OP_REQUIRES_OK(
185 ctx,
186 SetTensorListPushIndex(
187 new_list, xla::ConstantR0<int32>(ctx->builder(), num_elements),
188 &result));
189 ctx->SetTensorListOutput(0, result);
190 return;
191 }
192
193 xla::XlaOp result = BuildUninitializedTensorList(
194 ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
195 ctx->SetTensorListOutput(0, result);
196 }
197
198 private:
199 DataType dtype_;
200
201 TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
202 };
203
204 REGISTER_XLA_OP(Name("TensorListReserve")
205 .CompileTimeConstantInput("element_shape")
206 .CompileTimeConstantInput("num_elements"),
207 TensorListReserveOp);
208
209 class EmptyTensorListOp : public XlaOpKernel {
210 public:
EmptyTensorListOp(OpKernelConstruction * ctx)211 explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
212 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
213 }
214
Compile(XlaOpKernelContext * ctx)215 void Compile(XlaOpKernelContext* ctx) override {
216 int64 max_num_elements;
217 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
218 bool num_element_is_dynamic;
219 OP_REQUIRES_OK(
220 ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
221 OP_REQUIRES(ctx, max_num_elements >= 0,
222 errors::InvalidArgument(
223 "XLA compilation requires a fixed tensor list size. Set "
224 "the max number of elements. This could also happen if "
225 "you're using a TensorArray in a while loop that does not "
226 "have its maximum_iteration set, you can fix this by "
227 "setting maximum_iteration to a suitable value."));
228
229 if (dtype_ != DT_VARIANT) {
230 // We are creating a non-nested TensorList.
231 // If element shape is compile time constant and it's not "unknown
232 // rank" shape (-1), create an initialized TensorList. Otherwise
233 // create an uninitialized TensorList.
234 xla::XlaOp element_shape_handle = ctx->Input(0);
235 xla::PrimitiveType type;
236 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
237 bool got_shape;
238 xla::Shape element_shape;
239 OP_REQUIRES_OK(
240 ctx, TryGetElementShapeFromInput(ctx, element_shape_handle, type,
241 &got_shape, &element_shape));
242 if (got_shape) {
243 xla::Shape list_shape;
244 OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
245 element_shape, max_num_elements,
246 num_element_is_dynamic, &list_shape));
247 // Set up dynamic dimension sizes to create the zero tensor.
248 auto list_dynamic_dims_or = GetTensorListDynamicDims(
249 ctx, element_shape, list_shape, max_num_elements);
250 OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
251
252 xla::XlaOp result;
253 OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
254 ctx->builder(), list_shape,
255 list_dynamic_dims_or.ValueOrDie(), &result));
256
257 ctx->SetTensorListOutput(0, result);
258 return;
259 }
260 }
261
262 // We are creating a nested TensorList or a non-nested TensorList with
263 // unknown shape. Just create an uninitialized TensorList.
264 xla::XlaOp result =
265 BuildUninitializedTensorList(ctx->builder(), max_num_elements,
266 num_element_is_dynamic, ctx->Input(1));
267 ctx->SetTensorListOutput(0, result);
268 }
269
270 private:
271 DataType dtype_;
272
273 TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
274 };
275
276 REGISTER_XLA_OP(Name("EmptyTensorList")
277 .CompileTimeConstantInput("element_shape")
278 .CompileTimeConstantInput("max_num_elements")
279 .AllowVariantTypes(),
280 EmptyTensorListOp);
281
282 class TensorListElementShapeOp : public XlaOpKernel {
283 public:
TensorListElementShapeOp(OpKernelConstruction * ctx)284 explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
285 : XlaOpKernel(ctx) {
286 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
287 }
288
Compile(XlaOpKernelContext * ctx)289 void Compile(XlaOpKernelContext* ctx) override {
290 // Check that the TensorList is initialized.
291 bool is_initialized;
292 OP_REQUIRES_OK(ctx,
293 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
294 OP_REQUIRES(ctx, is_initialized,
295 errors::InvalidArgument("TensorList is not initialized"));
296
297 // Only non-nested TensorList is supported for now.
298 bool is_nested;
299 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
300 OP_REQUIRES(ctx, !is_nested,
301 errors::Unimplemented("Only non-nested TensorList is supported "
302 "for TensorListElementShape."));
303
304 // For non-nested TensorList, element shape is the buffer shape without
305 // the first dimension.
306 xla::XlaBuilder* b = ctx->builder();
307 xla::Shape list_shape;
308 OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &list_shape));
309 list_shape.DeleteDimension(0);
310
311 switch (shape_type_) {
312 case DT_INT64:
313 ctx->SetOutput(0, xla::ConstantR1<int64>(b, list_shape.dimensions()));
314 break;
315 case DT_INT32: {
316 std::vector<int32> size;
317 for (int64 s : list_shape.dimensions()) {
318 size.push_back(s);
319 }
320 ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
321 break;
322 }
323 default:
324 ctx->CtxFailure(
325 errors::InvalidArgument("Unsupported shape type requested"));
326 return;
327 }
328 }
329
330 private:
331 DataType shape_type_;
332
333 TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
334 };
335
336 REGISTER_XLA_OP(Name("TensorListElementShape").IsMetadataOp(),
337 TensorListElementShapeOp);
338
339 class TensorListGetItemOp : public XlaOpKernel {
340 public:
TensorListGetItemOp(OpKernelConstruction * ctx)341 explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
342 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
343 }
344
Compile(XlaOpKernelContext * ctx)345 void Compile(XlaOpKernelContext* ctx) override {
346 // Check that the TensorList is initialized.
347 bool is_initialized;
348 OP_REQUIRES_OK(ctx,
349 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
350 OP_REQUIRES(ctx, is_initialized,
351 errors::InvalidArgument("TensorList is not initialized"));
352
353 // Only non-nested TensorList is supported for now.
354 bool is_nested;
355 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
356 OP_REQUIRES(ctx, !is_nested,
357 errors::Unimplemented("Only non-nested TensorList is supported "
358 "for TensorListGetItem."));
359
360 xla::XlaOp list = ctx->Input(0);
361 xla::XlaOp index = ctx->Input(1);
362
363 xla::XlaOp result;
364 OP_REQUIRES_OK(ctx, ExecuteTensorListGetItem(list, index, &result));
365
366 ctx->SetOutput(0, result);
367 }
368
369 private:
370 DataType dtype_;
371
372 TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp);
373 };
374
375 REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp);
376
377 class TensorListGatherOp : public XlaOpKernel {
378 public:
TensorListGatherOp(OpKernelConstruction * ctx)379 explicit TensorListGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
380 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
381 }
382
Compile(XlaOpKernelContext * ctx)383 void Compile(XlaOpKernelContext* ctx) override {
384 // Check that the TensorList is initialized.
385 bool is_initialized;
386 OP_REQUIRES_OK(ctx,
387 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
388 OP_REQUIRES(ctx, is_initialized,
389 errors::InvalidArgument("TensorList is not initialized"));
390
391 // Only non-nested TensorList is supported for now.
392 bool is_nested;
393 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
394 OP_REQUIRES(ctx, !is_nested,
395 errors::Unimplemented("Only non-nested TensorList is supported "
396 "for TensorListGather."));
397
398 DataType indices_type = ctx->input_type(1);
399
400 const TensorShape indices_shape = ctx->InputShape(1);
401 OP_REQUIRES(ctx, indices_shape.dims() == 1,
402 errors::InvalidArgument("indices must be rank 1"));
403
404 xla::XlaOp list = ctx->Input(0);
405 xla::XlaOp indices = ctx->Input(1);
406
407 xla::XlaOp buffer;
408 OP_REQUIRES_OK(ctx, GetTensorListBuffer(list, &buffer));
409 xla::Shape buffer_xla_shape;
410 OP_REQUIRES_OK(ctx, GetTensorListBufferShape(list, &buffer_xla_shape));
411 TensorShape buffer_shape;
412 OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(buffer_xla_shape, &buffer_shape));
413
414 xla::XlaOp result;
415 OP_REQUIRES_OK(
416 ctx, XlaGather(buffer, buffer_shape, indices, indices_shape, /*axis=*/0,
417 /*indices_are_nd=*/false, dtype_, indices_type,
418 ctx->builder(), &result));
419 ctx->SetOutput(0, result);
420 }
421
422 private:
423 DataType dtype_;
424
425 TF_DISALLOW_COPY_AND_ASSIGN(TensorListGatherOp);
426 };
427
428 REGISTER_XLA_OP(Name("TensorListGather"), TensorListGatherOp);
429
430 class TensorListStackOp : public XlaOpKernel {
431 public:
TensorListStackOp(OpKernelConstruction * ctx)432 explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
433
Compile(XlaOpKernelContext * ctx)434 void Compile(XlaOpKernelContext* ctx) override {
435 // Check that the TensorList is initialized.
436 bool is_initialized;
437 OP_REQUIRES_OK(ctx,
438 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
439 OP_REQUIRES(ctx, is_initialized,
440 errors::InvalidArgument("TensorList is not initialized"));
441
442 // Only non-nested TensorList is supported for now.
443 bool is_nested;
444 OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
445 OP_REQUIRES(ctx, !is_nested,
446 errors::Unimplemented("Only non-nested TensorList is supported "
447 "for TensorListGetItem."));
448
449 xla::XlaOp buffer;
450 OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
451 ctx->SetOutput(0, buffer);
452 }
453
454 private:
455 TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp);
456 };
457
458 REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp);
459
460 class TensorListConcatOp : public XlaOpKernel {
461 public:
TensorListConcatOp(OpKernelConstruction * ctx)462 explicit TensorListConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
463
Compile(XlaOpKernelContext * ctx)464 void Compile(XlaOpKernelContext* ctx) override {
465 xla::XlaOp input = ctx->Input(0);
466
467 // Check that the TensorList is initialized.
468 bool is_initialized;
469 OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized)));
470 OP_REQUIRES(ctx, is_initialized,
471 errors::InvalidArgument("TensorList is not initialized"));
472
473 // Only non-nested TensorList is supported for now.
474 bool is_nested;
475 OP_REQUIRES_OK(ctx, IsNestedTensorList(input, &is_nested));
476 OP_REQUIRES(ctx, !is_nested,
477 errors::Unimplemented("Only non-nested TensorList is supported "
478 "for TensorListConcat."));
479
480 xla::XlaOp buffer;
481 OP_REQUIRES_OK(ctx, GetTensorListBuffer(input, &buffer));
482
483 xla::XlaBuilder* b = input.builder();
484 auto shape_or = b->GetShape(buffer);
485 OP_REQUIRES_OK(ctx, shape_or.status());
486 xla::Shape element_shape = shape_or.ConsumeValueOrDie();
487 std::vector<int64> element_dims =
488 xla::SpanToVector(element_shape.dimensions());
489 OP_REQUIRES(
490 ctx, element_dims.size() > 1,
491 errors::Unimplemented("TensorList of scalars is not supported"));
492 int64 num_elements = element_dims[0];
493 int64 tensor_lengths = element_dims[1];
494
495 std::vector<int64> new_dims = {num_elements * tensor_lengths};
496
497 for (int i = 2; i < element_dims.size(); i++) {
498 new_dims.push_back(element_dims[i]);
499 }
500
501 xla::XlaOp out = xla::Reshape(buffer, new_dims);
502 ctx->SetOutput(0, out);
503
504 // Second output is a tensor of lengths of returned tensors.
505 xla::XlaOp lengths = xla::ConstantR1(b, num_elements, tensor_lengths);
506 ctx->SetOutput(1, lengths);
507 }
508
509 private:
510 TF_DISALLOW_COPY_AND_ASSIGN(TensorListConcatOp);
511 };
512
513 REGISTER_XLA_OP(Name("TensorListConcatV2"), TensorListConcatOp);
514
515 class TensorListSplitOp : public XlaOpKernel {
516 public:
TensorListSplitOp(OpKernelConstruction * ctx)517 explicit TensorListSplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
518 OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
519 // Only non-nested TensorList is supported for now.
520 OP_REQUIRES(
521 ctx, dtype_ != DT_VARIANT,
522 errors::Unimplemented(
523 "Only non-nested TensorList is supported for TensorListReserve."));
524 }
525
Compile(XlaOpKernelContext * ctx)526 void Compile(XlaOpKernelContext* ctx) override {
527 xla::XlaOp input_tensor = ctx->Input(0);
528
529 xla::XlaBuilder* b = input_tensor.builder();
530 auto shape_or = b->GetShape(input_tensor);
531 OP_REQUIRES_OK(ctx, shape_or.status());
532 xla::Shape element_shape = shape_or.ConsumeValueOrDie();
533 std::vector<int64> element_dims =
534 xla::SpanToVector(element_shape.dimensions());
535 OP_REQUIRES(
536 ctx, !element_dims.empty(),
537 errors::Unimplemented("Element dimensions have to be non-empty"));
538
539 std::vector<int64> lengths;
540 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
541 OP_REQUIRES(ctx, !lengths.empty(),
542 errors::Unimplemented("Length has to be non-empty"));
543 int64 length = lengths[0];
544 for (int64 len : lengths) {
545 OP_REQUIRES(ctx, len == length,
546 errors::Unimplemented("All lengths have to be the same"));
547 }
548 OP_REQUIRES(
549 ctx, element_dims[0] % length == 0,
550 errors::Unimplemented("Buffer size has to be a multiple of length"));
551 std::vector<int64> new_dims = {element_dims[0] / length, length};
552 for (int i = 1; i < element_dims.size(); i++) {
553 new_dims.push_back(element_dims[i]);
554 }
555
556 xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims);
557
558 xla::XlaOp result;
559 OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result));
560 ctx->SetTensorListOutput(0, result);
561 }
562
563 private:
564 DataType dtype_;
565
566 TF_DISALLOW_COPY_AND_ASSIGN(TensorListSplitOp);
567 };
568
569 REGISTER_XLA_OP(Name("TensorListSplit")
570 .CompileTimeConstantInput("element_shape")
571 .CompileTimeConstantInput("lengths"),
572 TensorListSplitOp);
573
574 class TensorListFromTensorOp : public XlaOpKernel {
575 public:
TensorListFromTensorOp(OpKernelConstruction * ctx)576 explicit TensorListFromTensorOp(OpKernelConstruction* ctx)
577 : XlaOpKernel(ctx) {}
578
Compile(XlaOpKernelContext * ctx)579 void Compile(XlaOpKernelContext* ctx) override {
580 const TensorShape& tensor_shape = ctx->InputShape(0);
581 int num_elements = tensor_shape.dim_size(0);
582 const xla::XlaOp tensor = ctx->Input(0);
583 xla::XlaOp result;
584 OP_REQUIRES_OK(ctx,
585 ExecuteTensorListFromTensor(num_elements, tensor, &result));
586 auto list_shape_or = ctx->builder()->GetShape(result);
587 ctx->SetTensorListOutput(0, result);
588 }
589
590 private:
591 TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp);
592 };
593
594 REGISTER_XLA_OP(
595 Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
596 TensorListFromTensorOp);
597
598 class TensorListSetItemOp : public XlaOpKernel {
599 public:
TensorListSetItemOp(OpKernelConstruction * ctx)600 explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
601
Compile(XlaOpKernelContext * ctx)602 void Compile(XlaOpKernelContext* ctx) override {
603 xla::XlaOp list = ctx->Input(0);
604 xla::XlaOp index = ctx->Input(1);
605 xla::XlaOp element = ctx->Input(2);
606 xla::XlaOp initialized_list;
607 OP_REQUIRES_OK(ctx, GetInitializedTensorListForElement(
608 list, element, /*element_is_tensor_list=*/false,
609 &initialized_list));
610
611 // Only non-nested TensorList is supported for now.
612 bool is_nested;
613 OP_REQUIRES_OK(ctx, IsNestedTensorList(initialized_list, &is_nested));
614 OP_REQUIRES(ctx, !is_nested,
615 errors::Unimplemented("Only non-nested TensorList is supported "
616 "for TensorListSetItem."));
617
618 xla::XlaOp result;
619 OP_REQUIRES_OK(ctx, ExecuteTensorListSetItem(initialized_list, index,
620 element, &result));
621
622 ctx->SetTensorListOutput(0, result);
623 }
624
625 private:
626 TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp);
627 };
628
629 REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp);
630
631 class TensorListPushBackOp : public XlaOpKernel {
632 public:
TensorListPushBackOp(OpKernelConstruction * ctx)633 explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
634
Compile(XlaOpKernelContext * ctx)635 void Compile(XlaOpKernelContext* ctx) override {
636 xla::XlaOp list = ctx->Input(0);
637 xla::XlaOp element = ctx->Input(1);
638 bool element_is_tensor_list = IsTensorListInput(ctx, 1);
639 xla::XlaOp initialized_list;
640 OP_REQUIRES_OK(
641 ctx, GetInitializedTensorListForElement(
642 list, element, element_is_tensor_list, &initialized_list));
643
644 xla::XlaOp result;
645 OP_REQUIRES_OK(ctx,
646 ExecuteTensorListPushBack(initialized_list, element,
647 element_is_tensor_list, &result));
648
649 ctx->SetTensorListOutput(0, result);
650 }
651
652 private:
653 TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
654 };
655
656 REGISTER_XLA_OP(Name("TensorListPushBack").AllowVariantTypes(),
657 TensorListPushBackOp);
658
659 class TensorListPopBackOp : public XlaOpKernel {
660 public:
TensorListPopBackOp(OpKernelConstruction * ctx)661 explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
662
Compile(XlaOpKernelContext * ctx)663 void Compile(XlaOpKernelContext* ctx) override {
664 // Check that the TensorList is initialized.
665 bool is_initialized;
666 OP_REQUIRES_OK(ctx,
667 (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
668 OP_REQUIRES(ctx, is_initialized,
669 errors::InvalidArgument("TensorList is not initialized"));
670
671 xla::XlaOp list = ctx->Input(0);
672 xla::XlaOp list_result, element_result;
673 bool element_is_tensor_list;
674 OP_REQUIRES_OK(ctx,
675 ExecuteTensorListPopBack(list, &list_result, &element_result,
676 &element_is_tensor_list));
677
678 ctx->SetTensorListOutput(0, list_result);
679 if (element_is_tensor_list) {
680 ctx->SetTensorListOutput(1, element_result);
681 } else {
682 ctx->SetOutput(1, element_result);
683 }
684 }
685
686 private:
687 DataType dtype_;
688
689 TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
690 };
691
692 REGISTER_XLA_OP(Name("TensorListPopBack").AllowVariantTypes(),
693 TensorListPopBackOp);
694
695 } // anonymous namespace
696 } // namespace tensorflow
697