• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <limits>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 
25 namespace tensorflow {
26 
27 using errors::InvalidArgument;
28 
29 class RaggedTensorToSparseOp : public OpKernel {
30  public:
31   using OpKernel::OpKernel;
32 
Compute(OpKernelContext * context)33   void Compute(OpKernelContext* context) override {
34     // Read the `rt_nested_splits` input & convert to Eigen tensors.
35     OpInputList rt_nested_splits_in;
36     OP_REQUIRES_OK(
37         context, context->input_list("rt_nested_splits", &rt_nested_splits_in));
38     const int64 rt_nested_splits_len = rt_nested_splits_in.size();
39     DCHECK_GT(rt_nested_splits_len, 0);  // Enforced by REGISTER_OP.
40     std::vector<TTypes<int64>::ConstFlat> rt_nested_splits;
41     rt_nested_splits.reserve(rt_nested_splits_len);
42     for (int i = 0; i < rt_nested_splits_len; ++i) {
43       rt_nested_splits.push_back(rt_nested_splits_in[i].flat<int64>());
44     }
45 
46     // Read the `rt_dense_values` input.
47     const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len);
48     OP_REQUIRES_OK(context,
49                    ValidateInputs(rt_nested_splits, rt_dense_values_in));
50 
51     // Assemble each value in `sparse_indices` using three parts:
52     // - `index_prefix` is the index in dimensions up through the last ragged
53     //   dimension.
54     // - `index_middle` is the index in the last ragged dimension.
55     // - `index_suffix` is the index in the dense value dimensions.
56     std::vector<int64> index_prefix(rt_nested_splits_len);
57     std::vector<std::vector<int64>> index_suffixes =
58         MakeIndexSuffixes(rt_dense_values_in.shape());
59 
60     // Allocate the `sparse_indices` output tensor.
61     const int64 nvals =
62         (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) *
63          index_suffixes.size());
64     const int64 indices_len = rt_nested_splits_len + rt_dense_values_in.dims();
65     Tensor* sparse_indices_out = nullptr;
66     OP_REQUIRES_OK(
67         context, context->allocate_output(0, TensorShape({nvals, indices_len}),
68                                           &sparse_indices_out));
69     auto sparse_indices = sparse_indices_out->tensor<int64, 2>();
70 
71     // pos[i] is the current position in rt_nested_splits[i].  final_pos is a
72     // reference to make it easier to refer to pos[-1].
73     std::vector<int64> pos(rt_nested_splits_len);
74     int64& final_pos = pos[rt_nested_splits_len - 1];
75 
76     // Each iteration through the loop, we increment pos[-1], and add indices
77     // for all the values corresponding to
78     // rt_nested_splits[-1][pos[-1]:pos[-1]+1].
79     int next_index = 0;
80     int max_final_pos = rt_nested_splits.back().size() - 1;
81     for (; final_pos < max_final_pos; ++final_pos) {
82       // Update `pos` to skip over completed elements (i.e., elements where
83       // we have already generated indices for all contained values).
84       for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) {
85         while (IsCompleted(pos, dim, rt_nested_splits)) {
86           pos[dim] += 1;
87         }
88       }
89 
90       // Update index_prefix.
91       for (int dim = 0; dim < index_prefix.size(); ++dim) {
92         int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0;
93         index_prefix[dim] = pos[dim] - start;
94       }
95 
96       // Get length of the final-ragged-dimension slice.
97       const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1];
98       int64 slice_len = final_splits(final_pos + 1) - final_splits(final_pos);
99 
100       // Add sparse_indices for this slice.
101       for (int64 i = 0; i < slice_len; ++i) {
102         for (const auto& index_suffix : index_suffixes) {
103           int dim = 0;
104           for (int64 index : index_prefix) {  // index_prefix
105             sparse_indices(next_index, dim++) = index;
106           }
107           sparse_indices(next_index, dim++) = i;  // index_middle
108           for (int64 index : index_suffix) {      // index_suffix
109             sparse_indices(next_index, dim++) = index;
110           }
111           DCHECK_EQ(dim, indices_len);
112           ++next_index;
113         }
114       }
115     }
116     DCHECK_EQ(next_index, nvals);
117 
118     // Output the `sparse_values` Tensor.
119     if (rt_dense_values_in.dims() == 1) {
120       context->set_output(1, rt_dense_values_in);
121     } else {
122       Tensor sparse_values_out(rt_dense_values_in.dtype());
123       bool shapes_match = sparse_values_out.CopyFrom(
124           rt_dense_values_in, {rt_dense_values_in.NumElements()});
125       DCHECK(shapes_match);
126       context->set_output(1, sparse_values_out);
127     }
128 
129     // Output the `sparse_dense_shape` Tensor.
130     int64 ndims = rt_nested_splits_len + rt_dense_values_in.dims();
131     Tensor* sparse_dense_shape_out = nullptr;
132     OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}),
133                                                      &sparse_dense_shape_out));
134     auto sparse_dense_shape = sparse_dense_shape_out->vec<int64>();
135     sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1;
136     for (int dim = 0; dim < rt_nested_splits_len; ++dim) {
137       const auto& splits = rt_nested_splits[dim];
138       int64 max_width = 0;
139       for (int i = 1; i < splits.size(); ++i) {
140         max_width = std::max(max_width, splits(i) - splits(i - 1));
141       }
142       sparse_dense_shape(dim + 1) = max_width;
143     }
144     for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) {
145       sparse_dense_shape(dim + rt_nested_splits_len) =
146           rt_dense_values_in.dim_size(dim);
147     }
148   }
149 
150  private:
151   // Validate `rt_nested_splits` to ensure we don't get any segfaults.
ValidateInputs(std::vector<TTypes<int64>::ConstFlat> rt_nested_splits,const Tensor & rt_dense_values_in)152   static ::tensorflow::Status ValidateInputs(
153       std::vector<TTypes<int64>::ConstFlat> rt_nested_splits,
154       const Tensor& rt_dense_values_in) {
155     for (int i = 0; i < rt_nested_splits.size(); ++i) {
156       if (rt_nested_splits[i].size() == 0) {
157         return InvalidArgument("ragged splits may not be empty.");
158       }
159       if (rt_nested_splits[i](0) != 0) {
160         return InvalidArgument("First value of ragged splits must be 0.");
161       }
162       if (i > 0) {
163         int64 last_split =
164             rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1);
165         if (rt_nested_splits[i].size() != last_split + 1) {
166           return InvalidArgument(
167               "Final value of ragged splits must match the length "
168               "the corresponding ragged values.");
169         }
170       }
171     }
172     if (rt_dense_values_in.dim_size(0) !=
173         rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) {
174       return InvalidArgument(
175           "Final value of ragged splits must match the length "
176           "the corresponding ragged values.");
177     }
178     return ::tensorflow::Status::OK();
179   }
180 
181   // Build a list of index suffixes that should be added for each ragged item,
182   // to encode the indices of dense values in that ragged item.  This basically
183   // just gives a row-major enumeration of all indices in the given tensor
184   // shape, ignoring dim[0] (since that's the dimension that iterates over
185   // values, and we want index suffixes for a single value).  Example:
186   // MakeIndexSuffixes(TensorShape({100, 3, 2})
187   //   --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}}
MakeIndexSuffixes(const TensorShape & values_shape)188   static std::vector<std::vector<int64>> MakeIndexSuffixes(
189       const TensorShape& values_shape) {
190     std::vector<std::vector<int64>> suffixes{{}};
191     for (int dim = 1; dim < values_shape.dims(); ++dim) {
192       std::vector<std::vector<int64>> new_suffixes;
193       for (const auto& suffix : suffixes) {
194         for (int i = 0; i < values_shape.dim_size(dim); ++i) {
195           new_suffixes.push_back(suffix);
196           new_suffixes.back().push_back(i);
197         }
198       }
199       suffixes.swap(new_suffixes);
200     }
201     return suffixes;
202   }
203 
204   // Returns true if the ragged element at pos[dim] is "completed".  A ragged
205   // element is completed if we have already generated indices for all of its
206   // values.
IsCompleted(const std::vector<int64> & pos,int dim,const std::vector<TTypes<int64>::ConstFlat> & rt_nested_splits)207   static bool IsCompleted(
208       const std::vector<int64>& pos, int dim,
209       const std::vector<TTypes<int64>::ConstFlat>& rt_nested_splits) {
210     int64 current_child = pos[dim + 1];
211     int64 limit_child = rt_nested_splits[dim](pos[dim] + 1);
212     return current_child >= limit_child;
213   }
214 };
215 
216 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse").Device(DEVICE_CPU),
217                         RaggedTensorToSparseOp);
218 
219 }  // namespace tensorflow
220