1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <limits> 16 #include <memory> 17 #include <string> 18 #include <vector> 19 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 25 namespace tensorflow { 26 27 using errors::InvalidArgument; 28 29 template <typename SPLITS_TYPE> 30 class RaggedTensorToSparseOp : public OpKernel { 31 public: 32 using OpKernel::OpKernel; 33 using ConstFlatSplits = typename TTypes<SPLITS_TYPE>::ConstFlat; 34 Compute(OpKernelContext * context)35 void Compute(OpKernelContext* context) override { 36 // Read the `rt_nested_splits` input & convert to Eigen tensors. 37 OpInputList rt_nested_splits_in; 38 OP_REQUIRES_OK( 39 context, context->input_list("rt_nested_splits", &rt_nested_splits_in)); 40 const int rt_nested_splits_len = rt_nested_splits_in.size(); 41 DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP. 42 std::vector<ConstFlatSplits> rt_nested_splits; 43 rt_nested_splits.reserve(rt_nested_splits_len); 44 for (int i = 0; i < rt_nested_splits_len; ++i) { 45 rt_nested_splits.push_back(rt_nested_splits_in[i].flat<SPLITS_TYPE>()); 46 } 47 48 // Read the `rt_dense_values` input. 49 const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len); 50 OP_REQUIRES_OK(context, 51 ValidateInputs(rt_nested_splits, rt_dense_values_in)); 52 53 // Assemble each value in `sparse_indices` using three parts: 54 // - `index_prefix` is the index in dimensions up through the last ragged 55 // dimension. 56 // - `index_middle` is the index in the last ragged dimension. 57 // - `index_suffix` is the index in the dense value dimensions. 58 std::vector<int64> index_prefix(rt_nested_splits_len); 59 std::vector<std::vector<int64>> index_suffixes = 60 MakeIndexSuffixes(rt_dense_values_in.shape()); 61 62 // Allocate the `sparse_indices` output tensor. 63 const int64 nvals = 64 (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) * 65 index_suffixes.size()); 66 const int64 indices_len = rt_nested_splits_len + rt_dense_values_in.dims(); 67 Tensor* sparse_indices_out = nullptr; 68 OP_REQUIRES_OK( 69 context, context->allocate_output(0, TensorShape({nvals, indices_len}), 70 &sparse_indices_out)); 71 auto sparse_indices = sparse_indices_out->tensor<int64, 2>(); 72 73 // pos[i] is the current position in rt_nested_splits[i]. final_pos is a 74 // reference to make it easier to refer to pos[-1]. 75 std::vector<int64> pos(rt_nested_splits_len); 76 int64& final_pos = pos[rt_nested_splits_len - 1]; 77 78 // Each iteration through the loop, we increment pos[-1], and add indices 79 // for all the values corresponding to 80 // rt_nested_splits[-1][pos[-1]:pos[-1]+1]. 81 int next_index = 0; 82 int max_final_pos = rt_nested_splits.back().size() - 1; 83 for (; final_pos < max_final_pos; ++final_pos) { 84 // Update `pos` to skip over completed elements (i.e., elements where 85 // we have already generated indices for all contained values). 86 for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) { 87 while (IsCompleted(pos, dim, rt_nested_splits)) { 88 pos[dim] += 1; 89 } 90 } 91 92 // Update index_prefix. 93 for (int dim = 0; dim < index_prefix.size(); ++dim) { 94 int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0; 95 index_prefix[dim] = pos[dim] - start; 96 } 97 98 // Get length of the final-ragged-dimension slice. 99 const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1]; 100 int64 slice_len = final_splits(final_pos + 1) - final_splits(final_pos); 101 102 // Add sparse_indices for this slice. 103 for (int64 i = 0; i < slice_len; ++i) { 104 for (const auto& index_suffix : index_suffixes) { 105 int dim = 0; 106 for (int64 index : index_prefix) { // index_prefix 107 sparse_indices(next_index, dim++) = index; 108 } 109 sparse_indices(next_index, dim++) = i; // index_middle 110 for (int64 index : index_suffix) { // index_suffix 111 sparse_indices(next_index, dim++) = index; 112 } 113 DCHECK_EQ(dim, indices_len); 114 ++next_index; 115 } 116 } 117 } 118 DCHECK_EQ(next_index, nvals); 119 120 // Output the `sparse_values` Tensor. 121 if (rt_dense_values_in.dims() == 1) { 122 context->set_output(1, rt_dense_values_in); 123 } else { 124 Tensor sparse_values_out(rt_dense_values_in.dtype()); 125 bool shapes_match = sparse_values_out.CopyFrom( 126 rt_dense_values_in, {rt_dense_values_in.NumElements()}); 127 DCHECK(shapes_match); 128 context->set_output(1, sparse_values_out); 129 } 130 131 // Output the `sparse_dense_shape` Tensor. 132 int64 ndims = rt_nested_splits_len + rt_dense_values_in.dims(); 133 Tensor* sparse_dense_shape_out = nullptr; 134 OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}), 135 &sparse_dense_shape_out)); 136 auto sparse_dense_shape = sparse_dense_shape_out->vec<int64>(); 137 sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1; 138 for (int dim = 0; dim < rt_nested_splits_len; ++dim) { 139 const auto& splits = rt_nested_splits[dim]; 140 SPLITS_TYPE max_width = 0; 141 for (int i = 1; i < splits.size(); ++i) { 142 max_width = std::max(max_width, splits(i) - splits(i - 1)); 143 } 144 sparse_dense_shape(dim + 1) = max_width; 145 } 146 for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) { 147 sparse_dense_shape(dim + rt_nested_splits_len) = 148 rt_dense_values_in.dim_size(dim); 149 } 150 } 151 152 private: 153 // Validate `rt_nested_splits` to ensure we don't get any segfaults. ValidateInputs(std::vector<ConstFlatSplits> rt_nested_splits,const Tensor & rt_dense_values_in)154 static ::tensorflow::Status ValidateInputs( 155 std::vector<ConstFlatSplits> rt_nested_splits, 156 const Tensor& rt_dense_values_in) { 157 for (int i = 0; i < rt_nested_splits.size(); ++i) { 158 if (rt_nested_splits[i].size() == 0) { 159 return InvalidArgument("ragged splits may not be empty."); 160 } 161 if (rt_nested_splits[i](0) != 0) { 162 return InvalidArgument("First value of ragged splits must be 0."); 163 } 164 if (i > 0) { 165 SPLITS_TYPE last_split = 166 rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1); 167 if (rt_nested_splits[i].size() != last_split + 1) { 168 return InvalidArgument( 169 "Final value of ragged splits must match the length " 170 "the corresponding ragged values."); 171 } 172 } 173 } 174 if (rt_dense_values_in.dim_size(0) != 175 rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) { 176 return InvalidArgument( 177 "Final value of ragged splits must match the length " 178 "the corresponding ragged values."); 179 } 180 return ::tensorflow::Status::OK(); 181 } 182 183 // Build a list of index suffixes that should be added for each ragged item, 184 // to encode the indices of dense values in that ragged item. This basically 185 // just gives a row-major enumeration of all indices in the given tensor 186 // shape, ignoring dim[0] (since that's the dimension that iterates over 187 // values, and we want index suffixes for a single value). Example: 188 // MakeIndexSuffixes(TensorShape({100, 3, 2}) 189 // --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}} MakeIndexSuffixes(const TensorShape & values_shape)190 static std::vector<std::vector<int64>> MakeIndexSuffixes( 191 const TensorShape& values_shape) { 192 std::vector<std::vector<int64>> suffixes{{}}; 193 for (int dim = 1; dim < values_shape.dims(); ++dim) { 194 std::vector<std::vector<int64>> new_suffixes; 195 for (const auto& suffix : suffixes) { 196 for (int i = 0; i < values_shape.dim_size(dim); ++i) { 197 new_suffixes.push_back(suffix); 198 new_suffixes.back().push_back(i); 199 } 200 } 201 suffixes.swap(new_suffixes); 202 } 203 return suffixes; 204 } 205 206 // Returns true if the ragged element at pos[dim] is "completed". A ragged 207 // element is completed if we have already generated indices for all of its 208 // values. IsCompleted(const std::vector<int64> & pos,int dim,const std::vector<ConstFlatSplits> & rt_nested_splits)209 static bool IsCompleted( 210 const std::vector<int64>& pos, int dim, 211 const std::vector<ConstFlatSplits>& rt_nested_splits) { 212 int64 current_child = pos[dim + 1]; 213 int64 limit_child = rt_nested_splits[dim](pos[dim] + 1); 214 return current_child >= limit_child; 215 } 216 }; 217 218 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse") 219 .Device(DEVICE_CPU) 220 .TypeConstraint<int32>("Tsplits"), 221 RaggedTensorToSparseOp<int32>); 222 223 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse") 224 .Device(DEVICE_CPU) 225 .TypeConstraint<int64>("Tsplits"), 226 RaggedTensorToSparseOp<int64>); 227 228 } // namespace tensorflow 229