1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <limits> 16 #include <memory> 17 #include <string> 18 #include <vector> 19 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/register_types.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 25 namespace tensorflow { 26 27 using errors::InvalidArgument; 28 29 class RaggedTensorToSparseOp : public OpKernel { 30 public: 31 using OpKernel::OpKernel; 32 Compute(OpKernelContext * context)33 void Compute(OpKernelContext* context) override { 34 // Read the `rt_nested_splits` input & convert to Eigen tensors. 35 OpInputList rt_nested_splits_in; 36 OP_REQUIRES_OK( 37 context, context->input_list("rt_nested_splits", &rt_nested_splits_in)); 38 const int64 rt_nested_splits_len = rt_nested_splits_in.size(); 39 DCHECK_GT(rt_nested_splits_len, 0); // Enforced by REGISTER_OP. 40 std::vector<TTypes<int64>::ConstFlat> rt_nested_splits; 41 rt_nested_splits.reserve(rt_nested_splits_len); 42 for (int i = 0; i < rt_nested_splits_len; ++i) { 43 rt_nested_splits.push_back(rt_nested_splits_in[i].flat<int64>()); 44 } 45 46 // Read the `rt_dense_values` input. 47 const Tensor& rt_dense_values_in = context->input(rt_nested_splits_len); 48 OP_REQUIRES_OK(context, 49 ValidateInputs(rt_nested_splits, rt_dense_values_in)); 50 51 // Assemble each value in `sparse_indices` using three parts: 52 // - `index_prefix` is the index in dimensions up through the last ragged 53 // dimension. 54 // - `index_middle` is the index in the last ragged dimension. 55 // - `index_suffix` is the index in the dense value dimensions. 56 std::vector<int64> index_prefix(rt_nested_splits_len); 57 std::vector<std::vector<int64>> index_suffixes = 58 MakeIndexSuffixes(rt_dense_values_in.shape()); 59 60 // Allocate the `sparse_indices` output tensor. 61 const int64 nvals = 62 (rt_nested_splits.back()(rt_nested_splits.back().size() - 1) * 63 index_suffixes.size()); 64 const int64 indices_len = rt_nested_splits_len + rt_dense_values_in.dims(); 65 Tensor* sparse_indices_out = nullptr; 66 OP_REQUIRES_OK( 67 context, context->allocate_output(0, TensorShape({nvals, indices_len}), 68 &sparse_indices_out)); 69 auto sparse_indices = sparse_indices_out->tensor<int64, 2>(); 70 71 // pos[i] is the current position in rt_nested_splits[i]. final_pos is a 72 // reference to make it easier to refer to pos[-1]. 73 std::vector<int64> pos(rt_nested_splits_len); 74 int64& final_pos = pos[rt_nested_splits_len - 1]; 75 76 // Each iteration through the loop, we increment pos[-1], and add indices 77 // for all the values corresponding to 78 // rt_nested_splits[-1][pos[-1]:pos[-1]+1]. 79 int next_index = 0; 80 int max_final_pos = rt_nested_splits.back().size() - 1; 81 for (; final_pos < max_final_pos; ++final_pos) { 82 // Update `pos` to skip over completed elements (i.e., elements where 83 // we have already generated indices for all contained values). 84 for (int dim = rt_nested_splits_len - 2; dim >= 0; --dim) { 85 while (IsCompleted(pos, dim, rt_nested_splits)) { 86 pos[dim] += 1; 87 } 88 } 89 90 // Update index_prefix. 91 for (int dim = 0; dim < index_prefix.size(); ++dim) { 92 int start = dim > 0 ? rt_nested_splits[dim - 1](pos[dim - 1]) : 0; 93 index_prefix[dim] = pos[dim] - start; 94 } 95 96 // Get length of the final-ragged-dimension slice. 97 const auto& final_splits = rt_nested_splits[rt_nested_splits_len - 1]; 98 int64 slice_len = final_splits(final_pos + 1) - final_splits(final_pos); 99 100 // Add sparse_indices for this slice. 101 for (int64 i = 0; i < slice_len; ++i) { 102 for (const auto& index_suffix : index_suffixes) { 103 int dim = 0; 104 for (int64 index : index_prefix) { // index_prefix 105 sparse_indices(next_index, dim++) = index; 106 } 107 sparse_indices(next_index, dim++) = i; // index_middle 108 for (int64 index : index_suffix) { // index_suffix 109 sparse_indices(next_index, dim++) = index; 110 } 111 DCHECK_EQ(dim, indices_len); 112 ++next_index; 113 } 114 } 115 } 116 DCHECK_EQ(next_index, nvals); 117 118 // Output the `sparse_values` Tensor. 119 if (rt_dense_values_in.dims() == 1) { 120 context->set_output(1, rt_dense_values_in); 121 } else { 122 Tensor sparse_values_out(rt_dense_values_in.dtype()); 123 bool shapes_match = sparse_values_out.CopyFrom( 124 rt_dense_values_in, {rt_dense_values_in.NumElements()}); 125 DCHECK(shapes_match); 126 context->set_output(1, sparse_values_out); 127 } 128 129 // Output the `sparse_dense_shape` Tensor. 130 int64 ndims = rt_nested_splits_len + rt_dense_values_in.dims(); 131 Tensor* sparse_dense_shape_out = nullptr; 132 OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({ndims}), 133 &sparse_dense_shape_out)); 134 auto sparse_dense_shape = sparse_dense_shape_out->vec<int64>(); 135 sparse_dense_shape(0) = rt_nested_splits_in[0].dim_size(0) - 1; 136 for (int dim = 0; dim < rt_nested_splits_len; ++dim) { 137 const auto& splits = rt_nested_splits[dim]; 138 int64 max_width = 0; 139 for (int i = 1; i < splits.size(); ++i) { 140 max_width = std::max(max_width, splits(i) - splits(i - 1)); 141 } 142 sparse_dense_shape(dim + 1) = max_width; 143 } 144 for (int dim = 1; dim < rt_dense_values_in.dims(); ++dim) { 145 sparse_dense_shape(dim + rt_nested_splits_len) = 146 rt_dense_values_in.dim_size(dim); 147 } 148 } 149 150 private: 151 // Validate `rt_nested_splits` to ensure we don't get any segfaults. ValidateInputs(std::vector<TTypes<int64>::ConstFlat> rt_nested_splits,const Tensor & rt_dense_values_in)152 static ::tensorflow::Status ValidateInputs( 153 std::vector<TTypes<int64>::ConstFlat> rt_nested_splits, 154 const Tensor& rt_dense_values_in) { 155 for (int i = 0; i < rt_nested_splits.size(); ++i) { 156 if (rt_nested_splits[i].size() == 0) { 157 return InvalidArgument("ragged splits may not be empty."); 158 } 159 if (rt_nested_splits[i](0) != 0) { 160 return InvalidArgument("First value of ragged splits must be 0."); 161 } 162 if (i > 0) { 163 int64 last_split = 164 rt_nested_splits[i - 1](rt_nested_splits[i - 1].size() - 1); 165 if (rt_nested_splits[i].size() != last_split + 1) { 166 return InvalidArgument( 167 "Final value of ragged splits must match the length " 168 "the corresponding ragged values."); 169 } 170 } 171 } 172 if (rt_dense_values_in.dim_size(0) != 173 rt_nested_splits.back()(rt_nested_splits.back().size() - 1)) { 174 return InvalidArgument( 175 "Final value of ragged splits must match the length " 176 "the corresponding ragged values."); 177 } 178 return ::tensorflow::Status::OK(); 179 } 180 181 // Build a list of index suffixes that should be added for each ragged item, 182 // to encode the indices of dense values in that ragged item. This basically 183 // just gives a row-major enumeration of all indices in the given tensor 184 // shape, ignoring dim[0] (since that's the dimension that iterates over 185 // values, and we want index suffixes for a single value). Example: 186 // MakeIndexSuffixes(TensorShape({100, 3, 2}) 187 // --> {{0, 0}, {0, 1}, {1, 0}, {1, 1}, {2, 0}, {2, 1}} MakeIndexSuffixes(const TensorShape & values_shape)188 static std::vector<std::vector<int64>> MakeIndexSuffixes( 189 const TensorShape& values_shape) { 190 std::vector<std::vector<int64>> suffixes{{}}; 191 for (int dim = 1; dim < values_shape.dims(); ++dim) { 192 std::vector<std::vector<int64>> new_suffixes; 193 for (const auto& suffix : suffixes) { 194 for (int i = 0; i < values_shape.dim_size(dim); ++i) { 195 new_suffixes.push_back(suffix); 196 new_suffixes.back().push_back(i); 197 } 198 } 199 suffixes.swap(new_suffixes); 200 } 201 return suffixes; 202 } 203 204 // Returns true if the ragged element at pos[dim] is "completed". A ragged 205 // element is completed if we have already generated indices for all of its 206 // values. IsCompleted(const std::vector<int64> & pos,int dim,const std::vector<TTypes<int64>::ConstFlat> & rt_nested_splits)207 static bool IsCompleted( 208 const std::vector<int64>& pos, int dim, 209 const std::vector<TTypes<int64>::ConstFlat>& rt_nested_splits) { 210 int64 current_child = pos[dim + 1]; 211 int64 limit_child = rt_nested_splits[dim](pos[dim] + 1); 212 return current_child >= limit_child; 213 } 214 }; 215 216 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToSparse").Device(DEVICE_CPU), 217 RaggedTensorToSparseOp); 218 219 } // namespace tensorflow 220