• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
20 #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_context.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/framework/kernel_def_builder.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 
34 namespace tensorflow {
35 
XlaGather(const xla::XlaOp & input,const TensorShape & input_shape,const xla::XlaOp & indices,const TensorShape & indices_shape,int64_t axis,bool indices_are_nd,DataType dtype,DataType index_type,xla::XlaBuilder * builder,xla::XlaOp * gather_output)36 Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
37                  const xla::XlaOp& indices, const TensorShape& indices_shape,
38                  int64_t axis, bool indices_are_nd, DataType dtype,
39                  DataType index_type, xla::XlaBuilder* builder,
40                  xla::XlaOp* gather_output) {
41   // There is no deep reason why we need this precondition, but this is the only
42   // combination that is used and tested today.
43   CHECK(!indices_are_nd || axis == 0);
44 
45   // num_index_dims is the number of components in each index in the indices
46   // tensor.
47   //
48   // num_indices is the total number of (n dimensional or scalar) indices in the
49   // indices tensor.
50   //
51   // If the indices are N-dimensional, then the minor dimension of indices
52   // should be of size N and correspond to the N indices.
53   int64_t num_index_dims;
54   int64_t num_indices = 1;
55   if (indices_are_nd) {
56     CHECK_GE(indices_shape.dims(), 1);
57     num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
58     for (int64_t i = 0, e = indices_shape.dims() - 1; i < e; i++) {
59       num_indices *= indices_shape.dim_size(i);
60     }
61   } else {
62     num_index_dims = 1;
63     for (int64_t i = 0, e = indices_shape.dims(); i < e; i++) {
64       num_indices *= indices_shape.dim_size(i);
65     }
66   }
67 
68   // Degenerate case: empty indices.
69   if (num_indices == 0) {
70     TensorShape input_shape_pre_axis{input_shape};
71     input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
72     TensorShape input_shape_post_axis{input_shape};
73     input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
74 
75     TensorShape indices_shape_no_index_vectors{indices_shape};
76     if (indices_are_nd) {
77       indices_shape_no_index_vectors.RemoveLastDims(1);
78     }
79 
80     TensorShape out_shape;
81     out_shape.AppendShape(input_shape_pre_axis);
82     out_shape.AppendShape(indices_shape_no_index_vectors);
83     out_shape.AppendShape(input_shape_post_axis);
84 
85     *gather_output =
86         xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
87     return OkStatus();
88   }
89 
90   for (int64_t i = 0; i < num_index_dims; ++i) {
91     if (input_shape.dim_size(axis + i) == 0) {
92       return errors::InvalidArgument("Gather dimension ", axis + i,
93                                      " is of size zero in tensor with shape ",
94                                      input_shape.DebugString());
95     }
96   }
97 
98   // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
99   // tensor of shape [3,3].
100   //
101   //  operand = s32[3,3] parameter(0)
102   //  indices = s32[2] parameter(1)
103   //  gather = s32[3,2] gather(operand, indices),
104   //       offset_dims={0},
105   //       collapsed_slice_dims={1},
106   //       start_index_map={1},
107   //       index_vector_dim=1,
108   //       slice_sizes={3, 1}
109   //
110   //
111   // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
112   // tensor of shape [3,3,2].
113   //
114   //  operand = s32[3,3,2] parameter(0)
115   //  indices = s32[2,2] parameter(1)
116   //  gather = s32[2,2] gather(operand, indices),
117   //       offset_dims={1},
118   //       collapsed_slice_dims={0,1},
119   //       start_index_map={0,1},
120   //       index_vector_dim=0,
121   //       slice_sizes={1,1,2}
122 
123   xla::GatherDimensionNumbers dim_numbers;
124   std::vector<int64_t> slice_sizes;
125   slice_sizes.reserve(input_shape.dims());
126   for (int64_t i = 0; i < input_shape.dims(); i++) {
127     int64_t window_bound;
128     if (axis <= i && i < (axis + num_index_dims)) {
129       dim_numbers.add_collapsed_slice_dims(i);
130       window_bound = 1;
131     } else {
132       window_bound = input_shape.dim_size(i);
133     }
134 
135     slice_sizes.push_back(window_bound);
136 
137     if (i < axis) {
138       dim_numbers.add_offset_dims(i);
139     } else if (i >= (axis + num_index_dims)) {
140       int64_t indices_rank =
141           indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
142       dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
143     }
144   }
145 
146   dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
147                                                   : indices_shape.dims());
148   for (int64_t i = axis; i < axis + num_index_dims; i++) {
149     dim_numbers.add_start_index_map(i);
150   }
151 
152   *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
153   return OkStatus();
154 }
155 
XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext * context,const xla::XlaOp input,const TensorShape & input_shape,int batch_dims,xla::XlaOp * gather_output)156 Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context,
157                                     const xla::XlaOp input,
158                                     const TensorShape& input_shape,
159                                     int batch_dims, xla::XlaOp* gather_output) {
160   auto indices = context->Input(1);
161   auto indices_shape = context->InputShape(1);
162 
163   std::optional<int64_t> axis;
164   if (context->num_inputs() == 3) {
165     const TensorShape axis_shape = context->InputShape(2);
166     if (!TensorShapeUtils::IsScalar(axis_shape)) {
167       return errors::InvalidArgument("axis must be scalar");
168     }
169     DataType axis_type = context->input_type(2);
170     if (axis_type != DT_INT32 && axis_type != DT_INT64) {
171       return errors::InvalidArgument("axis must be int32 or int64");
172     }
173 
174     int64_t axis_input;
175     TF_RETURN_IF_ERROR(context->ConstantInputAsIntScalar(2, &axis_input));
176 
177     const auto params_dims = input_shape.dims();
178     if (-params_dims > axis_input || axis_input >= params_dims) {
179       // Check that params has rank of at least axis + 1.
180       const auto min_params_rank =
181           axis_input < 0 ? -axis_input : axis_input + 1;
182       return errors::InvalidArgument("Shape must be at least rank ",
183                                      min_params_rank, " but is rank ",
184                                      params_dims);
185     }
186     if (axis_input < 0) {
187       axis_input += params_dims;
188     }
189     axis = axis_input;
190   }
191 
192   if (batch_dims != 0) {
193     if (batch_dims < 0) {
194       batch_dims = indices_shape.dims() + batch_dims;
195     }
196 
197     axis = axis.value_or(batch_dims);
198 
199     if (batch_dims < -indices_shape.dims() ||
200         batch_dims > indices_shape.dims()) {
201       return errors::InvalidArgument(
202           "Expected batch_dims in the range [", -indices_shape.dims(), ", ",
203           indices_shape.dims(), "], but got ", batch_dims);
204     }
205 
206     if (batch_dims >= input_shape.dims()) {
207       return errors::InvalidArgument("batch_dims (", batch_dims,
208                                      ") must be less than rank(input) (",
209                                      input_shape.dims(), ").");
210     }
211 
212     if (*axis < batch_dims) {
213       return errors::InvalidArgument("batch_dims (", batch_dims,
214                                      ") must be less than or equal to ",
215                                      "axis (", *axis, ").");
216     }
217   }
218 
219   axis = axis.value_or(0);
220   DataType index_type = context->input_type(1);
221   if (index_type != DT_INT32 && index_type != DT_INT64) {
222     return errors::InvalidArgument("indices must be int32 or int64");
223   }
224 
225   xla::XlaOp gather;
226   if (batch_dims > 0) {
227     *gather_output = xla::TorchIndexSelect(input, indices, *axis, batch_dims);
228   } else {
229     // XlaGather() manages degenerate cases, like empty-indices, which are
230     // error conditions and caught above if batch_dims is not 0.
231     TF_RETURN_IF_ERROR(
232         XlaGather(input, input_shape, indices, indices_shape, *axis,
233                   /*indices_are_nd=*/false, context->expected_output_dtype(0),
234                   index_type, context->builder(), gather_output));
235   }
236   return OkStatus();
237 }
238 class GatherOp : public XlaOpKernel {
239  public:
GatherOp(OpKernelConstruction * context)240   explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {
241     // Set batch_dims_ to 0 if the attribute does not exist.
242     if (context->HasAttr("batch_dims")) {
243       OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_));
244     } else {
245       batch_dims_ = 0;
246     }
247   }
248 
Compile(XlaOpKernelContext * context)249   void Compile(XlaOpKernelContext* context) override {
250     auto input = context->Input(0);
251     auto input_shape = context->InputShape(0);
252 
253     xla::XlaOp gather;
254     OP_REQUIRES_OK(context,
255                    XlaGatherWithBatchDimsOpImpl(context, input, input_shape,
256                                                 batch_dims_, &gather));
257     context->SetOutput(0, gather);
258   }
259 
260  private:
261   TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
262 
263   // The number of batch dimensions, as passed in the batch_dims attribute.
264   // It must be less than or equal to rank(indices).
265   int32 batch_dims_ = 0;
266 };
267 
268 REGISTER_XLA_OP(Name("Gather"), MlirXlaOpKernel);
269 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstantInput("axis"), GatherOp);
270 
271 class GatherNdOp : public XlaOpKernel {
272  public:
GatherNdOp(OpKernelConstruction * context)273   explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
274 
Compile(XlaOpKernelContext * context)275   void Compile(XlaOpKernelContext* context) override {
276     DataType params_type = context->input_type(0);
277     DataType indices_type = context->input_type(1);
278 
279     TensorShape params_shape = context->InputShape(0);
280     TensorShape indices_shape = context->InputShape(1);
281     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape),
282                 errors::InvalidArgument("params must be at least a vector"));
283     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
284                 errors::InvalidArgument("indices must be at least a vector"));
285     const int64_t num_index_dims =
286         indices_shape.dim_size(indices_shape.dims() - 1);
287     OP_REQUIRES(
288         context, num_index_dims <= params_shape.dims(),
289         errors::InvalidArgument(
290             "index innermost dimension length must be <= params rank; saw: ",
291             indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
292             params_shape.dims()));
293 
294     xla::XlaBuilder* builder = context->builder();
295     auto params = context->Input(0);
296     auto indices = context->Input(1);
297     xla::XlaOp gather;
298     OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
299                                       indices_shape, /*axis=*/0,
300                                       /*indices_are_nd=*/true, params_type,
301                                       indices_type, builder, &gather));
302     context->SetOutput(0, gather);
303   }
304 };
305 
306 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp);
307 
308 }  // namespace tensorflow
309