1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
20 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_context.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
33
34 namespace tensorflow {
35
XlaGather(const xla::XlaOp & input,const TensorShape & input_shape,const xla::XlaOp & indices,const TensorShape & indices_shape,int64_t axis,bool indices_are_nd,DataType dtype,DataType index_type,xla::XlaBuilder * builder,xla::XlaOp * gather_output)36 Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
37 const xla::XlaOp& indices, const TensorShape& indices_shape,
38 int64_t axis, bool indices_are_nd, DataType dtype,
39 DataType index_type, xla::XlaBuilder* builder,
40 xla::XlaOp* gather_output) {
41 // There is no deep reason why we need this precondition, but this is the only
42 // combination that is used and tested today.
43 CHECK(!indices_are_nd || axis == 0);
44
45 // num_index_dims is the number of components in each index in the indices
46 // tensor.
47 //
48 // num_indices is the total number of (n dimensional or scalar) indices in the
49 // indices tensor.
50 //
51 // If the indices are N-dimensional, then the minor dimension of indices
52 // should be of size N and correspond to the N indices.
53 int64_t num_index_dims;
54 int64_t num_indices = 1;
55 if (indices_are_nd) {
56 CHECK_GE(indices_shape.dims(), 1);
57 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
58 for (int64_t i = 0, e = indices_shape.dims() - 1; i < e; i++) {
59 num_indices *= indices_shape.dim_size(i);
60 }
61 } else {
62 num_index_dims = 1;
63 for (int64_t i = 0, e = indices_shape.dims(); i < e; i++) {
64 num_indices *= indices_shape.dim_size(i);
65 }
66 }
67
68 // Degenerate case: empty indices.
69 if (num_indices == 0) {
70 TensorShape input_shape_pre_axis{input_shape};
71 input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
72 TensorShape input_shape_post_axis{input_shape};
73 input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
74
75 TensorShape indices_shape_no_index_vectors{indices_shape};
76 if (indices_are_nd) {
77 indices_shape_no_index_vectors.RemoveLastDims(1);
78 }
79
80 TensorShape out_shape;
81 out_shape.AppendShape(input_shape_pre_axis);
82 out_shape.AppendShape(indices_shape_no_index_vectors);
83 out_shape.AppendShape(input_shape_post_axis);
84
85 *gather_output =
86 xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
87 return OkStatus();
88 }
89
90 for (int64_t i = 0; i < num_index_dims; ++i) {
91 if (input_shape.dim_size(axis + i) == 0) {
92 return errors::InvalidArgument("Gather dimension ", axis + i,
93 " is of size zero in tensor with shape ",
94 input_shape.DebugString());
95 }
96 }
97
98 // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
99 // tensor of shape [3,3].
100 //
101 // operand = s32[3,3] parameter(0)
102 // indices = s32[2] parameter(1)
103 // gather = s32[3,2] gather(operand, indices),
104 // offset_dims={0},
105 // collapsed_slice_dims={1},
106 // start_index_map={1},
107 // index_vector_dim=1,
108 // slice_sizes={3, 1}
109 //
110 //
111 // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
112 // tensor of shape [3,3,2].
113 //
114 // operand = s32[3,3,2] parameter(0)
115 // indices = s32[2,2] parameter(1)
116 // gather = s32[2,2] gather(operand, indices),
117 // offset_dims={1},
118 // collapsed_slice_dims={0,1},
119 // start_index_map={0,1},
120 // index_vector_dim=0,
121 // slice_sizes={1,1,2}
122
123 xla::GatherDimensionNumbers dim_numbers;
124 std::vector<int64_t> slice_sizes;
125 slice_sizes.reserve(input_shape.dims());
126 for (int64_t i = 0; i < input_shape.dims(); i++) {
127 int64_t window_bound;
128 if (axis <= i && i < (axis + num_index_dims)) {
129 dim_numbers.add_collapsed_slice_dims(i);
130 window_bound = 1;
131 } else {
132 window_bound = input_shape.dim_size(i);
133 }
134
135 slice_sizes.push_back(window_bound);
136
137 if (i < axis) {
138 dim_numbers.add_offset_dims(i);
139 } else if (i >= (axis + num_index_dims)) {
140 int64_t indices_rank =
141 indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
142 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
143 }
144 }
145
146 dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
147 : indices_shape.dims());
148 for (int64_t i = axis; i < axis + num_index_dims; i++) {
149 dim_numbers.add_start_index_map(i);
150 }
151
152 *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
153 return OkStatus();
154 }
155
XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext * context,const xla::XlaOp input,const TensorShape & input_shape,int batch_dims,xla::XlaOp * gather_output)156 Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
157 const xla::XlaOp input,
158 const TensorShape& input_shape,
159 int batch_dims, xla::XlaOp* gather_output) {
160 auto indices = context->Input(1);
161 auto indices_shape = context->InputShape(1);
162
163 std::optional<int64_t> axis;
164 if (context->num_inputs() == 3) {
165 const TensorShape axis_shape = context->InputShape(2);
166 if (!TensorShapeUtils::IsScalar(axis_shape)) {
167 return errors::InvalidArgument("axis must be scalar");
168 }
169 DataType axis_type = context->input_type(2);
170 if (axis_type != DT_INT32 && axis_type != DT_INT64) {
171 return errors::InvalidArgument("axis must be int32 or int64");
172 }
173
174 int64_t axis_input;
175 TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
176
177 const auto params_dims = input_shape.dims();
178 if (-params_dims > axis_input || axis_input >= params_dims) {
179 // Check that params has rank of at least axis + 1.
180 const auto min_params_rank =
181 axis_input < 0 ? -axis_input : axis_input + 1;
182 return errors::InvalidArgument("Shape must be at least rank ",
183 min_params_rank, " but is rank ",
184 params_dims);
185 }
186 if (axis_input < 0) {
187 axis_input += params_dims;
188 }
189 axis = axis_input;
190 }
191
192 if (batch_dims != 0) {
193 if (batch_dims < 0) {
194 batch_dims = indices_shape.dims() + batch_dims;
195 }
196
197 axis = axis.value_or(batch_dims);
198
199 if (batch_dims < -indices_shape.dims() ||
200 batch_dims > indices_shape.dims()) {
201 return errors::InvalidArgument(
202 "Expected batch_dims in the range [", -indices_shape.dims(), ", ",
203 indices_shape.dims(), "], but got ", batch_dims);
204 }
205
206 if (batch_dims >= input_shape.dims()) {
207 return errors::InvalidArgument("batch_dims (", batch_dims,
208 ") must be less than rank(input) (",
209 input_shape.dims(), ").");
210 }
211
212 if (*axis < batch_dims) {
213 return errors::InvalidArgument("batch_dims (", batch_dims,
214 ") must be less than or equal to ",
215 "axis (", *axis, ").");
216 }
217 }
218
219 axis = axis.value_or(0);
220 DataType index_type = context->input_type(1);
221 if (index_type != DT_INT32 && index_type != DT_INT64) {
222 return errors::InvalidArgument("indices must be int32 or int64");
223 }
224
225 xla::XlaOp gather;
226 if (batch_dims > 0) {
227 *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
228 } else {
229 // XlaGather() manages degenerate cases, like empty-indices, which are
230 // error conditions and caught above if batch_dims is not 0.
231 TF_RETURN_IF_ERROR(
232 XlaGather(input, input_shape, indices, indices_shape, *axis,
233 /*indices_are_nd=*/false, context->expected_output_dtype(0),
234 index_type, context->builder(), gather_output));
235 }
236 return OkStatus();
237 }
238 class GatherOp : public XlaOpKernel {
239 public:
GatherOp(OpKernelConstruction * context)240 explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
241 // Set batch_dims_ to 0 if the attribute does not exist.
242 if (context->HasAttr("batch_dims")) {
243 OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
244 } else {
245 batch_dims_ = 0;
246 }
247 }
248
Compile(XlaOpKernelContext * context)249 void Compile(XlaOpKernelContext* context) override {
250 auto input = context->Input(0);
251 auto input_shape = context->InputShape(0);
252
253 xla::XlaOp gather;
254 OP_REQUIRES_OK(context,
255 XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
256 batch_dims_, &gather));
257 context->SetOutput(0, gather);
258 }
259
260 private:
261 TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
262
263 // The number of batch dimensions, as passed in the batch_dims attribute.
264 // It must be less than or equal to rank(indices).
265 int32 batch_dims_ = 0;
266 };
267
268 REGISTER_XLA_OP(Name("Gather"), MlirXlaOpKernel);
269 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstantInput("axis"), GatherOp);
270
271 class GatherNdOp : public XlaOpKernel {
272 public:
GatherNdOp(OpKernelConstruction * context)273 explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
274
Compile(XlaOpKernelContext * context)275 void Compile(XlaOpKernelContext* context) override {
276 DataType params_type = context->input_type(0);
277 DataType indices_type = context->input_type(1);
278
279 TensorShape params_shape = context->InputShape(0);
280 TensorShape indices_shape = context->InputShape(1);
281 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape),
282 errors::InvalidArgument("params must be at least a vector"));
283 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
284 errors::InvalidArgument("indices must be at least a vector"));
285 const int64_t num_index_dims =
286 indices_shape.dim_size(indices_shape.dims() - 1);
287 OP_REQUIRES(
288 context, num_index_dims <= params_shape.dims(),
289 errors::InvalidArgument(
290 "index innermost dimension length must be <= params rank; saw: ",
291 indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
292 params_shape.dims()));
293
294 xla::XlaBuilder* builder = context->builder();
295 auto params = context->Input(0);
296 auto indices = context->Input(1);
297 xla::XlaOp gather;
298 OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
299 indices_shape, /*axis=*/0,
300 /*indices_are_nd=*/true, params_type,
301 indices_type, builder, &gather));
302 context->SetOutput(0, gather);
303 }
304 };
305
306 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp);
307
308 } // namespace tensorflow
309