1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_context.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26
27 namespace tensorflow {
28
XlaGather(const xla::XlaOp & input,const TensorShape & input_shape,const xla::XlaOp & indices,const TensorShape & indices_shape,int64 axis,bool indices_are_nd,DataType dtype,DataType index_type,xla::XlaBuilder * builder,xla::XlaOp * gather_output)29 Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
30 const xla::XlaOp& indices, const TensorShape& indices_shape,
31 int64 axis, bool indices_are_nd, DataType dtype,
32 DataType index_type, xla::XlaBuilder* builder,
33 xla::XlaOp* gather_output) {
34 // There is no deep reason why we need this precondition, but this is the only
35 // combination that is used and tested today.
36 CHECK(!indices_are_nd || axis == 0);
37
38 // num_index_dims is the number of components in each index in the indices
39 // tensor.
40 //
41 // num_indices is the total number of (n dimensional or scalar) indices in the
42 // indices tensor.
43 //
44 // If the indices are N-dimensional, then the minor dimension of indices
45 // should be of size N and correspond to the N indices.
46 int64 num_index_dims;
47 int64 num_indices = 1;
48 if (indices_are_nd) {
49 CHECK_GE(indices_shape.dims(), 1);
50 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
51 for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) {
52 num_indices *= indices_shape.dim_size(i);
53 }
54 } else {
55 num_index_dims = 1;
56 for (int64 i = 0, e = indices_shape.dims(); i < e; i++) {
57 num_indices *= indices_shape.dim_size(i);
58 }
59 }
60
61 // Degenerate case: empty indices.
62 if (num_indices == 0) {
63 TensorShape input_shape_pre_axis{input_shape};
64 input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
65 TensorShape input_shape_post_axis{input_shape};
66 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
67
68 TensorShape indices_shape_no_index_vectors{indices_shape};
69 if (indices_are_nd) {
70 indices_shape_no_index_vectors.RemoveLastDims(1);
71 }
72
73 TensorShape out_shape;
74 out_shape.AppendShape(input_shape_pre_axis);
75 out_shape.AppendShape(indices_shape_no_index_vectors);
76 out_shape.AppendShape(input_shape_post_axis);
77
78 *gather_output =
79 xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
80 return Status::OK();
81 }
82
83 for (int64 i = 0; i < num_index_dims; ++i) {
84 if (input_shape.dim_size(axis + i) == 0) {
85 return errors::InvalidArgument("Gather dimension ", axis + i,
86 " is of size zero in tensor with shape ",
87 input_shape.DebugString());
88 }
89 }
90
91 // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
92 // tensor of shape [3,3].
93 //
94 // operand = s32[3,3] parameter(0)
95 // indices = s32[2] parameter(1)
96 // gather = s32[3,2] gather(operand, indices),
97 // offset_dims={0},
98 // collapsed_slice_dims={1},
99 // start_index_map={1},
100 // index_vector_dim=1,
101 // slice_sizes={3, 1}
102 //
103 //
104 // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
105 // tensor of shape [3,3,2].
106 //
107 // operand = s32[3,3,2] parameter(0)
108 // indices = s32[2,2] parameter(1)
109 // gather = s32[2,2] gather(operand, indices),
110 // offset_dims={1},
111 // collapsed_slice_dims={0,1},
112 // start_index_map={0,1},
113 // index_vector_dim=0,
114 // slice_sizes={1,1,2}
115
116 xla::GatherDimensionNumbers dim_numbers;
117 std::vector<int64> slice_sizes;
118 slice_sizes.reserve(input_shape.dims());
119 for (int64 i = 0; i < input_shape.dims(); i++) {
120 int64 window_bound;
121 if (axis <= i && i < (axis + num_index_dims)) {
122 dim_numbers.add_collapsed_slice_dims(i);
123 window_bound = 1;
124 } else {
125 window_bound = input_shape.dim_size(i);
126 }
127
128 slice_sizes.push_back(window_bound);
129
130 if (i < axis) {
131 dim_numbers.add_offset_dims(i);
132 } else if (i >= (axis + num_index_dims)) {
133 int64 indices_rank =
134 indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
135 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
136 }
137 }
138
139 dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
140 : indices_shape.dims());
141 for (int64 i = axis; i < axis + num_index_dims; i++) {
142 dim_numbers.add_start_index_map(i);
143 }
144
145 *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
146 return Status::OK();
147 }
148
149 class GatherOp : public XlaOpKernel {
150 public:
GatherOp(OpKernelConstruction * context)151 explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
152
Compile(XlaOpKernelContext * context)153 void Compile(XlaOpKernelContext* context) override {
154 xla::XlaBuilder* builder = context->builder();
155 auto input = context->Input(0);
156 auto input_shape = context->InputShape(0);
157 auto indices = context->Input(1);
158 auto indices_shape = context->InputShape(1);
159 int64 axis = 0;
160 if (context->num_inputs() == 3) {
161 const TensorShape axis_shape = context->InputShape(2);
162 OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
163 errors::InvalidArgument("axis must be scalar"));
164 DataType axis_type = input_type(2);
165 OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
166 errors::InvalidArgument("axis must be int32 or int64"));
167
168 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
169 const auto params_dims = input_shape.dims();
170 OP_REQUIRES(
171 context, -params_dims <= axis && axis < params_dims,
172 errors::InvalidArgument("Expected axis in the range [", -params_dims,
173 ", ", params_dims, "), but got ", axis));
174 if (axis < 0) {
175 axis += params_dims;
176 }
177 }
178
179 DataType index_type = input_type(1);
180 OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
181 errors::InvalidArgument("indices must be int32 or int64"));
182
183 xla::XlaOp gather;
184 OP_REQUIRES_OK(
185 context, XlaGather(input, input_shape, indices, indices_shape, axis,
186 /*indices_are_nd=*/false, input_type(0), index_type,
187 builder, &gather));
188 context->SetOutput(0, gather);
189 }
190
191 private:
192 TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
193 };
194
195 REGISTER_XLA_OP(Name("Gather"), GatherOp);
196 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstantInput("axis"), GatherOp);
197
198 class GatherNdOp : public XlaOpKernel {
199 public:
GatherNdOp(OpKernelConstruction * context)200 explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
201
Compile(XlaOpKernelContext * context)202 void Compile(XlaOpKernelContext* context) override {
203 DataType params_type = context->input_type(0);
204 DataType indices_type = context->input_type(1);
205
206 TensorShape params_shape = context->InputShape(0);
207 TensorShape indices_shape = context->InputShape(1);
208 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape),
209 errors::InvalidArgument("params must be at least a vector"));
210 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
211 errors::InvalidArgument("indices must be at least a vector"));
212 const int64 num_index_dims =
213 indices_shape.dim_size(indices_shape.dims() - 1);
214 OP_REQUIRES(
215 context, num_index_dims <= params_shape.dims(),
216 errors::InvalidArgument(
217 "index innermost dimension length must be <= params rank; saw: ",
218 indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
219 params_shape.dims()));
220
221 xla::XlaBuilder* builder = context->builder();
222 auto params = context->Input(0);
223 auto indices = context->Input(1);
224 xla::XlaOp gather;
225 OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
226 indices_shape, /*axis=*/0,
227 /*indices_are_nd=*/true, params_type,
228 indices_type, builder, &gather));
229 context->SetOutput(0, gather);
230 }
231 };
232
233 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp);
234
235 } // namespace tensorflow
236