1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
17
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/literal_util.h"
21 #include "tensorflow/compiler/xla/shape.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/errors.h"
29
30 // TensorList is represented by a tuple.
31 // - The first part of the tuple is a buffer containing all the tensors,
32 // - The following parts are push indices for all nested levels of
33 // TensorLists. The last part is push index for the outermost TensorList.
34 //
35 // TensorList, as it name suggests, is conceptually a list of tensors. In actual
36 // representation of a non-nested TensorList, the buffer shape is
37 // [element_shape, tensor_list_size]. We will call tensor_list_size "leading
38 // dimension" below. Notice that the leading dimension must be a compile time
39 // constant, since it's part of the buffer shape.
40 //
41 // Example: consider a 3-level nested TensorList whose element type is scalar.
42 // Assume inner TensorList has leading dimension 4, middle TensorList has 3,
43 // and outer TensorList has 3.
44 // Assume that lower cased letter means there is data in that position, and "."
45 // means there is no data in that position.
46 // First element of outer TensorList:
47 // [ a . . . ]
48 // [ b c . . ]
49 // [ d e f . ]
50 // Second element of outer TensorList:
51 // [ g h i . ]
52 // [ j k . . ]
53 // [ . . . . ]
54 // Third element: not pushed yet.
55 //
56 // The first part of the tuple is an array of shape [3, 3, 4] containing data.
57 // The second part is an array of shape [3, 3], each element is push index
58 // for the inner TensorList. In this case, its values are:
59 // [ 1 2 3 ]
60 // [ 3 2 . ]
61 // [ . . . ]
62 // The third part is an array of shape [3], each element is push index for
63 // the middle TensorList. In this case, its values are:
64 // [ 3 ]
65 // [ 2 ]
66 // [ . ]
67 // The forth (and last) part is a scalar. It's the push index for the outer
68 // TensorList. In this case, its values is 2.
69 //
70 // Now imagine we need to push the following element to the outer TensorList:
71 // [ l . . . ]
72 // [ m n . . ]
73 // [ . . . . ]
74 // This element is represented by a tuple of 3 parts:
75 // First part is all data.
76 // Second part is push indices for the inner TensorList, which is [ 1 2 . ].
77 // Third part is push index for the middle TensorList, which is 2.
78 // Now let's do the push.
79 // First, we append its data to outer TensorList's data.
80 // Then we start to deal with push indices. Similar to data, we append push
81 // indices for each level of TensorList.
82 // For the inner TensorList: append push indices for the pushed element.
83 // [ 1 2 3 ] [ 1 2 3 ]
84 // [ 3 2 . ] + = [ 3 2 . ]
85 // [ . . . ] [ 1 2 . ] [ 1 2 . ]
86 // For the middle TensorList: append push indices for the pushed element.
87 // [ 3 ] [ 3 ]
88 // [ 2 ] + = [ 2 ]
89 // [ . ] [ 2 ] [ 2 ]
90 // For the outer TensorList: just add 1.
91 // 2 + 1 = 3
92 //
93 // Popping an element from the outer TensorList also follows a similar process.
94 // First part is data. We get data by slicing data with push index for outer
95 // TensorList (which is 3).
96 // Second part is push indices for inner TensorList. We get it by slicing
97 // push indices for inner TensorList with push index for outer TensorList (which
98 // is 3).
99 // [ 1 2 3 ]
100 // [ 3 2 . ]
101 // [ 1 2 . ] ===> This is what we want
102 // Third part is push index for middle TensorList. We get it by slicing
103 // push indices for middle TensorList with push index for outer TensorList
104 // (which is 3).
105 // [ 3 ]
106 // [ 2 ]
107 // [ 2 ] ===> This is what we want
108
109 namespace tensorflow {
110
IsTensorListInput(XlaOpKernelContext * ctx,int index)111 bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
112 return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
113 }
114
IsTensorListInitialized(xla::XlaOp list,bool * is_initialized)115 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) {
116 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
117 *is_initialized = list_shape.IsTuple();
118 return Status::OK();
119 }
120
IsNestedTensorList(xla::XlaOp list,bool * is_nested_list)121 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) {
122 bool is_initialized;
123 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
124 if (!is_initialized) {
125 return errors::InvalidArgument("TensorList is not initialized");
126 }
127 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
128 *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2);
129 return Status::OK();
130 }
131
BuildNonNestedTensorList(xla::XlaOp buffer,xla::XlaOp push_index,xla::XlaOp * output_list)132 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index,
133 xla::XlaOp* output_list) {
134 TF_RET_CHECK(buffer.builder());
135 *output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
136 return Status::OK();
137 }
138
GetTensorListBufferShape(xla::XlaOp list,xla::Shape * buffer_shape)139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) {
140 bool is_initialized;
141 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
142 if (!is_initialized) {
143 return errors::InvalidArgument("TensorList is not initialized");
144 }
145 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
146 *buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
147 return Status::OK();
148 }
149
GetTensorListBuffer(xla::XlaOp list,xla::XlaOp * buffer)150 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) {
151 bool is_initialized;
152 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
153 if (!is_initialized) {
154 return errors::InvalidArgument("TensorList is not initialized");
155 }
156 *buffer = xla::GetTupleElement(list, 0);
157 return Status::OK();
158 }
159
GetTensorListPushIndex(xla::XlaOp list,xla::XlaOp * push_index)160 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) {
161 bool is_initialized;
162 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
163 if (!is_initialized) {
164 return errors::InvalidArgument("TensorList is not initialized");
165 }
166 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
167 int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
168 *push_index = xla::GetTupleElement(list, tuple_size - 1);
169 return Status::OK();
170 }
171
SetTensorListPushIndex(xla::XlaOp list,xla::XlaOp push_index,xla::XlaOp * result)172 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
173 xla::XlaOp* result) {
174 bool is_initialized;
175 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
176 if (!is_initialized) {
177 return errors::InvalidArgument("TensorList is not initialized");
178 }
179 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
180 int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
181 std::vector<xla::XlaOp> result_parts;
182 result_parts.reserve(tuple_size);
183 for (int i = 0; i < tuple_size - 1; i++) {
184 result_parts.push_back(xla::GetTupleElement(list, i));
185 }
186 result_parts.push_back(push_index);
187 *result = xla::Tuple(list.builder(), result_parts);
188 return Status::OK();
189 }
190
BuildUninitializedTensorList(xla::XlaBuilder * b,int64 leading_dimension,bool leading_size_is_dynamic,xla::XlaOp leading_dim_size)191 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
192 int64 leading_dimension,
193 bool leading_size_is_dynamic,
194 xla::XlaOp leading_dim_size) {
195 auto zero =
196 xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::PrimitiveType::S32));
197 auto broadcast = xla::Broadcast(zero, std::vector<int64>{leading_dimension});
198 if (leading_size_is_dynamic) {
199 return xla::SetDimensionSize(broadcast, leading_dim_size, 0);
200 } else {
201 return broadcast;
202 }
203 }
204
GetLeadingDimForTensorList(xla::XlaOp list,int64 * leading_dim,bool * leading_dim_is_dynamic,xla::XlaOp * leading_dim_dynamic_size)205 Status GetLeadingDimForTensorList(xla::XlaOp list, int64* leading_dim,
206 bool* leading_dim_is_dynamic,
207 xla::XlaOp* leading_dim_dynamic_size) {
208 bool is_initialized;
209 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
210 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list));
211 if (is_initialized) {
212 auto buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
213 *leading_dim_is_dynamic = buffer_shape.is_dynamic_dimension(0);
214 auto buffer = xla::GetTupleElement(list, 0);
215 *leading_dim = buffer_shape.dimensions(0);
216 *leading_dim_dynamic_size = xla::GetDimensionSize(buffer, 0);
217 } else {
218 *leading_dim_is_dynamic = list_shape.is_dynamic_dimension(0);
219 *leading_dim = list_shape.dimensions(0);
220 *leading_dim_dynamic_size = xla::GetDimensionSize(list, 0);
221 }
222 return Status::OK();
223 }
224
GetTensorListShapeFromElementTensorListShape(const xla::Shape & element_tensor_list_shape,int64 leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)225 Status GetTensorListShapeFromElementTensorListShape(
226 const xla::Shape& element_tensor_list_shape, int64 leading_dim,
227 bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) {
228 std::vector<xla::Shape> shapes;
229 int tuple_size = xla::ShapeUtil::TupleElementCount(element_tensor_list_shape);
230 for (int i = 0; i < tuple_size; i++) {
231 const xla::Shape& shape =
232 xla::ShapeUtil::GetTupleElementShape(element_tensor_list_shape, i);
233 std::vector<int64> dimensions = xla::SpanToVector(shape.dimensions());
234 dimensions.insert(dimensions.begin(), leading_dim);
235 shapes.push_back(
236 xla::ShapeUtil::MakeShape(shape.element_type(), dimensions));
237 if (leading_dim_is_dynamic) {
238 shapes.back().set_dynamic_dimension(0, true);
239 }
240 }
241 shapes.push_back(
242 xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
243 *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
244 return Status::OK();
245 }
246
GetTensorListShapeFromElementShape(const xla::Shape & element_shape,int64 leading_dim,bool leading_dim_is_dynamic,xla::Shape * tensor_list_shape)247 Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape,
248 int64 leading_dim,
249 bool leading_dim_is_dynamic,
250 xla::Shape* tensor_list_shape) {
251 if (!element_shape.IsArray()) {
252 return errors::InvalidArgument(
253 "GetTensorListShapeFromElementShape() only supports normal tensor "
254 "shape. But element shape is ",
255 element_shape.DebugString());
256 }
257 std::vector<xla::Shape> shapes;
258 std::vector<int64> dimensions = xla::SpanToVector(element_shape.dimensions());
259 dimensions.insert(dimensions.begin(), leading_dim);
260 shapes.push_back(
261 xla::ShapeUtil::MakeShape(element_shape.element_type(), dimensions));
262 shapes.back().set_dynamic_dimension(0, leading_dim_is_dynamic);
263 shapes.push_back(
264 xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, std::vector<int64>{}));
265 *tensor_list_shape = xla::ShapeUtil::MakeTupleShape(shapes);
266 return Status::OK();
267 }
268
CreateZerosTensorListWithShape(xla::XlaBuilder * b,const xla::Shape & list_shape,const std::vector<std::vector<xla::XlaOp>> & dynamic_dims,xla::XlaOp * list)269 Status CreateZerosTensorListWithShape(
270 xla::XlaBuilder* b, const xla::Shape& list_shape,
271 const std::vector<std::vector<xla::XlaOp>>& dynamic_dims,
272 xla::XlaOp* list) {
273 int tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
274 std::vector<xla::XlaOp> elements;
275 TF_RET_CHECK(dynamic_dims.size() == tuple_size - 1);
276 for (int i = 0; i < tuple_size - 1; i++) {
277 const xla::Shape& shape =
278 xla::ShapeUtil::GetTupleElementShape(list_shape, i);
279 xla::XlaOp zero =
280 xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type()));
281 xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions());
282 TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size());
283 for (int64 dim = 0; dim < shape.dimensions_size(); ++dim) {
284 zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim);
285 }
286 elements.push_back(zeros);
287 }
288 // List size (last item) has to be S32.
289 TF_RET_CHECK(xla::ShapeUtil::GetTupleElementShape(list_shape, tuple_size - 1)
290 .element_type() == xla::S32);
291 elements.push_back(xla::ConstantLiteral(b, xla::LiteralUtil::Zero(xla::S32)));
292 *list = xla::Tuple(b, elements);
293 return Status::OK();
294 }
295
GetInitializedTensorListForElement(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * initialized_list)296 Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element,
297 bool element_is_tensor_list,
298 xla::XlaOp* initialized_list) {
299 int64 leading_dim;
300 xla::XlaOp leading_dim_dynamic_size;
301 bool leading_dim_is_dynamic;
302 TF_RETURN_IF_ERROR(GetLeadingDimForTensorList(
303 list, &leading_dim, &leading_dim_is_dynamic, &leading_dim_dynamic_size));
304
305 xla::XlaBuilder* b = list.builder();
306 xla::Shape list_shape;
307 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
308
309 if (element_is_tensor_list) {
310 TF_RETURN_IF_ERROR(GetTensorListShapeFromElementTensorListShape(
311 element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
312 } else {
313 TF_RETURN_IF_ERROR(GetTensorListShapeFromElementShape(
314 element_shape, leading_dim, leading_dim_is_dynamic, &list_shape));
315 }
316 bool is_initialized;
317 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
318 if (is_initialized) {
319 // Check shape of initialized list is correct.
320 TF_ASSIGN_OR_RETURN(xla::Shape original_list_shape, b->GetShape(list));
321 if (!xla::ShapeUtil::Compatible(original_list_shape, list_shape)) {
322 return errors::Internal(
323 "Invalid TensorList shape: ", original_list_shape.DebugString(),
324 ", expected: ", list_shape.DebugString());
325 }
326 *initialized_list = list;
327 return Status::OK();
328 } else {
329 // Prepare dynamic dimension dimensions for zero tensor list. The dynamic
330 // sizes are created by reading the dynamic dimension size of sub-elements.
331 std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
332 for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
333 std::vector<xla::XlaOp> dynamic_dims;
334 const xla::Shape& shape = list_shape.tuple_shapes(i);
335 dynamic_dims.push_back(leading_dim_dynamic_size);
336 xla::XlaOp sub_element;
337 if (element_is_tensor_list) {
338 sub_element = xla::GetTupleElement(element, i);
339 } else {
340 sub_element = element;
341 }
342 for (int64 dim = 0; dim < shape.dimensions_size() - 1; ++dim) {
343 dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
344 }
345 list_dynamic_dims.push_back(dynamic_dims);
346 }
347 return CreateZerosTensorListWithShape(b, list_shape, list_dynamic_dims,
348 initialized_list);
349 }
350 }
351
ExecuteTensorListPushBack(xla::XlaOp list,xla::XlaOp element,bool element_is_tensor_list,xla::XlaOp * result)352 Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element,
353 bool element_is_tensor_list,
354 xla::XlaOp* result) {
355 bool is_initialized;
356 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
357 if (!is_initialized) {
358 return errors::InvalidArgument("TensorList is not initialized");
359 }
360
361 xla::XlaBuilder* b = list.builder();
362 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
363 int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
364 xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
365
366 std::vector<xla::XlaOp> result_parts;
367
368 if (element_is_tensor_list) {
369 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
370 int element_tuple_size = xla::ShapeUtil::TupleElementCount(element_shape);
371 for (int i = 0; i < element_tuple_size; i++) {
372 const xla::Shape& element_part_shape =
373 xla::ShapeUtil::GetTupleElementShape(element_shape, i);
374 xla::XlaOp element_part = xla::GetTupleElement(element, i);
375 std::vector<int64> element_part_dims =
376 xla::SpanToVector(element_part_shape.dimensions());
377 element_part_dims.insert(element_part_dims.begin(), 1);
378 element_part = xla::Reshape(element_part, element_part_dims);
379
380 std::vector<xla::XlaOp> start_indices(
381 element_part_shape.dimensions_size() + 1,
382 xla::ConstantR0<int32>(b, 0));
383 start_indices[0] = push_index;
384
385 xla::XlaOp list_part = xla::GetTupleElement(list, i);
386 xla::XlaOp updated_list_part =
387 xla::DynamicUpdateSlice(list_part, element_part, start_indices);
388 result_parts.push_back(updated_list_part);
389 }
390 } else {
391 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
392 std::vector<int64> element_dims =
393 xla::SpanToVector(element_shape.dimensions());
394 element_dims.insert(element_dims.begin(), 1);
395 xla::XlaOp update = xla::Reshape(element, element_dims);
396
397 std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
398 xla::ConstantR0<int32>(b, 0));
399 start_indices[0] = push_index;
400
401 xla::XlaOp list_part = xla::GetTupleElement(list, 0);
402 xla::XlaOp updated_list_part =
403 xla::DynamicUpdateSlice(list_part, update, start_indices);
404 result_parts.push_back(updated_list_part);
405 }
406
407 xla::XlaOp updated_push_index = push_index + xla::ConstantR0<int32>(b, 1);
408 result_parts.push_back(updated_push_index);
409
410 *result = xla::Tuple(b, result_parts);
411 return Status::OK();
412 }
413
ExecuteTensorListPopBack(xla::XlaOp list,xla::XlaOp * list_result,xla::XlaOp * element_result,bool * element_is_tensor_list)414 Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result,
415 xla::XlaOp* element_result,
416 bool* element_is_tensor_list) {
417 bool is_initialized;
418 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
419 if (!is_initialized) {
420 return errors::InvalidArgument("TensorList is not initialized");
421 }
422
423 // If the TensorList is a nested TensorList, element will be TensorList.
424 TF_RETURN_IF_ERROR(IsNestedTensorList(list, element_is_tensor_list));
425
426 xla::XlaBuilder* b = list.builder();
427 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
428 int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape);
429 xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1);
430 push_index = push_index - xla::ConstantR0<int32>(b, 1);
431
432 std::vector<xla::XlaOp> list_result_parts, element_result_parts;
433 for (int i = 0; i < list_tuple_size - 1; i++) {
434 const xla::Shape& list_part_shape =
435 xla::ShapeUtil::GetTupleElementShape(list_shape, i);
436 std::vector<xla::XlaOp> start_indices(list_part_shape.dimensions_size(),
437 xla::ConstantR0<int32>(b, 0));
438 start_indices[0] = push_index;
439
440 std::vector<int64> slice_shape =
441 xla::SpanToVector(list_part_shape.dimensions());
442 slice_shape[0] = 1LL;
443
444 xla::XlaOp list_part = xla::GetTupleElement(list, i);
445 xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
446
447 slice_shape.erase(slice_shape.begin());
448 element_result_parts.push_back(xla::Reshape(read, slice_shape));
449 list_result_parts.push_back(list_part);
450 }
451 list_result_parts.push_back(push_index);
452
453 *list_result = xla::Tuple(b, list_result_parts);
454 if (*element_is_tensor_list) {
455 *element_result = xla::Tuple(b, element_result_parts);
456 } else {
457 *element_result = element_result_parts[0];
458 }
459
460 return Status::OK();
461 }
462
ExecuteTensorListSetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp element,xla::XlaOp * result)463 Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index,
464 xla::XlaOp element, xla::XlaOp* result) {
465 bool is_initialized;
466 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
467 if (!is_initialized) {
468 return errors::InvalidArgument("TensorList is not initialized");
469 }
470 bool is_nested;
471 TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
472 if (is_nested) {
473 return errors::Unimplemented(
474 "ExecuteTensorListSetItem() only supports non-nested TensorList");
475 }
476
477 xla::XlaBuilder* b = list.builder();
478 TF_ASSIGN_OR_RETURN(xla::Shape element_shape, b->GetShape(element));
479 std::vector<int64> element_dims =
480 xla::SpanToVector(element_shape.dimensions());
481 element_dims.insert(element_dims.begin(), 1);
482 xla::XlaOp update = xla::Reshape(element, element_dims);
483
484 std::vector<xla::XlaOp> start_indices(element_shape.dimensions_size() + 1,
485 xla::ConstantR0<int32>(b, 0));
486 start_indices[0] = index;
487
488 xla::XlaOp list_part = xla::GetTupleElement(list, 0);
489 xla::XlaOp updated_list_part =
490 xla::DynamicUpdateSlice(list_part, update, start_indices);
491
492 std::vector<xla::XlaOp> result_parts;
493 result_parts.push_back(updated_list_part);
494 result_parts.push_back(xla::GetTupleElement(list, 1));
495 *result = xla::Tuple(b, result_parts);
496 return Status::OK();
497 }
498
ExecuteTensorListGetItem(xla::XlaOp list,xla::XlaOp index,xla::XlaOp * result)499 Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index,
500 xla::XlaOp* result) {
501 bool is_initialized;
502 TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized));
503 if (!is_initialized) {
504 return errors::InvalidArgument("TensorList is not initialized");
505 }
506 bool is_nested;
507 TF_RETURN_IF_ERROR(IsNestedTensorList(list, &is_nested));
508 if (is_nested) {
509 return errors::Unimplemented(
510 "ExecuteTensorListGetItem() only supports non-nested TensorList");
511 }
512
513 xla::XlaBuilder* b = list.builder();
514 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list));
515 const xla::Shape& buffer_shape =
516 xla::ShapeUtil::GetTupleElementShape(list_shape, 0);
517 std::vector<xla::XlaOp> start_indices(buffer_shape.dimensions_size(),
518 xla::ConstantR0<int32>(b, 0));
519 start_indices[0] = index;
520
521 std::vector<int64> slice_shape = xla::SpanToVector(buffer_shape.dimensions());
522 slice_shape[0] = 1LL;
523
524 xla::XlaOp list_part = xla::GetTupleElement(list, 0);
525 xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape);
526 // Propagate dynamic dimensions from buffer to the sliced buffer, except for
527 // leading dimension (which is always static 1).
528 for (int64 i = 1; i < buffer_shape.dimensions_size(); ++i) {
529 if (buffer_shape.is_dynamic_dimension(i)) {
530 auto buffer = xla::GetTupleElement(list, 0);
531 auto gds = xla::GetDimensionSize(buffer, i);
532 read = xla::SetDimensionSize(read, gds, i);
533 }
534 }
535 slice_shape.erase(slice_shape.begin());
536 *result = xla::Reshape(read, slice_shape);
537 return Status::OK();
538 }
539
ExecuteTensorListFromTensor(int push_index,xla::XlaOp tensor,xla::XlaOp * result)540 Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor,
541 xla::XlaOp* result) {
542 xla::XlaBuilder* b = tensor.builder();
543 TF_ASSIGN_OR_RETURN(xla::Shape shape, b->GetShape(tensor));
544 if (!shape.IsArray()) {
545 return errors::InvalidArgument(
546 "ExecuteTensorListFromTensor() only supports normal tensor. But input "
547 "shape is ",
548 shape.DebugString());
549 }
550
551 std::vector<xla::XlaOp> result_parts{tensor,
552 xla::ConstantR0<int32>(b, push_index)};
553 *result = xla::Tuple(b, result_parts);
554 return Status::OK();
555 }
556
557 } // namespace tensorflow
558