1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <limits> 16 #include <memory> 17 #include <string> 18 #include <vector> 19 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/platform/errors.h" 25 26 namespace tensorflow { 27 28 using errors::InvalidArgument; 29 30 template <typename SPLITS_TYPE> 31 class RaggedTensorToSparseOp : public OpKernel { 32 public: 33 using OpKernel::OpKernel; 34 using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat; 35 Compute(OpKernelContext * context)36 void Compute(OpKernelContext* context) override { 37 // Read the `rt_nested_splits` input & convert to Eigen tensors. 38 OpInputList rt_nested_splits_in; 39 OP_REQUIRES_OK( 40 context, context->input_list("rt_nested_splits", &rt_nested_splits_in)); 41 const int rt_nested_splits_len = rt_nested_splits_in.size(); 42 OP_REQUIRES(context, rt_nested_splits_len > 0, 43 errors::InvalidArgument("rt_nested_splits must be non empty")); 44 std::vector<ConstFlatSplits> rt_nested_splits; 45 rt_nested_splits.reserve(rt_nested_splits_len); 46 for (int i = 0; i < rt_nested_splits_len; ++i) { 47 rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>()); 48 } 49 50 // Read the `rt_dense_values` input. 51 const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len); 52 OP_REQUIRES_OK(context, 53 ValidateInputs(rt_nested_splits, rt_dense_values_in)); 54 55 // Assemble each value in `sparse_indices` using three parts: 56 // - `index_prefix` is the index in dimensions up through the last ragged 57 // dimension. 58 // - `index_middle` is the index in the last ragged dimension. 59 // - `index_suffix` is the index in the dense value dimensions. 60 std::vector<int64_t> index_prefix(rt_nested_splits_len); 61 std::vector<std::vector<int64_t>> index_suffixes = 62 MakeIndexSuffixes(rt_dense_values_in.shape()); 63 64 // Allocate the `sparse_indices` output tensor. 65 const int64_t nvals = 66 (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) * 67 index_suffixes.size()); 68 const int64_t indices_len = 69 rt_nested_splits_len + rt_dense_values_in.dims(); 70 Tensor* sparse_indices_out = nullptr; 71 OP_REQUIRES_OK( 72 context, context->allocate_output(0, TensorShape({nvals, indices_len}), 73 &sparse_indices_out)); 74 auto sparse_indices = sparse_indices_out->tensor<int64_t, 2>(); 75 76 // pos[i] is the current position in rt_nested_splits[i]. final_pos is a 77 // reference to make it easier to refer to pos[-1]. 78 std::vector<int64_t> pos(rt_nested_splits_len); 79 int64_t& final_pos = pos[rt_nested_splits_len - 1]; 80 81 // Each iteration through the loop, we increment pos[-1], and add indices 82 // for all the values corresponding to 83 // rt_nested_splits[-1][pos[-1]:pos[-1]+1]. 84 int next_index = 0; 85 int max_final_pos = rt_nested_splits.back().size() - 1; 86 for (; final_pos < max_final_pos; ++final_pos) { 87 // Update `pos` to skip over completed elements (i.e., elements where 88 // we have already generated indices for all contained values). 89 for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) { 90 while (IsCompleted(pos, dim, rt_nested_splits)) { 91 pos[dim] += 1; 92 } 93 } 94 95 // Update index_prefix. 96 for (int dim = 0; dim < index_prefix.size(); ++dim) { 97 int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0; 98 index_prefix[dim] = pos[dim] - start; 99 } 100 101 // Get length of the final-ragged-dimension slice. 102 const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1]; 103 int64_t slice_len = final_splits(final_pos + 1) - final_splits(final_pos); 104 105 // Add sparse_indices for this slice. 106 for (int64_t i = 0; i < slice_len; ++i) { 107 for (const auto& index_suffix : index_suffixes) { 108 int dim = 0; 109 for (int64_t index : index_prefix) { // index_prefix 110 sparse_indices(next_index, dim++) = index; 111 } 112 sparse_indices(next_index, dim++) = i; // index_middle 113 for (int64_t index : index_suffix) { // index_suffix 114 sparse_indices(next_index, dim++) = index; 115 } 116 DCHECK_EQ(dim, indices_len); 117 ++next_index; 118 } 119 } 120 } 121 DCHECK_EQ(next_index, nvals); 122 123 // Output the `sparse_values` Tensor. 124 if (rt_dense_values_in.dims() == 1) { 125 context->set_output(1, rt_dense_values_in); 126 } else { 127 Tensor sparse_values_out(rt_dense_values_in.dtype()); 128 bool shapes_match = sparse_values_out.CopyFrom( 129 rt_dense_values_in, {rt_dense_values_in.NumElements()}); 130 DCHECK(shapes_match); 131 context->set_output(1, sparse_values_out); 132 } 133 134 // Output the `sparse_dense_shape` Tensor. 135 int64_t ndims = rt_nested_splits_len + rt_dense_values_in.dims(); 136 Tensor* sparse_dense_shape_out = nullptr; 137 OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}), 138 &sparse_dense_shape_out)); 139 auto sparse_dense_shape = sparse_dense_shape_out->vec<int64_t>(); 140 sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1; 141 for (int dim = 0; dim < rt_nested_splits_len; ++dim) { 142 const auto& splits = rt_nested_splits[dim]; 143 SPLITS_TYPE max_width = 0; 144 for (int i = 1; i < splits.size(); ++i) { 145 max_width = std::max(max_width, splits(i) - splits(i - 1)); 146 } 147 sparse_dense_shape(dim + 1) = max_width; 148 } 149 for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) { 150 sparse_dense_shape(dim + rt_nested_splits_len) = 151 rt_dense_values_in.dim_size(dim); 152 } 153 } 154 155 private: 156 // Validate `rt_nested_splits` to ensure we don't get any segfaults. ValidateInputs(std::vector<ConstFlatSplits> rt_nested_splits,const Tensor & rt_dense_values_in)157 static ::tensorflow::Status ValidateInputs( 158 std::vector<ConstFlatSplits> rt_nested_splits, 159 const Tensor& rt_dense_values_in) { 160 for (int i = 0; i < rt_nested_splits.size(); ++i) { 161 if (rt_nested_splits[i].size() == 0) { 162 return InvalidArgument("ragged splits may not be empty."); 163 } 164 if (rt_nested_splits[i](0) != 0) { 165 return InvalidArgument("First value of ragged splits must be 0."); 166 } 167 for (int j = 1; j < rt_nested_splits[i].size(); ++j) { 168 if (rt_nested_splits[i](j) < rt_nested_splits[i](j - 1)) { 169 return InvalidArgument( 170 "Ragged splits should be non decreasing, but we got ", 171 rt_nested_splits[i](j - 1), " followed by ", 172 rt_nested_splits[i](j)); 173 } 174 } 175 if (i > 0) { 176 SPLITS_TYPE last_split = 177 rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1); 178 if (rt_nested_splits[i].size() != last_split + 1) { 179 return InvalidArgument( 180 "Final value of ragged splits must match the length " 181 "the corresponding ragged values."); 182 } 183 } 184 } 185 if (rt_dense_values_in.dim_size(0) != 186 rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) { 187 return InvalidArgument( 188 "Final value of ragged splits must match the length " 189 "the corresponding ragged values."); 190 } 191 return OkStatus(); 192 } 193 194 // Build a list of index suffixes that should be added for each ragged item, 195 // to encode the indices of dense values in that ragged item. This basically 196 // just gives a row-major enumeration of all indices in the given tensor 197 // shape, ignoring dim[0] (since that's the dimension that iterates over 198 // values, and we want index suffixes for a single value). Example: 199 // MakeIndexSuffixes(TensorShape({100, 3, 2}) 200 // --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}} MakeIndexSuffixes(const TensorShape & values_shape)201 static std::vector<std::vector<int64_t>> MakeIndexSuffixes( 202 const TensorShape& values_shape) { 203 std::vector<std::vector<int64_t>> suffixes{{}}; 204 for (int dim = 1; dim < values_shape.dims(); ++dim) { 205 std::vector<std::vector<int64_t>> new_suffixes; 206 for (const auto& suffix : suffixes) { 207 for (int i = 0; i < values_shape.dim_size(dim); ++i) { 208 new_suffixes.push_back(suffix); 209 new_suffixes.back().push_back(i); 210 } 211 } 212 suffixes.swap(new_suffixes); 213 } 214 return suffixes; 215 } 216 217 // Returns true if the ragged element at pos[dim] is "completed". A ragged 218 // element is completed if we have already generated indices for all of its 219 // values. IsCompleted(const std::vector<int64_t> & pos,int dim,const std::vector<ConstFlatSplits> & rt_nested_splits)220 static bool IsCompleted( 221 const std::vector<int64_t>& pos, int dim, 222 const std::vector<ConstFlatSplits>& rt_nested_splits) { 223 int64_t current_child = pos[dim + 1]; 224 int64_t limit_child = rt_nested_splits[dim](pos[dim] + 1); 225 return current_child >= limit_child; 226 } 227 }; 228 229 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse") 230 .Device(DEVICE_CPU) 231 .TypeConstraint<int32>("Tsplits"), 232 RaggedTensorToSparseOp<int32>); 233 234 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse") 235 .Device(DEVICE_CPU) 236 .TypeConstraint<int64_t>("Tsplits"), 237 RaggedTensorToSparseOp<int64_t>); 238 239 } // namespace tensorflow 240