• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <limits>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/util/util.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 
30 // For each slice in `(start, limit)` in `value_slices`, append
31 // `params_dense_values_in[start:limit] to `values_out`.  `value_size` indicates
32 // the number of scalars contained in each value params_dense_values_in[i].
33 template <typename VALUE_TYPE, typename SPLITS_TYPE>
WriteValueSlices(const Tensor & params_dense_values_in,const std::vector<std::pair<SPLITS_TYPE,SPLITS_TYPE>> & value_slices,SPLITS_TYPE value_size,Tensor * values_out)34 void WriteValueSlices(
35     const Tensor& params_dense_values_in,
36     const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
37     SPLITS_TYPE value_size, Tensor* values_out) {
38   const auto& params_dense_values =
39       params_dense_values_in.flat_outer_dims<VALUE_TYPE, 2>();
40   auto values = values_out->flat_outer_dims<VALUE_TYPE, 2>();
41   int out_pos = 0;
42   for (const auto& slice : value_slices) {
43     for (int i = slice.first; i < slice.second; ++i) {
44       for (int j = 0; j < value_size; ++j) {
45         values(out_pos, j) = params_dense_values(i, j);
46       }
47       ++out_pos;
48     }
49   }
50 }
51 
52 }  // namespace
53 
54 template <typename INDEX_TYPE, typename SPLITS_TYPE>
55 class RaggedGatherOpBase : public OpKernel {
56  public:
57   using OpKernel::OpKernel;
58 
Compute(OpKernelContext * context)59   void Compute(OpKernelContext* context) override {
60     // Get the input Tensors.
61 
62     OpInputList params_nested_splits_in;
63     OP_REQUIRES_OK(context, context->input_list("params_nested_splits",
64                                                 &params_nested_splits_in));
65     OP_REQUIRES(
66         context, params_nested_splits_in.size() > 0,
67         errors::InvalidArgument("params_nested_splits must be non empty"));
68 
69     const Tensor& params_dense_values_in =
70         context->input(params_nested_splits_in.size());
71     const Tensor& indices_in =
72         context->input(params_nested_splits_in.size() + 1);
73 
74     OP_REQUIRES(context, params_nested_splits_in[0].dims() > 0,
75                 errors::InvalidArgument("Split tensors must not be scalars"));
76     SPLITS_TYPE num_params = params_nested_splits_in[0].dim_size(0) - 1;
77     OP_REQUIRES_OK(context, ValidateIndices(indices_in, num_params));
78 
79     OP_REQUIRES(context, params_dense_values_in.dims() > 0,
80                 errors::InvalidArgument("params.rank must be nonzero"));
81     SPLITS_TYPE num_params_dense_values = params_dense_values_in.dim_size(0);
82 
83     // Calculate the `splits`, and store the value slices that we need to
84     // copy in `value_slices`.
85     std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>> value_slices;
86     SPLITS_TYPE num_values = 0;
87     std::vector<std::vector<SPLITS_TYPE>> out_splits;
88     OP_REQUIRES_OK(context, MakeSplits(indices_in, params_nested_splits_in,
89                                        num_params_dense_values, &out_splits,
90                                        &value_slices, &num_values));
91 
92     // Write the output tensors.
93     OP_REQUIRES_OK(context, WriteSplits(out_splits, context));
94     OP_REQUIRES_OK(context,
95                    WriteValues(params_dense_values_in, value_slices,
96                                out_splits.size(), num_values, context));
97   }
98 
99  private:
100   using ConstFlatType = typename TTypes<SPLITS_TYPE>::ConstFlat;
101 
102   // Check if any indices are out-of-bounds.
ValidateIndices(const Tensor & indices_in,SPLITS_TYPE num_params)103   ::tensorflow::Status ValidateIndices(const Tensor& indices_in,
104                                        SPLITS_TYPE num_params) {
105     const auto& indices = indices_in.flat<INDEX_TYPE>();
106     for (SPLITS_TYPE i = 0; i < indices.size(); ++i) {
107       SPLITS_TYPE index = indices(i);
108       if (index < 0 || index >= num_params) {
109         return errors::InvalidArgument(
110             "indices", SliceDebugString(indices_in.shape(), i), " = ", index,
111             " is not in [0, ", num_params, ")");
112       }
113     }
114     return ::tensorflow::Status::OK();
115   }
116 
117   // Construct the `splits` output tensors, encoded using a nested vector.
118   // Also find the slices of values that need to be copied, and store them
119   // in `value_slices`.  The total number of values that will be copied (which
120   // we need for allocating the output values tensor) is stored in `num_values`.
MakeSplits(const Tensor & indices_in,const OpInputList & params_nested_splits_in,SPLITS_TYPE num_params_dense_values,std::vector<std::vector<SPLITS_TYPE>> * out_splits,std::vector<std::pair<SPLITS_TYPE,SPLITS_TYPE>> * value_slices,SPLITS_TYPE * num_values)121   ::tensorflow::Status MakeSplits(
122       const Tensor& indices_in, const OpInputList& params_nested_splits_in,
123       SPLITS_TYPE num_params_dense_values,
124       std::vector<std::vector<SPLITS_TYPE>>* out_splits,
125       std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>* value_slices,
126       SPLITS_TYPE* num_values) {
127     *num_values = 0;
128     value_slices->clear();
129 
130     int num_splits = indices_in.dims() - 1 + params_nested_splits_in.size();
131     out_splits->assign(num_splits, {0});
132 
133     // Get Eigen tensors.
134     const auto& indices = indices_in.flat<INDEX_TYPE>();
135     std::vector<ConstFlatType> params_nested_splits;
136     params_nested_splits.reserve(params_nested_splits_in.size());
137     for (const auto& splits_in : params_nested_splits_in) {
138       params_nested_splits.push_back(splits_in.flat<SPLITS_TYPE>());
139     }
140 
141     TF_RETURN_IF_ERROR(
142         ValidateSplits(params_nested_splits, num_params_dense_values));
143 
144     // Add `splits` that come from all but the last dimension of the dense
145     // Tensor `indices`.  In particular, for each dimension D, we add a
146     // splits tensor whose values are:
147     //   range(reduce_prod(splits.shape[:D]) + 1) * splits.shape[D+1]
148     // E.g., if indices.shape=[2, 3, 4] then we will add splits tensors:
149     //   [0, 3, 6]                    # length=2+1, stride=3
150     //   [0, 4, 8, 12, 16, 20, 24]    # length=2*3+1, stride=4
151     int nrows = 1;
152     for (int dim = 0; dim < indices_in.dims() - 1; ++dim) {
153       nrows *= indices_in.dim_size(dim);
154       int row_length = indices_in.dim_size(dim + 1);
155       for (int i = 1; i < nrows + 1; ++i) {
156         out_splits->at(dim).push_back(i * row_length);
157       }
158     }
159 
160     // Add `splits` that come from `params_nested_splits`.  Starting with the
161     // outermost ragged dimension (i.e., the first `splits` tensor), we work
162     // our way in, finding the range of values that should be copied.  As we
163     // go, we update the output `splits` for each dimension with the appropriate
164     // values.  In particular, the *lengths* of the slices from `param_splits`
165     // should be copied to generate corresponding slice lengths in the output
166     // splits.  E.g., if we are copying a ragged row with length 4, then we
167     // should add a new split point to out_splits that is 4 greater than the
168     // previous split point in out_splits.
169     for (int i = 0; i < indices.size(); ++i) {
170       int start = indices(i);
171       int limit = indices(i) + 1;
172 
173       // Copy splits.
174       for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
175         const auto& splits = params_nested_splits[dim];
176         int out_dim = dim + indices_in.dims() - 1;
177         if (out_dim >= 0) {
178           SPLITS_TYPE delta = out_splits->at(out_dim).back() - splits(start);
179           for (int j = start; j < limit; ++j) {
180             out_splits->at(out_dim).push_back(splits(j + 1) + delta);
181           }
182         }
183         start = splits(start);
184         limit = splits(limit);
185       }
186       if (limit != start) {
187         value_slices->emplace_back(start, limit);
188         *num_values += limit - start;
189       }
190     }
191     return ::tensorflow::Status::OK();
192   }
193 
ValidateSplits(const std::vector<ConstFlatType> & params_nested_splits,SPLITS_TYPE num_params_dense_values)194   ::tensorflow::Status ValidateSplits(
195       const std::vector<ConstFlatType>& params_nested_splits,
196       SPLITS_TYPE num_params_dense_values) {
197     // Validate
198     for (int dim = 0; dim < params_nested_splits.size(); ++dim) {
199       const auto& splits = params_nested_splits[dim];
200       SPLITS_TYPE last_split = (dim == params_nested_splits.size() - 1)
201                                    ? num_params_dense_values
202                                    : params_nested_splits[dim + 1].size();
203       if (splits.size() == 0) {
204         return errors::InvalidArgument("Ragged splits may not be empty");
205       }
206       if (splits(0) < 0) {
207         return errors::InvalidArgument("Ragged splits must be non-negative");
208       }
209       if (splits(splits.size() - 1) > last_split) {
210         return errors::InvalidArgument(
211             "Ragged splits must not point past values");
212       }
213       for (int i = 1; i < splits.size(); ++i) {
214         if (splits(i - 1) > splits(i)) {
215           return errors::InvalidArgument("Ragged splits must be sorted");
216         }
217       }
218     }
219     return ::tensorflow::Status::OK();
220   }
221 
WriteSplits(const std::vector<std::vector<SPLITS_TYPE>> & out_splits,OpKernelContext * context)222   ::tensorflow::Status WriteSplits(
223       const std::vector<std::vector<SPLITS_TYPE>>& out_splits,
224       OpKernelContext* context) {
225     OpOutputList splits_out;
226     TF_RETURN_IF_ERROR(
227         context->output_list("output_nested_splits", &splits_out));
228     for (int i = 0; i < out_splits.size(); ++i) {
229       Tensor* splits;
230       SPLITS_TYPE num_splits = out_splits[i].size();
231       TF_RETURN_IF_ERROR(
232           splits_out.allocate(i, TensorShape({num_splits}), &splits));
233       auto splits_flat = splits->flat<SPLITS_TYPE>();
234       std::copy_n(out_splits[i].data(), out_splits[i].size(),
235                   splits_flat.data());
236     }
237     return ::tensorflow::Status::OK();
238   }
239 
WriteValues(const Tensor & params_dense_values_in,const std::vector<std::pair<SPLITS_TYPE,SPLITS_TYPE>> & value_slices,int values_index,SPLITS_TYPE num_values,OpKernelContext * context) const240   ::tensorflow::Status WriteValues(
241       const Tensor& params_dense_values_in,
242       const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
243       int values_index, SPLITS_TYPE num_values,
244       OpKernelContext* context) const {
245     Tensor* values_out = nullptr;
246     TensorShape values_shape = params_dense_values_in.shape();
247     values_shape.set_dim(0, num_values);
248     TF_RETURN_IF_ERROR(
249         context->allocate_output(values_index, values_shape, &values_out));
250     const SPLITS_TYPE num_elements = params_dense_values_in.NumElements();
251     const SPLITS_TYPE value_size =
252         num_elements == 0 ? 0
253                           : (num_elements / params_dense_values_in.dim_size(0));
254     CallWriteValueSlices(params_dense_values_in, value_slices, value_size,
255                          values_out);
256     return ::tensorflow::Status::OK();
257   }
258 
259  protected:
260   // Call WriteValueSlices() using the appropriate VALUE_TYPE template
261   // parameter.  This pattern is used to reduce binary size.  In particular,
262   // this allows us to have two instantiations of this class (one for each
263   // index type), rather than 14 (one for each index type and value type),
264   // which cuts the binary size of this op from ~300k to <90k.
265   virtual void CallWriteValueSlices(
266       const Tensor& params_dense_values_in,
267       const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
268       SPLITS_TYPE value_size, Tensor* values_out) const = 0;
269 };
270 
271 template <typename INDEX_TYPE, typename VALUE_TYPE, typename SPLITS_TYPE>
272 class RaggedGatherOp : public RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE> {
273  public:
274   using RaggedGatherOpBase<INDEX_TYPE, SPLITS_TYPE>::RaggedGatherOpBase;
275 
276  private:
CallWriteValueSlices(const Tensor & params_dense_values_in,const std::vector<std::pair<SPLITS_TYPE,SPLITS_TYPE>> & value_slices,SPLITS_TYPE value_size,Tensor * values_out) const277   void CallWriteValueSlices(
278       const Tensor& params_dense_values_in,
279       const std::vector<std::pair<SPLITS_TYPE, SPLITS_TYPE>>& value_slices,
280       SPLITS_TYPE value_size, Tensor* values_out) const override {
281     WriteValueSlices<VALUE_TYPE>(params_dense_values_in, value_slices,
282                                  value_size, values_out);
283   }
284 };
285 
286 #define REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(index_type, value_type, \
287                                             splits_type)            \
288   REGISTER_KERNEL_BUILDER(                                          \
289       Name("RaggedGather")                                          \
290           .Device(DEVICE_CPU)                                       \
291           .TypeConstraint<index_type>("Tindices")                   \
292           .TypeConstraint<value_type>("Tvalues")                    \
293           .TypeConstraint<splits_type>("Tsplits"),                  \
294       RaggedGatherOp<index_type, value_type, splits_type>);
295 #define REGISTER_CPU_KERNEL(value_type)                         \
296   REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int32) \
297   REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int32) \
298   REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int32, value_type, int64) \
299   REGISTER_CPU_KERNEL_WITH_INDEX_TYPE(int64, value_type, int64)
300 TF_CALL_POD_TYPES(REGISTER_CPU_KERNEL);
301 TF_CALL_tstring(REGISTER_CPU_KERNEL);
302 TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
303 TF_CALL_quint16(REGISTER_CPU_KERNEL);
304 TF_CALL_qint16(REGISTER_CPU_KERNEL);
305 #undef REGISTER_CPU_KERNEL
306 #undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE
307 
308 }  // namespace tensorflow
309