• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <limits>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/platform/errors.h"
25 
26 namespace tensorflow {
27 
28 using errors::InvalidArgument;
29 
30 template <typename SPLITS_TYPE>
31 class RaggedTensorToSparseOp : public OpKernel {
32  public:
33   using OpKernel::OpKernel;
34   using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat;
35 
Compute(OpKernelContext * context)36   void Compute(OpKernelContext* context) override {
37     // Read the `rt_nested_splits` input & convert to Eigen tensors.
38     OpInputList rt_nested_splits_in;
39     OP_REQUIRES_OK(
40         context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
41     const int rt_nested_splits_len = rt_nested_splits_in.size();
42     OP_REQUIRES(context, rt_nested_splits_len > 0,
43                 errors::InvalidArgument("rt_nested_splits must be non empty"));
44     std::vector<ConstFlatSplits> rt_nested_splits;
45     rt_nested_splits.reserve(rt_nested_splits_len);
46     for (int i = 0; i < rt_nested_splits_len; ++i) {
47       rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>());
48     }
49 
50     // Read the `rt_dense_values` input.
51     const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len);
52     OP_REQUIRES_OK(context,
53                    ValidateInputs(rt_nested_splits, rt_dense_values_in));
54 
55     // Assemble each value in `sparse_indices` using three parts:
56     // - `index_prefix` is the index in dimensions up through the last ragged
57     //   dimension.
58     // - `index_middle` is the index in the last ragged dimension.
59     // - `index_suffix` is the index in the dense value dimensions.
60     std::vector<int64_t> index_prefix(rt_nested_splits_len);
61     std::vector<std::vector<int64_t>> index_suffixes =
62         MakeIndexSuffixes(rt_dense_values_in.shape());
63 
64     // Allocate the `sparse_indices` output tensor.
65     const int64_t nvals =
66         (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) *
67          index_suffixes.size());
68     const int64_t indices_len =
69         rt_nested_splits_len + rt_dense_values_in.dims();
70     Tensor* sparse_indices_out = nullptr;
71     OP_REQUIRES_OK(
72         context, context->allocate_output(0, TensorShape({nvals, indices_len}),
73                                           &sparse_indices_out));
74     auto sparse_indices = sparse_indices_out->tensor<int64_t, 2>();
75 
76     // pos[i] is the current position in rt_nested_splits[i].  final_pos is a
77     // reference to make it easier to refer to pos[-1].
78     std::vector<int64_t> pos(rt_nested_splits_len);
79     int64_t& final_pos = pos[rt_nested_splits_len - 1];
80 
81     // Each iteration through the loop, we increment pos[-1], and add indices
82     // for all the values corresponding to
83     // rt_nested_splits[-1][pos[-1]:pos[-1]+1].
84     int next_index = 0;
85     int max_final_pos = rt_nested_splits.back().size() - 1;
86     for (; final_pos < max_final_pos; ++final_pos) {
87       // Update `pos` to skip over completed elements (i.e., elements where
88       // we have already generated indices for all contained values).
89       for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) {
90         while (IsCompleted(pos, dim, rt_nested_splits)) {
91           pos[dim] += 1;
92         }
93       }
94 
95       // Update index_prefix.
96       for (int dim = 0; dim < index_prefix.size(); ++dim) {
97         int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0;
98         index_prefix[dim] = pos[dim] - start;
99       }
100 
101       // Get length of the final-ragged-dimension slice.
102       const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1];
103       int64_t slice_len = final_splits(final_pos + 1) - final_splits(final_pos);
104 
105       // Add sparse_indices for this slice.
106       for (int64_t i = 0; i < slice_len; ++i) {
107         for (const auto& index_suffix : index_suffixes) {
108           int dim = 0;
109           for (int64_t index : index_prefix) {  // index_prefix
110             sparse_indices(next_index, dim++) = index;
111           }
112           sparse_indices(next_index, dim++) = i;  // index_middle
113           for (int64_t index : index_suffix) {    // index_suffix
114             sparse_indices(next_index, dim++) = index;
115           }
116           DCHECK_EQ(dim, indices_len);
117           ++next_index;
118         }
119       }
120     }
121     DCHECK_EQ(next_index, nvals);
122 
123     // Output the `sparse_values` Tensor.
124     if (rt_dense_values_in.dims() == 1) {
125       context->set_output(1, rt_dense_values_in);
126     } else {
127       Tensor sparse_values_out(rt_dense_values_in.dtype());
128       bool shapes_match = sparse_values_out.CopyFrom(
129           rt_dense_values_in, {rt_dense_values_in.NumElements()});
130       DCHECK(shapes_match);
131       context->set_output(1, sparse_values_out);
132     }
133 
134     // Output the `sparse_dense_shape` Tensor.
135     int64_t ndims = rt_nested_splits_len + rt_dense_values_in.dims();
136     Tensor* sparse_dense_shape_out = nullptr;
137     OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}),
138                                                      &sparse_dense_shape_out));
139     auto sparse_dense_shape = sparse_dense_shape_out->vec<int64_t>();
140     sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1;
141     for (int dim = 0; dim < rt_nested_splits_len; ++dim) {
142       const auto& splits = rt_nested_splits[dim];
143       SPLITS_TYPE max_width = 0;
144       for (int i = 1; i < splits.size(); ++i) {
145         max_width = std::max(max_width, splits(i) - splits(i - 1));
146       }
147       sparse_dense_shape(dim + 1) = max_width;
148     }
149     for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) {
150       sparse_dense_shape(dim + rt_nested_splits_len) =
151           rt_dense_values_in.dim_size(dim);
152     }
153   }
154 
155  private:
156   // Validate `rt_nested_splits` to ensure we don't get any segfaults.
ValidateInputs(std::vector<ConstFlatSplits> rt_nested_splits,const Tensor & rt_dense_values_in)157   static ::tensorflow::Status ValidateInputs(
158       std::vector<ConstFlatSplits> rt_nested_splits,
159       const Tensor& rt_dense_values_in) {
160     for (int i = 0; i < rt_nested_splits.size(); ++i) {
161       if (rt_nested_splits[i].size() == 0) {
162         return InvalidArgument("ragged splits may not be empty.");
163       }
164       if (rt_nested_splits[i](0) != 0) {
165         return InvalidArgument("First value of ragged splits must be 0.");
166       }
167       for (int j = 1; j < rt_nested_splits[i].size(); ++j) {
168         if (rt_nested_splits[i](j) < rt_nested_splits[i](j - 1)) {
169           return InvalidArgument(
170               "Ragged splits should be non decreasing, but we got ",
171               rt_nested_splits[i](j - 1), " followed by ",
172               rt_nested_splits[i](j));
173         }
174       }
175       if (i > 0) {
176         SPLITS_TYPE last_split =
177             rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
178         if (rt_nested_splits[i].size() != last_split + 1) {
179           return InvalidArgument(
180               "Final value of ragged splits must match the length "
181               "the corresponding ragged values.");
182         }
183       }
184     }
185     if (rt_dense_values_in.dim_size(0) !=
186         rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) {
187       return InvalidArgument(
188           "Final value of ragged splits must match the length "
189           "the corresponding ragged values.");
190     }
191     return OkStatus();
192   }
193 
194   // Build a list of index suffixes that should be added for each ragged item,
195   // to encode the indices of dense values in that ragged item.  This basically
196   // just gives a row-major enumeration of all indices in the given tensor
197   // shape, ignoring dim[0] (since that's the dimension that iterates over
198   // values, and we want index suffixes for a single value).  Example:
199   // MakeIndexSuffixes(TensorShape({100, 3, 2})
200   //   --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}}
MakeIndexSuffixes(const TensorShape & values_shape)201   static std::vector<std::vector<int64_t>> MakeIndexSuffixes(
202       const TensorShape& values_shape) {
203     std::vector<std::vector<int64_t>> suffixes{{}};
204     for (int dim = 1; dim < values_shape.dims(); ++dim) {
205       std::vector<std::vector<int64_t>> new_suffixes;
206       for (const auto& suffix : suffixes) {
207         for (int i = 0; i < values_shape.dim_size(dim); ++i) {
208           new_suffixes.push_back(suffix);
209           new_suffixes.back().push_back(i);
210         }
211       }
212       suffixes.swap(new_suffixes);
213     }
214     return suffixes;
215   }
216 
217   // Returns true if the ragged element at pos[dim] is "completed".  A ragged
218   // element is completed if we have already generated indices for all of its
219   // values.
IsCompleted(const std::vector<int64_t> & pos,int dim,const std::vector<ConstFlatSplits> & rt_nested_splits)220   static bool IsCompleted(
221       const std::vector<int64_t>& pos, int dim,
222       const std::vector<ConstFlatSplits>& rt_nested_splits) {
223     int64_t current_child = pos[dim + 1];
224     int64_t limit_child = rt_nested_splits[dim](pos[dim] + 1);
225     return current_child >= limit_child;
226   }
227 };
228 
229 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
230                             .Device(DEVICE_CPU)
231                             .TypeConstraint<int32>("Tsplits"),
232                         RaggedTensorToSparseOp<int32>);
233 
234 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
235                             .Device(DEVICE_CPU)
236                             .TypeConstraint<int64_t>("Tsplits"),
237                         RaggedTensorToSparseOp<int64_t>);
238 
239 }  // namespace tensorflow
240