1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <algorithm> 19 #include <numeric> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_util.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/gtl/inlined_vector.h" 30 #include "tensorflow/core/util/sparse/sparse_tensor.h" 31 32 namespace tensorflow { 33 34 using CPUDevice = Eigen::ThreadPoolDevice; 35 36 template <typename T> 37 class SparseFillEmptyRowsOp : public OpKernel { 38 public: SparseFillEmptyRowsOp(OpKernelConstruction * context)39 explicit SparseFillEmptyRowsOp(OpKernelConstruction* context) 40 : OpKernel(context) {} 41 Compute(OpKernelContext * context)42 void Compute(OpKernelContext* context) override { 43 const Tensor* indices_t; 44 const Tensor* values_t; 45 const Tensor* dense_shape_t; 46 const Tensor* default_value_t; 47 OP_REQUIRES_OK(context, context->input("indices", &indices_t)); 48 OP_REQUIRES_OK(context, context->input("values", &values_t)); 49 OP_REQUIRES_OK(context, context->input("dense_shape", &dense_shape_t)); 50 OP_REQUIRES_OK(context, context->input("default_value", &default_value_t)); 51 52 const CPUDevice& d = context->eigen_device<CPUDevice>(); 53 54 OP_REQUIRES(context, TensorShapeUtils::IsVector(dense_shape_t->shape()), 55 errors::InvalidArgument("dense_shape must be a vector, saw: ", 56 dense_shape_t->shape().DebugString())); 57 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices_t->shape()), 58 errors::InvalidArgument("indices must be a matrix, saw: ", 59 indices_t->shape().DebugString())); 60 OP_REQUIRES(context, TensorShapeUtils::IsVector(values_t->shape()), 61 errors::InvalidArgument("values must be a vector, saw: ", 62 values_t->shape().DebugString())); 63 OP_REQUIRES( 64 context, TensorShapeUtils::IsScalar(default_value_t->shape()), 65 errors::InvalidArgument("default_value must be a scalar, saw: ", 66 default_value_t->shape().DebugString())); 67 // TODO(ebrevdo): add shape checks between values, indices, 68 // dense_shape. Also add check that dense rank > 0. 69 70 const T& default_value = default_value_t->scalar<T>()(); 71 const auto indices = indices_t->matrix<int64>(); 72 const auto values = values_t->vec<T>(); 73 const auto dense_shape = dense_shape_t->vec<int64>(); 74 75 const int64 N = indices_t->shape().dim_size(0); 76 const int64 dense_rows = dense_shape(0); 77 78 Tensor* empty_row_indicator_t; 79 OP_REQUIRES_OK(context, context->allocate_output("empty_row_indicator", 80 TensorShape({dense_rows}), 81 &empty_row_indicator_t)); 82 auto empty_row_indicator = empty_row_indicator_t->vec<bool>(); 83 Tensor* reverse_index_map_t; 84 OP_REQUIRES_OK( 85 context, context->allocate_output("reverse_index_map", TensorShape({N}), 86 &reverse_index_map_t)); 87 auto reverse_index_map = reverse_index_map_t->vec<int64>(); 88 89 int rank = indices_t->shape().dim_size(1); 90 91 if (dense_rows == 0) { 92 OP_REQUIRES( 93 context, N == 0, 94 errors::InvalidArgument("Received SparseTensor with dense_shape[0] = " 95 "0 but indices.shape[0] = ", 96 N)); 97 Tensor* output_indices_t; 98 TensorShape output_indices_shape({0, rank}); 99 OP_REQUIRES_OK(context, context->allocate_output("output_indices", 100 output_indices_shape, 101 &output_indices_t)); 102 Tensor* output_values_t; 103 OP_REQUIRES_OK(context, 104 context->allocate_output("output_values", TensorShape({0}), 105 &output_values_t)); 106 107 // Exit early, nothing more to do. 108 return; 109 } 110 111 Tensor scratch_t; 112 OP_REQUIRES_OK(context, 113 context->allocate_temp(DT_INT64, TensorShape({dense_rows}), 114 &scratch_t)); 115 auto scratch = scratch_t.vec<int64>(); 116 scratch.device(d) = scratch.constant(0); 117 for (int i = 0; i < N; ++i) { 118 const int64 row = indices(i, 0); 119 OP_REQUIRES(context, row >= 0 && row < dense_rows, 120 errors::InvalidArgument("indices(", i, ", 0) is invalid: ", 121 row, " >= ", dense_rows)); 122 ++scratch(indices(i, 0)); 123 } 124 for (int row = 0; row < dense_rows; ++row) { 125 // Scratch here describes the number of elements in this dense row 126 empty_row_indicator(row) = (scratch(row) == 0); 127 // In filled version, each row has at least one element. 128 scratch(row) = std::max(scratch(row), 1LL); 129 // Update scratch to represent the number of elements up to and 130 // including dense_row + 1: 131 // scratch(0) == #{elements of row 0} 132 // scratch(1) == #{elements of row 1} + #{elements of row 0} 133 // .. 134 // scratch(i) == starting index for elements in row i + 1. 135 if (row > 0) { 136 scratch(row) += scratch(row - 1); 137 } 138 } 139 Tensor* output_indices_t; 140 const int64 N_full = scratch(dense_rows - 1); 141 TensorShape output_indices_shape({N_full, rank}); 142 OP_REQUIRES_OK(context, context->allocate_output("output_indices", 143 output_indices_shape, 144 &output_indices_t)); 145 auto output_indices = output_indices_t->matrix<int64>(); 146 output_indices.device(d) = output_indices.constant(0); 147 148 Tensor* output_values_t; 149 OP_REQUIRES_OK( 150 context, context->allocate_output( 151 "output_values", TensorShape({N_full}), &output_values_t)); 152 auto output_values = output_values_t->vec<T>(); 153 output_values.device(d) = output_values.constant(default_value); 154 155 Tensor filled_count_t; 156 OP_REQUIRES_OK(context, 157 context->allocate_temp(DT_INT64, TensorShape({dense_rows}), 158 &filled_count_t)); 159 auto filled_count = filled_count_t.vec<int64>(); 160 filled_count.device(d) = filled_count.constant(0); 161 162 // Fill in values for rows that are not missing 163 for (int64 i = 0; i < N; ++i) { 164 const int64 row = indices(i, 0); 165 int64& offset = filled_count(row); 166 const int64 output_i = ((row == 0) ? 0 : scratch(row - 1)) + offset; 167 offset++; // Increment the filled count for this row. 168 std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0)); 169 output_values(output_i) = values(i); 170 // We'll need this reverse index map to backprop correctly. 171 reverse_index_map(i) = output_i; 172 } 173 174 // Fill in values for rows that are missing 175 for (int64 row = 0; row < dense_rows; ++row) { 176 const int64 row_count = filled_count(row); 177 if (row_count == 0) { // We haven't filled this row 178 const int64 starting_index = (row == 0) ? 0 : scratch(row - 1); 179 // Remaining index values were set to zero already. 180 // The value at this index was set to default_value already. 181 // Just need to set the row index in the right location. 182 output_indices(starting_index, 0) = row; 183 } 184 } 185 } 186 }; 187 188 #define REGISTER_KERNELS(type) \ 189 REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows") \ 190 .Device(DEVICE_CPU) \ 191 .TypeConstraint<type>("T"), \ 192 SparseFillEmptyRowsOp<type>) 193 194 TF_CALL_ALL_TYPES(REGISTER_KERNELS); 195 #undef REGISTER_KERNELS 196 197 template <typename T> 198 class SparseFillEmptyRowsGradOp : public OpKernel { 199 public: SparseFillEmptyRowsGradOp(OpKernelConstruction * context)200 explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context) 201 : OpKernel(context) {} 202 Compute(OpKernelContext * context)203 void Compute(OpKernelContext* context) override { 204 const Tensor* reverse_index_map_t; 205 const Tensor* grad_values_t; 206 OP_REQUIRES_OK(context, 207 context->input("reverse_index_map", &reverse_index_map_t)); 208 OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t)); 209 210 const CPUDevice& d = context->eigen_device<CPUDevice>(); 211 212 OP_REQUIRES( 213 context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()), 214 errors::InvalidArgument("reverse_index_map must be a vector, saw: ", 215 reverse_index_map_t->shape().DebugString())); 216 217 const auto reverse_index_map = reverse_index_map_t->vec<int64>(); 218 const auto grad_values = grad_values_t->vec<T>(); 219 220 const int64 N = reverse_index_map_t->shape().dim_size(0); 221 const int64 N_full = grad_values_t->shape().dim_size(0); 222 223 Tensor* d_values_t; 224 OP_REQUIRES_OK(context, context->allocate_output( 225 "d_values", TensorShape({N}), &d_values_t)); 226 auto d_values = d_values_t->vec<T>(); 227 Tensor* d_default_value_t; 228 OP_REQUIRES_OK(context, 229 context->allocate_output("d_default_value", TensorShape({}), 230 &d_default_value_t)); 231 T& d_default_value = d_default_value_t->scalar<T>()(); 232 d_default_value = T(); 233 234 Tensor visited_t; 235 OP_REQUIRES_OK(context, context->allocate_temp( 236 DT_BOOL, TensorShape({N_full}), &visited_t)); 237 auto visited = visited_t.vec<bool>(); 238 visited.device(d) = visited.constant(false); 239 240 for (int i = 0; i < N; ++i) { 241 // Locate the index of the output of the forward prop associated 242 // with this location in the input of the forward prop. Copy 243 // the gradient into it. Mark it as visited. 244 d_values(i) = grad_values(reverse_index_map(i)); 245 visited(reverse_index_map(i)) = true; 246 } 247 for (int j = 0; j < N_full; ++j) { 248 // The default value gradient gets the accumulated remainder of 249 // the backprop values (since the default value was used to fill 250 // in these slots in the forward calculation). 251 if (!visited(j)) { 252 d_default_value += grad_values(j); 253 } 254 } 255 } 256 }; 257 258 #define REGISTER_KERNELS(type) \ 259 REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \ 260 .Device(DEVICE_CPU) \ 261 .TypeConstraint<type>("T"), \ 262 SparseFillEmptyRowsGradOp<type>) 263 264 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); 265 #undef REGISTER_KERNELS 266 } // namespace tensorflow 267