1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/shape.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 #include "tensorflow/compiler/xla/statusor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/lib/core/errors.h"
24
25 namespace tensorflow {
26
IsTensorListInput(XlaOpKernelContext * ctx,int index)27 bool IsTensorListInput(XlaOpKernelContext* ctx, int index) {
28 return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList;
29 }
30
BuildTensorList(const xla::XlaOp & buffer,const xla::XlaOp & push_index,xla::XlaOp * output_list)31 Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
32 xla::XlaOp* output_list) {
33 TF_RET_CHECK(buffer.builder());
34 *output_list = xla::Tuple(buffer.builder(), {buffer, push_index});
35 return Status::OK();
36 }
37
GetTensorListBuffer(const xla::XlaOp & op,xla::XlaOp * buffer)38 Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) {
39 TF_RET_CHECK(op.builder());
40 *buffer = xla::GetTupleElement(op, 0);
41 return Status::OK();
42 }
43
GetTensorListPushIndex(const xla::XlaOp & op,xla::XlaOp * push_index)44 Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) {
45 TF_RET_CHECK(op.builder());
46 *push_index = xla::GetTupleElement(op, 1);
47 return Status::OK();
48 }
49
GetTensorListBufferShape(const xla::XlaOp & op,TensorShape * buffer_shape)50 Status GetTensorListBufferShape(const xla::XlaOp& op,
51 TensorShape* buffer_shape) {
52 TF_RET_CHECK(op.builder());
53 TensorShape shape;
54 TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape,
55 op.builder()->GetShape(op));
56 return GetTensorListBufferShape(list_tuple_shape, buffer_shape);
57 }
58
GetTensorListBufferShape(const xla::Shape & list_shape,TensorShape * buffer_shape)59 Status GetTensorListBufferShape(const xla::Shape& list_shape,
60 TensorShape* buffer_shape) {
61 TF_RET_CHECK(list_shape.IsTuple());
62 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
63 xla::ShapeUtil::GetTupleElementShape(list_shape, 0), buffer_shape));
64 return Status::OK();
65 }
66
IsTensorListInitialized(const xla::XlaOp & op,bool * is_initialized)67 Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) {
68 TensorShape list_shape;
69 TF_RETURN_IF_ERROR(GetTensorListBufferShape(op, &list_shape));
70 *is_initialized = !(list_shape.dims() == 2 && list_shape.dim_size(1) == 0);
71 return Status::OK();
72 }
73
InitializeTensorList(const xla::XlaOp & uninitialized_list,const TensorShape & buffer_shape,xla::XlaOp * output_list)74 Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
75 const TensorShape& buffer_shape,
76 xla::XlaOp* output_list) {
77 TensorShape input_buffer_shape;
78 TF_RETURN_IF_ERROR(
79 GetTensorListBufferShape(uninitialized_list, &input_buffer_shape));
80 if (input_buffer_shape.dim_size(0) != buffer_shape.dim_size(0)) {
81 return errors::InvalidArgument(
82 "Number of elements in input list does not match buffer size. ",
83 "input list size: ", input_buffer_shape.dim_size(0),
84 "buffer size: ", buffer_shape.dim_size(0));
85 }
86 xla::XlaBuilder* builder = uninitialized_list.builder();
87 xla::XlaOp input_buffer;
88 TF_RETURN_IF_ERROR(GetTensorListBuffer(uninitialized_list, &input_buffer));
89 TF_ASSIGN_OR_RETURN(const xla::Shape& input_buffer_xla_shape,
90 builder->GetShape(input_buffer));
91 auto new_buffer = xla::Broadcast(
92 xla::ConstantLiteral(builder, xla::LiteralUtil::Zero(
93 input_buffer_xla_shape.element_type())),
94 buffer_shape.dim_sizes());
95 xla::XlaOp push_index;
96 TF_RETURN_IF_ERROR(GetTensorListPushIndex(uninitialized_list, &push_index));
97 return BuildTensorList(new_buffer, push_index, output_list);
98 }
99
100 } // namespace tensorflow
101