• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA TensorList operators.
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
22 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/framework/bounds_check.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/partial_tensor_shape.h"
35 #include "tensorflow/core/framework/register_types.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_types.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 
45 namespace {
46 
47 // GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
48 // may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
49 // dimension is static, a constant dimension is returned. If a dim is dynamic, a
50 // dynamic XlaOp representing the dynamic size is returned.
GetTensorListDynamicDims(XlaOpKernelContext * ctx,const xla::Shape & element_shape,const xla::Shape & list_shape,int64 num_elements)51 xla::StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
52     XlaOpKernelContext* ctx, const xla::Shape& element_shape,
53     const xla::Shape& list_shape, int64 num_elements) {
54   std::vector<int64> dynamic_sizes;
55   // The multiplier can be a dynamic value.
56   TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
57   std::vector<bool> dims_are_dynamic;
58   TF_RETURN_IF_ERROR(
59       ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
60   bool leading_dim_is_dynamic;
61   TF_RETURN_IF_ERROR(
62       ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
63   std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
64   // Set dynamic dimension size to 0 for initialization value.
65   std::vector<xla::XlaOp> dynamic_dims;
66   if (leading_dim_is_dynamic) {
67     dynamic_dims.push_back(ctx->Input(1));
68   } else {
69     dynamic_dims.push_back(
70         xla::ConstantR0<int32>(ctx->builder(), num_elements));
71   }
72   for (int64 dim = 0; dim < element_shape.dimensions_size(); ++dim) {
73     if (dims_are_dynamic[dim]) {
74       auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
75       dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
76       dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
77       dynamic_dims.push_back(dynamic_dim_size);
78     } else {
79       dynamic_dims.push_back(
80           xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
81     }
82   }
83   list_dynamic_dims.push_back(dynamic_dims);
84   return list_dynamic_dims;
85 }
86 
87 class TensorListLengthOp : public XlaOpKernel {
88  public:
TensorListLengthOp(OpKernelConstruction * ctx)89   explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
90 
Compile(XlaOpKernelContext * ctx)91   void Compile(XlaOpKernelContext* ctx) override {
92     int64 leading_dim;
93     xla::XlaOp leading_dim_size;
94     bool leading_dim_is_dynamic;
95     OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
96                                                    &leading_dim_is_dynamic,
97                                                    &leading_dim_size));
98     ctx->SetOutput(0, leading_dim_size);
99   }
100 
101  private:
102   TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
103 };
104 
105 REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
106 
107 // "input" is the shape input for EmptyTensorList/TensorListReserve ops.
108 // If "input" is a compile time constant and not "unknown rank" (-1), return
109 // its value in "*shape".
TryGetElementShapeFromInput(XlaOpKernelContext * ctx,xla::XlaOp input,xla::PrimitiveType dtype,bool * got_shape,xla::Shape * shape)110 Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input,
111                                    xla::PrimitiveType dtype, bool* got_shape,
112                                    xla::Shape* shape) {
113   auto is_compile_time_constant_or = input.builder()->IsConstant(input);
114   TF_RETURN_IF_ERROR(is_compile_time_constant_or.status());
115 
116   bool is_compile_time_constant = is_compile_time_constant_or.ValueOrDie();
117   if (!is_compile_time_constant) {
118     *got_shape = false;
119     return Status::OK();
120   }
121 
122   PartialTensorShape partial_shape;
123   TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape));
124   if (!partial_shape.IsFullyDefined()) {
125     *got_shape = false;
126     return Status::OK();
127   }
128 
129   *shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes());
130   *got_shape = true;
131   return Status::OK();
132 }
133 
134 class TensorListReserveOp : public XlaOpKernel {
135  public:
TensorListReserveOp(OpKernelConstruction * ctx)136   explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
137     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
138     // Only non-nested TensorList is supported for now.
139     OP_REQUIRES(
140         ctx, dtype_ != DT_VARIANT,
141         errors::Unimplemented(
142             "Only non-nested TensorList is supported for TensorListReserve."));
143   }
144 
Compile(XlaOpKernelContext * ctx)145   void Compile(XlaOpKernelContext* ctx) override {
146     int64 num_elements;
147     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
148     bool num_element_is_dynamic;
149     OP_REQUIRES_OK(
150         ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
151     OP_REQUIRES(
152         ctx, num_elements >= 0,
153         errors::InvalidArgument(
154             "XLA compilation requires a fixed tensor list size. Set the number "
155             "of elements. This could also happen if you're using a TensorArray "
156             "in a while loop that does not have its maximum_iteration set, you "
157             "can fix this by setting maximum_iteration to a suitable value."));
158 
159     // If element shape is compile time constant and it's not "unknown rank"
160     // shape (-1), create an initialized TensorList. Otherwise create an
161     // uninitialized TensorList.
162     xla::XlaOp element_shape_handle = ctx->Input(0);
163     xla::PrimitiveType type;
164     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
165     bool got_shape;
166     xla::Shape element_shape;
167     OP_REQUIRES_OK(ctx,
168                    TryGetElementShapeFromInput(ctx, element_shape_handle, type,
169                                                &got_shape, &element_shape));
170     if (got_shape) {
171       xla::Shape list_shape;
172       OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
173                               element_shape, num_elements,
174                               num_element_is_dynamic, &list_shape));
175       // Set up dynamic dimension sizes to create the zero tensor.
176       auto list_dynamic_dims_or = GetTensorListDynamicDims(
177           ctx, element_shape, list_shape, num_elements);
178       OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
179       xla::XlaOp new_list;
180       OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
181                               ctx->builder(), list_shape,
182                               list_dynamic_dims_or.ValueOrDie(), &new_list));
183       xla::XlaOp result;
184       OP_REQUIRES_OK(
185           ctx,
186           SetTensorListPushIndex(
187               new_list, xla::ConstantR0<int32>(ctx->builder(), num_elements),
188               &result));
189       ctx->SetTensorListOutput(0, result);
190       return;
191     }
192 
193     xla::XlaOp result = BuildUninitializedTensorList(
194         ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
195     ctx->SetTensorListOutput(0, result);
196   }
197 
198  private:
199   DataType dtype_;
200 
201   TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
202 };
203 
204 REGISTER_XLA_OP(Name("TensorListReserve")
205                     .CompileTimeConstantInput("element_shape")
206                     .CompileTimeConstantInput("num_elements"),
207                 TensorListReserveOp);
208 
209 class EmptyTensorListOp : public XlaOpKernel {
210  public:
EmptyTensorListOp(OpKernelConstruction * ctx)211   explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
212     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
213   }
214 
Compile(XlaOpKernelContext * ctx)215   void Compile(XlaOpKernelContext* ctx) override {
216     int64 max_num_elements;
217     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &max_num_elements));
218     bool num_element_is_dynamic;
219     OP_REQUIRES_OK(
220         ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
221     OP_REQUIRES(ctx, max_num_elements >= 0,
222                 errors::InvalidArgument(
223                     "XLA compilation requires a fixed tensor list size. Set "
224                     "the max number of elements. This could also happen if "
225                     "you're using a TensorArray in a while loop that does not "
226                     "have its maximum_iteration set, you can fix this by "
227                     "setting maximum_iteration to a suitable value."));
228 
229     if (dtype_ != DT_VARIANT) {
230       // We are creating a non-nested TensorList.
231       // If element shape is compile time constant and it's not "unknown
232       // rank" shape (-1), create an initialized TensorList. Otherwise
233       // create an uninitialized TensorList.
234       xla::XlaOp element_shape_handle = ctx->Input(0);
235       xla::PrimitiveType type;
236       OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
237       bool got_shape;
238       xla::Shape element_shape;
239       OP_REQUIRES_OK(
240           ctx, TryGetElementShapeFromInput(ctx, element_shape_handle, type,
241                                            &got_shape, &element_shape));
242       if (got_shape) {
243         xla::Shape list_shape;
244         OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
245                                 element_shape, max_num_elements,
246                                 num_element_is_dynamic, &list_shape));
247         // Set up dynamic dimension sizes to create the zero tensor.
248         auto list_dynamic_dims_or = GetTensorListDynamicDims(
249             ctx, element_shape, list_shape, max_num_elements);
250         OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
251 
252         xla::XlaOp result;
253         OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
254                                 ctx->builder(), list_shape,
255                                 list_dynamic_dims_or.ValueOrDie(), &result));
256 
257         ctx->SetTensorListOutput(0, result);
258         return;
259       }
260     }
261 
262     // We are creating a nested TensorList or a non-nested TensorList with
263     // unknown shape. Just create an uninitialized TensorList.
264     xla::XlaOp result =
265         BuildUninitializedTensorList(ctx->builder(), max_num_elements,
266                                      num_element_is_dynamic, ctx->Input(1));
267     ctx->SetTensorListOutput(0, result);
268   }
269 
270  private:
271   DataType dtype_;
272 
273   TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
274 };
275 
276 REGISTER_XLA_OP(Name("EmptyTensorList")
277                     .CompileTimeConstantInput("element_shape")
278                     .CompileTimeConstantInput("max_num_elements")
279                     .AllowVariantTypes(),
280                 EmptyTensorListOp);
281 
282 class TensorListElementShapeOp : public XlaOpKernel {
283  public:
TensorListElementShapeOp(OpKernelConstruction * ctx)284   explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
285       : XlaOpKernel(ctx) {
286     OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
287   }
288 
Compile(XlaOpKernelContext * ctx)289   void Compile(XlaOpKernelContext* ctx) override {
290     // Check that the TensorList is initialized.
291     bool is_initialized;
292     OP_REQUIRES_OK(ctx,
293                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
294     OP_REQUIRES(ctx, is_initialized,
295                 errors::InvalidArgument("TensorList is not initialized"));
296 
297     // Only non-nested TensorList is supported for now.
298     bool is_nested;
299     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
300     OP_REQUIRES(ctx, !is_nested,
301                 errors::Unimplemented("Only non-nested TensorList is supported "
302                                       "for TensorListElementShape."));
303 
304     // For non-nested TensorList, element shape is the buffer shape without
305     // the first dimension.
306     xla::XlaBuilder* b = ctx->builder();
307     xla::Shape list_shape;
308     OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &list_shape));
309     list_shape.DeleteDimension(0);
310 
311     switch (shape_type_) {
312       case DT_INT64:
313         ctx->SetOutput(0, xla::ConstantR1<int64>(b, list_shape.dimensions()));
314         break;
315       case DT_INT32: {
316         std::vector<int32> size;
317         for (int64 s : list_shape.dimensions()) {
318           size.push_back(s);
319         }
320         ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
321         break;
322       }
323       default:
324         ctx->CtxFailure(
325             errors::InvalidArgument("Unsupported shape type requested"));
326         return;
327     }
328   }
329 
330  private:
331   DataType shape_type_;
332 
333   TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
334 };
335 
336 REGISTER_XLA_OP(Name("TensorListElementShape").IsMetadataOp(),
337                 TensorListElementShapeOp);
338 
339 class TensorListGetItemOp : public XlaOpKernel {
340  public:
TensorListGetItemOp(OpKernelConstruction * ctx)341   explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
342     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
343   }
344 
Compile(XlaOpKernelContext * ctx)345   void Compile(XlaOpKernelContext* ctx) override {
346     // Check that the TensorList is initialized.
347     bool is_initialized;
348     OP_REQUIRES_OK(ctx,
349                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
350     OP_REQUIRES(ctx, is_initialized,
351                 errors::InvalidArgument("TensorList is not initialized"));
352 
353     // Only non-nested TensorList is supported for now.
354     bool is_nested;
355     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
356     OP_REQUIRES(ctx, !is_nested,
357                 errors::Unimplemented("Only non-nested TensorList is supported "
358                                       "for TensorListGetItem."));
359 
360     xla::XlaOp list = ctx->Input(0);
361     xla::XlaOp index = ctx->Input(1);
362 
363     xla::XlaOp result;
364     OP_REQUIRES_OK(ctx, ExecuteTensorListGetItem(list, index, &result));
365 
366     ctx->SetOutput(0, result);
367   }
368 
369  private:
370   DataType dtype_;
371 
372   TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp);
373 };
374 
375 REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp);
376 
377 class TensorListGatherOp : public XlaOpKernel {
378  public:
TensorListGatherOp(OpKernelConstruction * ctx)379   explicit TensorListGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
380     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
381   }
382 
Compile(XlaOpKernelContext * ctx)383   void Compile(XlaOpKernelContext* ctx) override {
384     // Check that the TensorList is initialized.
385     bool is_initialized;
386     OP_REQUIRES_OK(ctx,
387                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
388     OP_REQUIRES(ctx, is_initialized,
389                 errors::InvalidArgument("TensorList is not initialized"));
390 
391     // Only non-nested TensorList is supported for now.
392     bool is_nested;
393     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
394     OP_REQUIRES(ctx, !is_nested,
395                 errors::Unimplemented("Only non-nested TensorList is supported "
396                                       "for TensorListGather."));
397 
398     DataType indices_type = ctx->input_type(1);
399 
400     const TensorShape indices_shape = ctx->InputShape(1);
401     OP_REQUIRES(ctx, indices_shape.dims() == 1,
402                 errors::InvalidArgument("indices must be rank 1"));
403 
404     xla::XlaOp list = ctx->Input(0);
405     xla::XlaOp indices = ctx->Input(1);
406 
407     xla::XlaOp buffer;
408     OP_REQUIRES_OK(ctx, GetTensorListBuffer(list, &buffer));
409     xla::Shape buffer_xla_shape;
410     OP_REQUIRES_OK(ctx, GetTensorListBufferShape(list, &buffer_xla_shape));
411     TensorShape buffer_shape;
412     OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(buffer_xla_shape, &buffer_shape));
413 
414     xla::XlaOp result;
415     OP_REQUIRES_OK(
416         ctx, XlaGather(buffer, buffer_shape, indices, indices_shape, /*axis=*/0,
417                        /*indices_are_nd=*/false, dtype_, indices_type,
418                        ctx->builder(), &result));
419     ctx->SetOutput(0, result);
420   }
421 
422  private:
423   DataType dtype_;
424 
425   TF_DISALLOW_COPY_AND_ASSIGN(TensorListGatherOp);
426 };
427 
428 REGISTER_XLA_OP(Name("TensorListGather"), TensorListGatherOp);
429 
430 class TensorListStackOp : public XlaOpKernel {
431  public:
TensorListStackOp(OpKernelConstruction * ctx)432   explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
433 
Compile(XlaOpKernelContext * ctx)434   void Compile(XlaOpKernelContext* ctx) override {
435     // Check that the TensorList is initialized.
436     bool is_initialized;
437     OP_REQUIRES_OK(ctx,
438                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
439     OP_REQUIRES(ctx, is_initialized,
440                 errors::InvalidArgument("TensorList is not initialized"));
441 
442     // Only non-nested TensorList is supported for now.
443     bool is_nested;
444     OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
445     OP_REQUIRES(ctx, !is_nested,
446                 errors::Unimplemented("Only non-nested TensorList is supported "
447                                       "for TensorListGetItem."));
448 
449     xla::XlaOp buffer;
450     OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
451     ctx->SetOutput(0, buffer);
452   }
453 
454  private:
455   TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp);
456 };
457 
458 REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp);
459 
460 class TensorListConcatOp : public XlaOpKernel {
461  public:
TensorListConcatOp(OpKernelConstruction * ctx)462   explicit TensorListConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
463 
Compile(XlaOpKernelContext * ctx)464   void Compile(XlaOpKernelContext* ctx) override {
465     xla::XlaOp input = ctx->Input(0);
466 
467     // Check that the TensorList is initialized.
468     bool is_initialized;
469     OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized)));
470     OP_REQUIRES(ctx, is_initialized,
471                 errors::InvalidArgument("TensorList is not initialized"));
472 
473     // Only non-nested TensorList is supported for now.
474     bool is_nested;
475     OP_REQUIRES_OK(ctx, IsNestedTensorList(input, &is_nested));
476     OP_REQUIRES(ctx, !is_nested,
477                 errors::Unimplemented("Only non-nested TensorList is supported "
478                                       "for TensorListConcat."));
479 
480     xla::XlaOp buffer;
481     OP_REQUIRES_OK(ctx, GetTensorListBuffer(input, &buffer));
482 
483     xla::XlaBuilder* b = input.builder();
484     auto shape_or = b->GetShape(buffer);
485     OP_REQUIRES_OK(ctx, shape_or.status());
486     xla::Shape element_shape = shape_or.ConsumeValueOrDie();
487     std::vector<int64> element_dims =
488         xla::SpanToVector(element_shape.dimensions());
489     OP_REQUIRES(
490         ctx, element_dims.size() > 1,
491         errors::Unimplemented("TensorList of scalars is not supported"));
492     int64 num_elements = element_dims[0];
493     int64 tensor_lengths = element_dims[1];
494 
495     std::vector<int64> new_dims = {num_elements * tensor_lengths};
496 
497     for (int i = 2; i < element_dims.size(); i++) {
498       new_dims.push_back(element_dims[i]);
499     }
500 
501     xla::XlaOp out = xla::Reshape(buffer, new_dims);
502     ctx->SetOutput(0, out);
503 
504     // Second output is a tensor of lengths of returned tensors.
505     xla::XlaOp lengths = xla::ConstantR1(b, num_elements, tensor_lengths);
506     ctx->SetOutput(1, lengths);
507   }
508 
509  private:
510   TF_DISALLOW_COPY_AND_ASSIGN(TensorListConcatOp);
511 };
512 
513 REGISTER_XLA_OP(Name("TensorListConcatV2"), TensorListConcatOp);
514 
515 class TensorListSplitOp : public XlaOpKernel {
516  public:
TensorListSplitOp(OpKernelConstruction * ctx)517   explicit TensorListSplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
518     OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
519     // Only non-nested TensorList is supported for now.
520     OP_REQUIRES(
521         ctx, dtype_ != DT_VARIANT,
522         errors::Unimplemented(
523             "Only non-nested TensorList is supported for TensorListReserve."));
524   }
525 
Compile(XlaOpKernelContext * ctx)526   void Compile(XlaOpKernelContext* ctx) override {
527     xla::XlaOp input_tensor = ctx->Input(0);
528 
529     xla::XlaBuilder* b = input_tensor.builder();
530     auto shape_or = b->GetShape(input_tensor);
531     OP_REQUIRES_OK(ctx, shape_or.status());
532     xla::Shape element_shape = shape_or.ConsumeValueOrDie();
533     std::vector<int64> element_dims =
534         xla::SpanToVector(element_shape.dimensions());
535     OP_REQUIRES(
536         ctx, !element_dims.empty(),
537         errors::Unimplemented("Element dimensions have to be non-empty"));
538 
539     std::vector<int64> lengths;
540     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
541     OP_REQUIRES(ctx, !lengths.empty(),
542                 errors::Unimplemented("Length has to be non-empty"));
543     int64 length = lengths[0];
544     for (int64 len : lengths) {
545       OP_REQUIRES(ctx, len == length,
546                   errors::Unimplemented("All lengths have to be the same"));
547     }
548     OP_REQUIRES(
549         ctx, element_dims[0] % length == 0,
550         errors::Unimplemented("Buffer size has to be a multiple of length"));
551     std::vector<int64> new_dims = {element_dims[0] / length, length};
552     for (int i = 1; i < element_dims.size(); i++) {
553       new_dims.push_back(element_dims[i]);
554     }
555 
556     xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims);
557 
558     xla::XlaOp result;
559     OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result));
560     ctx->SetTensorListOutput(0, result);
561   }
562 
563  private:
564   DataType dtype_;
565 
566   TF_DISALLOW_COPY_AND_ASSIGN(TensorListSplitOp);
567 };
568 
569 REGISTER_XLA_OP(Name("TensorListSplit")
570                     .CompileTimeConstantInput("element_shape")
571                     .CompileTimeConstantInput("lengths"),
572                 TensorListSplitOp);
573 
574 class TensorListFromTensorOp : public XlaOpKernel {
575  public:
TensorListFromTensorOp(OpKernelConstruction * ctx)576   explicit TensorListFromTensorOp(OpKernelConstruction* ctx)
577       : XlaOpKernel(ctx) {}
578 
Compile(XlaOpKernelContext * ctx)579   void Compile(XlaOpKernelContext* ctx) override {
580     const TensorShape& tensor_shape = ctx->InputShape(0);
581     int num_elements = tensor_shape.dim_size(0);
582     const xla::XlaOp tensor = ctx->Input(0);
583     xla::XlaOp result;
584     OP_REQUIRES_OK(ctx,
585                    ExecuteTensorListFromTensor(num_elements, tensor, &result));
586     auto list_shape_or = ctx->builder()->GetShape(result);
587     ctx->SetTensorListOutput(0, result);
588   }
589 
590  private:
591   TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp);
592 };
593 
594 REGISTER_XLA_OP(
595     Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
596     TensorListFromTensorOp);
597 
598 class TensorListSetItemOp : public XlaOpKernel {
599  public:
TensorListSetItemOp(OpKernelConstruction * ctx)600   explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
601 
Compile(XlaOpKernelContext * ctx)602   void Compile(XlaOpKernelContext* ctx) override {
603     xla::XlaOp list = ctx->Input(0);
604     xla::XlaOp index = ctx->Input(1);
605     xla::XlaOp element = ctx->Input(2);
606     xla::XlaOp initialized_list;
607     OP_REQUIRES_OK(ctx, GetInitializedTensorListForElement(
608                             list, element, /*element_is_tensor_list=*/false,
609                             &initialized_list));
610 
611     // Only non-nested TensorList is supported for now.
612     bool is_nested;
613     OP_REQUIRES_OK(ctx, IsNestedTensorList(initialized_list, &is_nested));
614     OP_REQUIRES(ctx, !is_nested,
615                 errors::Unimplemented("Only non-nested TensorList is supported "
616                                       "for TensorListSetItem."));
617 
618     xla::XlaOp result;
619     OP_REQUIRES_OK(ctx, ExecuteTensorListSetItem(initialized_list, index,
620                                                  element, &result));
621 
622     ctx->SetTensorListOutput(0, result);
623   }
624 
625  private:
626   TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp);
627 };
628 
629 REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp);
630 
631 class TensorListPushBackOp : public XlaOpKernel {
632  public:
TensorListPushBackOp(OpKernelConstruction * ctx)633   explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
634 
Compile(XlaOpKernelContext * ctx)635   void Compile(XlaOpKernelContext* ctx) override {
636     xla::XlaOp list = ctx->Input(0);
637     xla::XlaOp element = ctx->Input(1);
638     bool element_is_tensor_list = IsTensorListInput(ctx, 1);
639     xla::XlaOp initialized_list;
640     OP_REQUIRES_OK(
641         ctx, GetInitializedTensorListForElement(
642                  list, element, element_is_tensor_list, &initialized_list));
643 
644     xla::XlaOp result;
645     OP_REQUIRES_OK(ctx,
646                    ExecuteTensorListPushBack(initialized_list, element,
647                                              element_is_tensor_list, &result));
648 
649     ctx->SetTensorListOutput(0, result);
650   }
651 
652  private:
653   TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
654 };
655 
656 REGISTER_XLA_OP(Name("TensorListPushBack").AllowVariantTypes(),
657                 TensorListPushBackOp);
658 
659 class TensorListPopBackOp : public XlaOpKernel {
660  public:
TensorListPopBackOp(OpKernelConstruction * ctx)661   explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
662 
Compile(XlaOpKernelContext * ctx)663   void Compile(XlaOpKernelContext* ctx) override {
664     // Check that the TensorList is initialized.
665     bool is_initialized;
666     OP_REQUIRES_OK(ctx,
667                    (IsTensorListInitialized(ctx->Input(0), &is_initialized)));
668     OP_REQUIRES(ctx, is_initialized,
669                 errors::InvalidArgument("TensorList is not initialized"));
670 
671     xla::XlaOp list = ctx->Input(0);
672     xla::XlaOp list_result, element_result;
673     bool element_is_tensor_list;
674     OP_REQUIRES_OK(ctx,
675                    ExecuteTensorListPopBack(list, &list_result, &element_result,
676                                             &element_is_tensor_list));
677 
678     ctx->SetTensorListOutput(0, list_result);
679     if (element_is_tensor_list) {
680       ctx->SetTensorListOutput(1, element_result);
681     } else {
682       ctx->SetOutput(1, element_result);
683     }
684   }
685 
686  private:
687   DataType dtype_;
688 
689   TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
690 };
691 
692 REGISTER_XLA_OP(Name("TensorListPopBack").AllowVariantTypes(),
693                 TensorListPopBackOp);
694 
695 }  // anonymous namespace
696 }  // namespace tensorflow
697