• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_context.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 
27 namespace tensorflow {
28 
XlaGather(const xla::XlaOp & input,const TensorShape & input_shape,const xla::XlaOp & indices,const TensorShape & indices_shape,int64 axis,bool indices_are_nd,DataType dtype,DataType index_type,xla::XlaBuilder * builder,xla::XlaOp * gather_output)29 Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
30                  const xla::XlaOp& indices, const TensorShape& indices_shape,
31                  int64 axis, bool indices_are_nd, DataType dtype,
32                  DataType index_type, xla::XlaBuilder* builder,
33                  xla::XlaOp* gather_output) {
34   // There is no deep reason why we need this precondition, but this is the only
35   // combination that is used and tested today.
36   CHECK(!indices_are_nd || axis == 0);
37 
38   // num_index_dims is the number of components in each index in the indices
39   // tensor.
40   //
41   // num_indices is the total number of (n dimensional or scalar) indices in the
42   // indices tensor.
43   //
44   // If the indices are N-dimensional, then the minor dimension of indices
45   // should be of size N and correspond to the N indices.
46   int64 num_index_dims;
47   int64 num_indices = 1;
48   if (indices_are_nd) {
49     CHECK_GE(indices_shape.dims(), 1);
50     num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
51     for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) {
52       num_indices *= indices_shape.dim_size(i);
53     }
54   } else {
55     num_index_dims = 1;
56     for (int64 i = 0, e = indices_shape.dims(); i < e; i++) {
57       num_indices *= indices_shape.dim_size(i);
58     }
59   }
60 
61   // Degenerate case: empty indices.
62   if (num_indices == 0) {
63     TensorShape input_shape_pre_axis{input_shape};
64     input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
65     TensorShape input_shape_post_axis{input_shape};
66     input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
67 
68     TensorShape indices_shape_no_index_vectors{indices_shape};
69     if (indices_are_nd) {
70       indices_shape_no_index_vectors.RemoveLastDims(1);
71     }
72 
73     TensorShape out_shape;
74     out_shape.AppendShape(input_shape_pre_axis);
75     out_shape.AppendShape(indices_shape_no_index_vectors);
76     out_shape.AppendShape(input_shape_post_axis);
77 
78     *gather_output =
79         xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
80     return Status::OK();
81   }
82 
83   for (int64 i = 0; i < num_index_dims; ++i) {
84     if (input_shape.dim_size(axis + i) == 0) {
85       return errors::InvalidArgument("Gather dimension ", axis + i,
86                                      " is of size zero in tensor with shape ",
87                                      input_shape.DebugString());
88     }
89   }
90 
91   // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a
92   // tensor of shape [3,3].
93   //
94   //  operand = s32[3,3] parameter(0)
95   //  indices = s32[2] parameter(1)
96   //  gather = s32[3,2] gather(operand, indices),
97   //       offset_dims={0},
98   //       collapsed_slice_dims={1},
99   //       start_index_map={1},
100   //       index_vector_dim=1,
101   //       slice_sizes={3, 1}
102   //
103   //
104   // Example of an N-D gather pulling out slices of shape [1,1,2] out of a
105   // tensor of shape [3,3,2].
106   //
107   //  operand = s32[3,3,2] parameter(0)
108   //  indices = s32[2,2] parameter(1)
109   //  gather = s32[2,2] gather(operand, indices),
110   //       offset_dims={1},
111   //       collapsed_slice_dims={0,1},
112   //       start_index_map={0,1},
113   //       index_vector_dim=0,
114   //       slice_sizes={1,1,2}
115 
116   xla::GatherDimensionNumbers dim_numbers;
117   std::vector<int64> slice_sizes;
118   slice_sizes.reserve(input_shape.dims());
119   for (int64 i = 0; i < input_shape.dims(); i++) {
120     int64 window_bound;
121     if (axis <= i && i < (axis + num_index_dims)) {
122       dim_numbers.add_collapsed_slice_dims(i);
123       window_bound = 1;
124     } else {
125       window_bound = input_shape.dim_size(i);
126     }
127 
128     slice_sizes.push_back(window_bound);
129 
130     if (i < axis) {
131       dim_numbers.add_offset_dims(i);
132     } else if (i >= (axis + num_index_dims)) {
133       int64 indices_rank =
134           indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims();
135       dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
136     }
137   }
138 
139   dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
140                                                   : indices_shape.dims());
141   for (int64 i = axis; i < axis + num_index_dims; i++) {
142     dim_numbers.add_start_index_map(i);
143   }
144 
145   *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
146   return Status::OK();
147 }
148 
149 class GatherOp : public XlaOpKernel {
150  public:
GatherOp(OpKernelConstruction * context)151   explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
152 
Compile(XlaOpKernelContext * context)153   void Compile(XlaOpKernelContext* context) override {
154     xla::XlaBuilder* builder = context->builder();
155     auto input = context->Input(0);
156     auto input_shape = context->InputShape(0);
157     auto indices = context->Input(1);
158     auto indices_shape = context->InputShape(1);
159     int64 axis = 0;
160     if (context->num_inputs() == 3) {
161       const TensorShape axis_shape = context->InputShape(2);
162       OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
163                   errors::InvalidArgument("axis must be scalar"));
164       DataType axis_type = input_type(2);
165       OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
166                   errors::InvalidArgument("axis must be int32 or int64"));
167 
168       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
169       const auto params_dims = input_shape.dims();
170       OP_REQUIRES(
171           context, -params_dims <= axis && axis < params_dims,
172           errors::InvalidArgument("Expected axis in the range [", -params_dims,
173                                   ", ", params_dims, "), but got ", axis));
174       if (axis < 0) {
175         axis += params_dims;
176       }
177     }
178 
179     DataType index_type = input_type(1);
180     OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
181                 errors::InvalidArgument("indices must be int32 or int64"));
182 
183     xla::XlaOp gather;
184     OP_REQUIRES_OK(
185         context, XlaGather(input, input_shape, indices, indices_shape, axis,
186                            /*indices_are_nd=*/false, input_type(0), index_type,
187                            builder, &gather));
188     context->SetOutput(0, gather);
189   }
190 
191  private:
192   TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
193 };
194 
195 REGISTER_XLA_OP(Name("Gather"), GatherOp);
196 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstantInput("axis"), GatherOp);
197 
198 class GatherNdOp : public XlaOpKernel {
199  public:
GatherNdOp(OpKernelConstruction * context)200   explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
201 
Compile(XlaOpKernelContext * context)202   void Compile(XlaOpKernelContext* context) override {
203     DataType params_type = context->input_type(0);
204     DataType indices_type = context->input_type(1);
205 
206     TensorShape params_shape = context->InputShape(0);
207     TensorShape indices_shape = context->InputShape(1);
208     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape),
209                 errors::InvalidArgument("params must be at least a vector"));
210     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
211                 errors::InvalidArgument("indices must be at least a vector"));
212     const int64 num_index_dims =
213         indices_shape.dim_size(indices_shape.dims() - 1);
214     OP_REQUIRES(
215         context, num_index_dims <= params_shape.dims(),
216         errors::InvalidArgument(
217             "index innermost dimension length must be <= params rank; saw: ",
218             indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
219             params_shape.dims()));
220 
221     xla::XlaBuilder* builder = context->builder();
222     auto params = context->Input(0);
223     auto indices = context->Input(1);
224     xla::XlaOp gather;
225     OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
226                                       indices_shape, /*axis=*/0,
227                                       /*indices_are_nd=*/true, params_type,
228                                       indices_type, builder, &gather));
229     context->SetOutput(0, gather);
230   }
231 };
232 
233 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp);
234 
235 }  // namespace tensorflow
236