• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <limits>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 
25 namespace tensorflow {
26 
27 using errors::InvalidArgument;
28 
29 template <typename SPLITS_TYPE>
30 class RaggedTensorToSparseOp : public OpKernel {
31  public:
32   using OpKernel::OpKernel;
33   using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat;
34 
Compute(OpKernelContext * context)35   void Compute(OpKernelContext* context) override {
36     // Read the `rt_nested_splits` input & convert to Eigen tensors.
37     OpInputList rt_nested_splits_in;
38     OP_REQUIRES_OK(
39         context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
40     const int rt_nested_splits_len = rt_nested_splits_in.size();
41     DCHECK_GT(rt_nested_splits_len, 0);  // Enforced by REGISTER_OP.
42     std::vector<ConstFlatSplits> rt_nested_splits;
43     rt_nested_splits.reserve(rt_nested_splits_len);
44     for (int i = 0; i < rt_nested_splits_len; ++i) {
45       rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>());
46     }
47 
48     // Read the `rt_dense_values` input.
49     const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len);
50     OP_REQUIRES_OK(context,
51                    ValidateInputs(rt_nested_splits, rt_dense_values_in));
52 
53     // Assemble each value in `sparse_indices` using three parts:
54     // - `index_prefix` is the index in dimensions up through the last ragged
55     //   dimension.
56     // - `index_middle` is the index in the last ragged dimension.
57     // - `index_suffix` is the index in the dense value dimensions.
58     std::vector<int64> index_prefix(rt_nested_splits_len);
59     std::vector<std::vector<int64>> index_suffixes =
60         MakeIndexSuffixes(rt_dense_values_in.shape());
61 
62     // Allocate the `sparse_indices` output tensor.
63     const int64 nvals =
64         (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) *
65          index_suffixes.size());
66     const int64 indices_len = rt_nested_splits_len + rt_dense_values_in.dims();
67     Tensor* sparse_indices_out = nullptr;
68     OP_REQUIRES_OK(
69         context, context->allocate_output(0, TensorShape({nvals, indices_len}),
70                                           &sparse_indices_out));
71     auto sparse_indices = sparse_indices_out->tensor<int64, 2>();
72 
73     // pos[i] is the current position in rt_nested_splits[i].  final_pos is a
74     // reference to make it easier to refer to pos[-1].
75     std::vector<int64> pos(rt_nested_splits_len);
76     int64& final_pos = pos[rt_nested_splits_len - 1];
77 
78     // Each iteration through the loop, we increment pos[-1], and add indices
79     // for all the values corresponding to
80     // rt_nested_splits[-1][pos[-1]:pos[-1]+1].
81     int next_index = 0;
82     int max_final_pos = rt_nested_splits.back().size() - 1;
83     for (; final_pos < max_final_pos; ++final_pos) {
84       // Update `pos` to skip over completed elements (i.e., elements where
85       // we have already generated indices for all contained values).
86       for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) {
87         while (IsCompleted(pos, dim, rt_nested_splits)) {
88           pos[dim] += 1;
89         }
90       }
91 
92       // Update index_prefix.
93       for (int dim = 0; dim < index_prefix.size(); ++dim) {
94         int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0;
95         index_prefix[dim] = pos[dim] - start;
96       }
97 
98       // Get length of the final-ragged-dimension slice.
99       const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1];
100       int64 slice_len = final_splits(final_pos + 1) - final_splits(final_pos);
101 
102       // Add sparse_indices for this slice.
103       for (int64 i = 0; i < slice_len; ++i) {
104         for (const auto& index_suffix : index_suffixes) {
105           int dim = 0;
106           for (int64 index : index_prefix) {  // index_prefix
107             sparse_indices(next_index, dim++) = index;
108           }
109           sparse_indices(next_index, dim++) = i;  // index_middle
110           for (int64 index : index_suffix) {      // index_suffix
111             sparse_indices(next_index, dim++) = index;
112           }
113           DCHECK_EQ(dim, indices_len);
114           ++next_index;
115         }
116       }
117     }
118     DCHECK_EQ(next_index, nvals);
119 
120     // Output the `sparse_values` Tensor.
121     if (rt_dense_values_in.dims() == 1) {
122       context->set_output(1, rt_dense_values_in);
123     } else {
124       Tensor sparse_values_out(rt_dense_values_in.dtype());
125       bool shapes_match = sparse_values_out.CopyFrom(
126           rt_dense_values_in, {rt_dense_values_in.NumElements()});
127       DCHECK(shapes_match);
128       context->set_output(1, sparse_values_out);
129     }
130 
131     // Output the `sparse_dense_shape` Tensor.
132     int64 ndims = rt_nested_splits_len + rt_dense_values_in.dims();
133     Tensor* sparse_dense_shape_out = nullptr;
134     OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}),
135                                                      &sparse_dense_shape_out));
136     auto sparse_dense_shape = sparse_dense_shape_out->vec<int64>();
137     sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1;
138     for (int dim = 0; dim < rt_nested_splits_len; ++dim) {
139       const auto& splits = rt_nested_splits[dim];
140       SPLITS_TYPE max_width = 0;
141       for (int i = 1; i < splits.size(); ++i) {
142         max_width = std::max(max_width, splits(i) - splits(i - 1));
143       }
144       sparse_dense_shape(dim + 1) = max_width;
145     }
146     for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) {
147       sparse_dense_shape(dim + rt_nested_splits_len) =
148           rt_dense_values_in.dim_size(dim);
149     }
150   }
151 
152  private:
153   // Validate `rt_nested_splits` to ensure we don't get any segfaults.
ValidateInputs(std::vector<ConstFlatSplits> rt_nested_splits,const Tensor & rt_dense_values_in)154   static ::tensorflow::Status ValidateInputs(
155       std::vector<ConstFlatSplits> rt_nested_splits,
156       const Tensor& rt_dense_values_in) {
157     for (int i = 0; i < rt_nested_splits.size(); ++i) {
158       if (rt_nested_splits[i].size() == 0) {
159         return InvalidArgument("ragged splits may not be empty.");
160       }
161       if (rt_nested_splits[i](0) != 0) {
162         return InvalidArgument("First value of ragged splits must be 0.");
163       }
164       if (i > 0) {
165         SPLITS_TYPE last_split =
166             rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
167         if (rt_nested_splits[i].size() != last_split + 1) {
168           return InvalidArgument(
169               "Final value of ragged splits must match the length "
170               "the corresponding ragged values.");
171         }
172       }
173     }
174     if (rt_dense_values_in.dim_size(0) !=
175         rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) {
176       return InvalidArgument(
177           "Final value of ragged splits must match the length "
178           "the corresponding ragged values.");
179     }
180     return ::tensorflow::Status::OK();
181   }
182 
183   // Build a list of index suffixes that should be added for each ragged item,
184   // to encode the indices of dense values in that ragged item.  This basically
185   // just gives a row-major enumeration of all indices in the given tensor
186   // shape, ignoring dim[0] (since that's the dimension that iterates over
187   // values, and we want index suffixes for a single value).  Example:
188   // MakeIndexSuffixes(TensorShape({100, 3, 2})
189   //   --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}}
MakeIndexSuffixes(const TensorShape & values_shape)190   static std::vector<std::vector<int64>> MakeIndexSuffixes(
191       const TensorShape& values_shape) {
192     std::vector<std::vector<int64>> suffixes{{}};
193     for (int dim = 1; dim < values_shape.dims(); ++dim) {
194       std::vector<std::vector<int64>> new_suffixes;
195       for (const auto& suffix : suffixes) {
196         for (int i = 0; i < values_shape.dim_size(dim); ++i) {
197           new_suffixes.push_back(suffix);
198           new_suffixes.back().push_back(i);
199         }
200       }
201       suffixes.swap(new_suffixes);
202     }
203     return suffixes;
204   }
205 
206   // Returns true if the ragged element at pos[dim] is "completed".  A ragged
207   // element is completed if we have already generated indices for all of its
208   // values.
IsCompleted(const std::vector<int64> & pos,int dim,const std::vector<ConstFlatSplits> & rt_nested_splits)209   static bool IsCompleted(
210       const std::vector<int64>& pos, int dim,
211       const std::vector<ConstFlatSplits>& rt_nested_splits) {
212     int64 current_child = pos[dim + 1];
213     int64 limit_child = rt_nested_splits[dim](pos[dim] + 1);
214     return current_child >= limit_child;
215   }
216 };
217 
218 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
219                             .Device(DEVICE_CPU)
220                             .TypeConstraint<int32>("Tsplits"),
221                         RaggedTensorToSparseOp<int32>);
222 
223 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse")
224                             .Device(DEVICE_CPU)
225                             .TypeConstraint<int64>("Tsplits"),
226                         RaggedTensorToSparseOp<int64>);
227 
228 }  // namespace tensorflow
229