• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
17 
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 
30 // TensorList is represented by a tuple.
31 // - The first part of the tuple is a buffer containing all the tensors,
32 // - The following parts are push indices for all nested levels of
33 //   TensorLists. The last part is push index for the outermost TensorList.
34 //
35 // TensorList, as it name suggests, is conceptually a list of tensors. In actual
36 // representation of a non-nested TensorList, the buffer shape is
37 // [element_shape, tensor_list_size]. We will call tensor_list_size "leading
38 // dimension" below. Notice that the leading dimension must be a compile time
39 // constant, since it's part of the buffer shape.
40 //
41 // Example: consider a 3-level nested TensorList whose element type is scalar.
42 // Assume inner TensorList has leading dimension 4, middle TensorList has 3,
43 // and outer TensorList has 3.
44 // Assume that lower cased letter means there is data in that position, and "."
45 // means there is no data in that position.
46 // First element of outer TensorList:
47 // [ a . . . ]
48 // [ b c . . ]
49 // [ d e f . ]
50 // Second element of outer TensorList:
51 // [ g h i . ]
52 // [ j k . . ]
53 // [ . . . . ]
54 // Third element: not pushed yet.
55 //
56 // The first part of the tuple is an array of shape [3, 3, 4] containing data.
57 // The second part is an array of shape [3, 3], each element is push index
58 // for the inner TensorList. In this case, its values are:
59 // [ 1 2 3 ]
60 // [ 3 2 . ]
61 // [ . . . ]
62 // The third part is an array of shape [3], each element is push index for
63 // the middle TensorList. In this case, its values are:
64 // [ 3 ]
65 // [ 2 ]
66 // [ . ]
67 // The forth (and last) part is a scalar. It's the push index for the outer
68 // TensorList. In this case, its values is 2.
69 //
70 // Now imagine we need to push the following element to the outer TensorList:
71 // [ l . . . ]
72 // [ m n . . ]
73 // [ . . . . ]
74 // This element is represented by a tuple of 3 parts:
75 // First part is all data.
76 // Second part is push indices for the inner TensorList, which is [ 1 2 . ].
77 // Third part is push index for the middle TensorList, which is 2.
78 // Now let's do the push.
79 // First, we append its data to outer TensorList's data.
80 // Then we start to deal with push indices. Similar to data, we append push
81 // indices for each level of TensorList.
82 // For the inner TensorList: append push indices for the pushed element.
83 // [ 1 2 3 ]               [ 1 2 3 ]
84 // [ 3 2 . ] +           = [ 3 2 . ]
85 // [ . . . ]   [ 1 2 . ]   [ 1 2 . ]
86 // For the middle TensorList: append push indices for the pushed element.
87 // [ 3 ]           [ 3 ]
88 // [ 2 ] +       = [ 2 ]
89 // [ . ]   [ 2 ]   [ 2 ]
90 // For the outer TensorList: just add 1.
91 // 2 + 1 = 3
92 //
93 // Popping an element from the outer TensorList also follows a similar process.
94 // First part is data. We get data by slicing data with push index for outer
95 // TensorList (which is 3).
96 // Second part is push indices for inner TensorList. We get it by slicing
97 // push indices for inner TensorList with push index for outer TensorList (which
98 // is 3).
99 // [ 1 2 3 ]
100 // [ 3 2 . ]
101 // [ 1 2 . ] ===> This is what we want
102 // Third part is push index for middle TensorList. We get it by slicing
103 // push indices for middle TensorList with push index for outer TensorList
104 // (which is 3).
105 // [ 3 ]
106 // [ 2 ]
107 // [ 2 ] ===> This is what we want
108 
109 namespace tensorflow {
110 
IsTensorListInput(XlaOpKernelContext * ctx,int index)111 bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
112   return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
113 }
114 
IsTensorListInitialized(xla::XlaOp list,bool * is_initialized)115 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) {
116   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
117   *is_initialized = list_shape.IsTuple();
118   return Status::OK();
119 }
120 
IsNestedTensorList(xla::XlaOp list,bool * is_nested_list)121 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) {
122   bool is_initialized;
123   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
124   if (!is_initialized) {
125     return errors::InvalidArgument("TensorList is not initialized");
126   }
127   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
128   *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2);
129   return Status::OK();
130 }
131 
BuildNonNestedTensorList(xla::XlaOp buffer,xla::XlaOp push_index,xla::XlaOp * output_list)132 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index,
133                                 xla::XlaOp* output_list) {
134   TF_RET_CHECK(buffer.builder());
135   *output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
136   return Status::OK();
137 }
138 
GetTensorListBufferShape(xla::XlaOp list,xla::Shape * buffer_shape)139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) {
140   bool is_initialized;
141   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
142   if (!is_initialized) {
143     return errors::InvalidArgument("TensorList is not initialized");
144   }
145   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
146   *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
147   return Status::OK();
148 }
149 
GetTensorListBuffer(xla::XlaOp list,xla::XlaOp * buffer)150 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) {
151   bool is_initialized;
152   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
153   if (!is_initialized) {
154     return errors::InvalidArgument("TensorList is not initialized");
155   }
156   *buffer = xla::GetTupleElement(list, 0);
157   return Status::OK();
158 }
159 
GetTensorListPushIndex(xla::XlaOp list,xla::XlaOp * push_index)160 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) {
161   bool is_initialized;
162   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
163   if (!is_initialized) {
164     return errors::InvalidArgument("TensorList is not initialized");
165   }
166   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
167   int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
168   *push_index = xla::GetTupleElement(list, tuple_size - 1);
169   return Status::OK();
170 }
171 
SetTensorListPushIndex(xla::XlaOp list,xla::XlaOp push_index,xla::XlaOp * result)172 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
173                               xla::XlaOp* result) {
174   bool is_initialized;
175   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
176   if (!is_initialized) {
177     return errors::InvalidArgument("TensorList is not initialized");
178   }
179   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
180   int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
181   std::vector<xla::XlaOp> result_parts;
182   result_parts.reserve(tuple_size);
183   for (int i = 0; i < tuple_size - 1; i++) {
184     result_parts.push_back(xla::GetTupleElement(list, i));
185   }
186   result_parts.push_back(push_index);
187   *result = xla::Tuple(list.builder(), result_parts);
188   return Status::OK();
189 }
190 
BuildUninitializedTensorList(xla::XlaBuilder * b,int64 leading_dimension,bool leading_size_is_dynamic,xla::XlaOp leading_dim_size)191 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
192                                         int64 leading_dimension,
193                                         bool leading_size_is_dynamic,
194                                         xla::XlaOp leading_dim_size) {
195   auto zero =
196       xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
197   auto broadcast = xla::Broadcast(zero, std::vector<int64>{leading_dimension});
198   if (leading_size_is_dynamic) {
199     return xla::SetDimensionSize(broadcast, leading_dim_size, 0);
200   } else {
201     return broadcast;
202   }
203 }
204 
GetLeadingDimForTensorList(xla::XlaOp list,int64 * leading_dim,bool * leading_dim_is_dynamic,xla::XlaOp * leading_dim_dynamic_size)205 Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
206                                   bool* leading_dim_is_dynamic,
207                                   xla::XlaOp* leading_dim_dynamic_size) {
208   bool is_initialized;
209   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
210   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
211   if (is_initialized) {
212     auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
213     *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0);
214     auto buffer = xla::GetTupleElement(list, 0);
215     *leading_dim = buffer_shape.dimensions(0);
216     *leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0);
217   } else {
218     *leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0);
219     *leading_dim = list_shape.dimensions(0);
220     *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0);
221   }
222   return Status::OK();
223 }
224 
GetTensorListShapeFromElementTensorListShape(const xla::Shape & element_tensor_list_shape,int64 leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)225 Status GetTensorListShapeFromElementTensorListShape(
226     const xla::Shape& element_tensor_list_shape, int64 leading_dim,
227     bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) {
228   std::vector<xla::Shape> shapes;
229   int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape);
230   for (int i = 0; i < tuple_size; i++) {
231     const xla::Shape& shape =
232         xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i);
233     std::vector<int64> dimensions = xla::SpanToVector(shape.dimensions());
234     dimensions.insert(dimensions.begin(), leading_dim);
235     shapes.push_back(
236         xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
237     if (leading_dim_is_dynamic) {
238       shapes.back().set_dynamic_dimension(0, true);
239     }
240   }
241   shapes.push_back(
242       xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
243   *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
244   return Status::OK();
245 }
246 
GetTensorListShapeFromElementShape(const xla::Shape & element_shape,int64 leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)247 Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
248                                           int64 leading_dim,
249                                           bool leading_dim_is_dynamic,
250                                           xla::Shape* tensor_list_shape) {
251   if (!element_shape.IsArray()) {
252     return errors::InvalidArgument(
253         "GetTensorListShapeFromElementShape() only supports normal tensor "
254         "shape. But element shape is ",
255         element_shape.DebugString());
256   }
257   std::vector<xla::Shape> shapes;
258   std::vector<int64> dimensions = xla::SpanToVector(element_shape.dimensions());
259   dimensions.insert(dimensions.begin(), leading_dim);
260   shapes.push_back(
261       xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
262   shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
263   shapes.push_back(
264       xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
265   *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
266   return Status::OK();
267 }
268 
CreateZerosTensorListWithShape(xla::XlaBuilder * b,const xla::Shape & list_shape,const std::vector<std::vector<xla::XlaOp>> & dynamic_dims,xla::XlaOp * list)269 Status CreateZerosTensorListWithShape(
270     xla::XlaBuilder* b, const xla::Shape& list_shape,
271     const std::vector<std::vector<xla::XlaOp>>& dynamic_dims,
272     xla::XlaOp* list) {
273   int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
274   std::vector<xla::XlaOp> elements;
275   TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1);
276   for (int i = 0; i < tuple_size - 1; i++) {
277     const xla::Shape& shape =
278         xla::ShapeUtil::GetTupleElementShape(list_shape, i);
279     xla::XlaOp zero =
280         xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
281     xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
282     TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size());
283     for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
284       zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim);
285     }
286     elements.push_back(zeros);
287   }
288   // List size (last item) has to be S32.
289   TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1)
290                    .element_type() == xla::S32);
291   elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32)));
292   *list = xla::Tuple(b, elements);
293   return Status::OK();
294 }
295 
GetInitializedTensorListForElement(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * initialized_list)296 Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
297                                           bool element_is_tensor_list,
298                                           xla::XlaOp* initialized_list) {
299   int64 leading_dim;
300   xla::XlaOp leading_dim_dynamic_size;
301   bool leading_dim_is_dynamic;
302   TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(
303       list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size));
304 
305   xla::XlaBuilder* b = list.builder();
306   xla::Shape list_shape;
307   TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
308 
309   if (element_is_tensor_list) {
310     TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
311         element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
312   } else {
313     TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
314         element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
315   }
316   bool is_initialized;
317   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
318   if (is_initialized) {
319     // Check shape of initialized list is correct.
320     TF_ASSIGN_OR_RETURN(xla::Shape original_list_shape, b->GetShape(list));
321     if (!xla::ShapeUtil::Compatible(original_list_shape, list_shape)) {
322       return errors::Internal(
323           "Invalid TensorList shape: ", original_list_shape.DebugString(),
324           ", expected: ", list_shape.DebugString());
325     }
326     *initialized_list = list;
327     return Status::OK();
328   } else {
329     // Prepare dynamic dimension dimensions for zero tensor list. The dynamic
330     // sizes are created by reading the dynamic dimension size of sub-elements.
331     std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
332     for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
333       std::vector<xla::XlaOp> dynamic_dims;
334       const xla::Shape& shape = list_shape.tuple_shapes(i);
335       dynamic_dims.push_back(leading_dim_dynamic_size);
336       xla::XlaOp sub_element;
337       if (element_is_tensor_list) {
338         sub_element = xla::GetTupleElement(element, i);
339       } else {
340         sub_element = element;
341       }
342       for (int64 dim = 0; dim < shape.dimensions_size() - 1; ++dim) {
343         dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
344       }
345       list_dynamic_dims.push_back(dynamic_dims);
346     }
347     return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims,
348                                           initialized_list);
349   }
350 }
351 
ExecuteTensorListPushBack(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * result)352 Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element,
353                                  bool element_is_tensor_list,
354                                  xla::XlaOp* result) {
355   bool is_initialized;
356   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
357   if (!is_initialized) {
358     return errors::InvalidArgument("TensorList is not initialized");
359   }
360 
361   xla::XlaBuilder* b = list.builder();
362   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
363   int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
364   xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
365 
366   std::vector<xla::XlaOp> result_parts;
367 
368   if (element_is_tensor_list) {
369     TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
370     int element_tuple_size = xla::ShapeUtil::TupleElementCount(element_shape);
371     for (int i = 0; i < element_tuple_size; i++) {
372       const xla::Shape& element_part_shape =
373           xla::ShapeUtil::GetTupleElementShape(element_shape, i);
374       xla::XlaOp element_part = xla::GetTupleElement(element, i);
375       std::vector<int64> element_part_dims =
376           xla::SpanToVector(element_part_shape.dimensions());
377       element_part_dims.insert(element_part_dims.begin(), 1);
378       element_part = xla::Reshape(element_part, element_part_dims);
379 
380       std::vector<xla::XlaOp> start_indices(
381           element_part_shape.dimensions_size() + 1,
382           xla::ConstantR0<int32>(b, 0));
383       start_indices[0] = push_index;
384 
385       xla::XlaOp list_part = xla::GetTupleElement(list, i);
386       xla::XlaOp updated_list_part =
387           xla::DynamicUpdateSlice(list_part, element_part, start_indices);
388       result_parts.push_back(updated_list_part);
389     }
390   } else {
391     TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
392     std::vector<int64> element_dims =
393         xla::SpanToVector(element_shape.dimensions());
394     element_dims.insert(element_dims.begin(), 1);
395     xla::XlaOp update = xla::Reshape(element, element_dims);
396 
397     std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
398                                           xla::ConstantR0<int32>(b, 0));
399     start_indices[0] = push_index;
400 
401     xla::XlaOp list_part = xla::GetTupleElement(list, 0);
402     xla::XlaOp updated_list_part =
403         xla::DynamicUpdateSlice(list_part, update, start_indices);
404     result_parts.push_back(updated_list_part);
405   }
406 
407   xla::XlaOp updated_push_index = push_index + xla::ConstantR0<int32>(b, 1);
408   result_parts.push_back(updated_push_index);
409 
410   *result = xla::Tuple(b, result_parts);
411   return Status::OK();
412 }
413 
ExecuteTensorListPopBack(xla::XlaOp list,xla::XlaOp * list_result,xla::XlaOp * element_result,bool * element_is_tensor_list)414 Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result,
415                                 xla::XlaOp* element_result,
416                                 bool* element_is_tensor_list) {
417   bool is_initialized;
418   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
419   if (!is_initialized) {
420     return errors::InvalidArgument("TensorList is not initialized");
421   }
422 
423   // If the TensorList is a nested TensorList, element will be TensorList.
424   TF_RETURN_IF_ERROR(IsNestedTensorList(list, element_is_tensor_list));
425 
426   xla::XlaBuilder* b = list.builder();
427   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
428   int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
429   xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
430   push_index = push_index - xla::ConstantR0<int32>(b, 1);
431 
432   std::vector<xla::XlaOp> list_result_parts, element_result_parts;
433   for (int i = 0; i < list_tuple_size - 1; i++) {
434     const xla::Shape& list_part_shape =
435         xla::ShapeUtil::GetTupleElementShape(list_shape, i);
436     std::vector<xla::XlaOp> start_indices(list_part_shape.dimensions_size(),
437                                           xla::ConstantR0<int32>(b, 0));
438     start_indices[0] = push_index;
439 
440     std::vector<int64> slice_shape =
441         xla::SpanToVector(list_part_shape.dimensions());
442     slice_shape[0] = 1LL;
443 
444     xla::XlaOp list_part = xla::GetTupleElement(list, i);
445     xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
446 
447     slice_shape.erase(slice_shape.begin());
448     element_result_parts.push_back(xla::Reshape(read, slice_shape));
449     list_result_parts.push_back(list_part);
450   }
451   list_result_parts.push_back(push_index);
452 
453   *list_result = xla::Tuple(b, list_result_parts);
454   if (*element_is_tensor_list) {
455     *element_result = xla::Tuple(b, element_result_parts);
456   } else {
457     *element_result = element_result_parts[0];
458   }
459 
460   return Status::OK();
461 }
462 
ExecuteTensorListSetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp element,xla::XlaOp * result)463 Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index,
464                                 xla::XlaOp element, xla::XlaOp* result) {
465   bool is_initialized;
466   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
467   if (!is_initialized) {
468     return errors::InvalidArgument("TensorList is not initialized");
469   }
470   bool is_nested;
471   TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
472   if (is_nested) {
473     return errors::Unimplemented(
474         "ExecuteTensorListSetItem() only supports non-nested TensorList");
475   }
476 
477   xla::XlaBuilder* b = list.builder();
478   TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
479   std::vector<int64> element_dims =
480       xla::SpanToVector(element_shape.dimensions());
481   element_dims.insert(element_dims.begin(), 1);
482   xla::XlaOp update = xla::Reshape(element, element_dims);
483 
484   std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
485                                         xla::ConstantR0<int32>(b, 0));
486   start_indices[0] = index;
487 
488   xla::XlaOp list_part = xla::GetTupleElement(list, 0);
489   xla::XlaOp updated_list_part =
490       xla::DynamicUpdateSlice(list_part, update, start_indices);
491 
492   std::vector<xla::XlaOp> result_parts;
493   result_parts.push_back(updated_list_part);
494   result_parts.push_back(xla::GetTupleElement(list, 1));
495   *result = xla::Tuple(b, result_parts);
496   return Status::OK();
497 }
498 
ExecuteTensorListGetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp * result)499 Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
500                                 xla::XlaOp* result) {
501   bool is_initialized;
502   TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
503   if (!is_initialized) {
504     return errors::InvalidArgument("TensorList is not initialized");
505   }
506   bool is_nested;
507   TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
508   if (is_nested) {
509     return errors::Unimplemented(
510         "ExecuteTensorListGetItem() only supports non-nested TensorList");
511   }
512 
513   xla::XlaBuilder* b = list.builder();
514   TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
515   const xla::Shape& buffer_shape =
516       xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
517   std::vector<xla::XlaOp> start_indices(buffer_shape.dimensions_size(),
518                                         xla::ConstantR0<int32>(b, 0));
519   start_indices[0] = index;
520 
521   std::vector<int64> slice_shape = xla::SpanToVector(buffer_shape.dimensions());
522   slice_shape[0] = 1LL;
523 
524   xla::XlaOp list_part = xla::GetTupleElement(list, 0);
525   xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
526   // Propagate dynamic dimensions from buffer to the sliced buffer, except for
527   // leading dimension (which is always static 1).
528   for (int64 i = 1; i < buffer_shape.dimensions_size(); ++i) {
529     if (buffer_shape.is_dynamic_dimension(i)) {
530       auto buffer = xla::GetTupleElement(list, 0);
531       auto gds = xla::GetDimensionSize(buffer, i);
532       read = xla::SetDimensionSize(read, gds, i);
533     }
534   }
535   slice_shape.erase(slice_shape.begin());
536   *result = xla::Reshape(read, slice_shape);
537   return Status::OK();
538 }
539 
ExecuteTensorListFromTensor(int push_index,xla::XlaOp tensor,xla::XlaOp * result)540 Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor,
541                                    xla::XlaOp* result) {
542   xla::XlaBuilder* b = tensor.builder();
543   TF_ASSIGN_OR_RETURN(xla::Shape shape, b->GetShape(tensor));
544   if (!shape.IsArray()) {
545     return errors::InvalidArgument(
546         "ExecuteTensorListFromTensor() only supports normal tensor. But input "
547         "shape is ",
548         shape.DebugString());
549   }
550 
551   std::vector<xla::XlaOp> result_parts{tensor,
552                                        xla::ConstantR0<int32>(b, push_index)};
553   *result = xla::Tuple(b, result_parts);
554   return Status::OK();
555 }
556 
557 }  // namespace tensorflow
558