1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/kernels/sparse_fill_empty_rows_op.h" 19 20 #include <algorithm> 21 #include <numeric> 22 #include <unordered_map> 23 #include <utility> 24 #include <vector> 25 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_util.h" 30 #include "tensorflow/core/framework/types.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/util/sparse/sparse_tensor.h" 33 34 namespace tensorflow { 35 36 using CPUDevice = Eigen::ThreadPoolDevice; 37 38 namespace functor { 39 40 template <typename T, typename Tindex> 41 struct SparseFillEmptyRows<CPUDevice, T, Tindex> { operator ()tensorflow::functor::SparseFillEmptyRows42 Status operator()(OpKernelContext* context, const Tensor& default_value_t, 43 const Tensor& indices_t, const Tensor& values_t, 44 const Tensor& dense_shape_t) { 45 const int kOutputIndicesOutput = 0; 46 const int kOutputValuesOutput = 1; 47 const int kEmptyRowIndicatorOutput = 2; 48 const int kReverseIndexMapOutput = 3; 49 50 const T& default_value = default_value_t.scalar<T>()(); 51 const auto indices = indices_t.matrix<Tindex>(); 52 const auto values = values_t.vec<T>(); 53 const auto dense_shape = dense_shape_t.vec<Tindex>(); 54 55 const Tindex N = indices_t.shape().dim_size(0); 56 const Tindex dense_rows = dense_shape(0); 57 58 bool* empty_row_indicator = nullptr; 59 if (context->output_required(kEmptyRowIndicatorOutput)) { 60 Tensor* empty_row_indicator_t = nullptr; 61 TF_RETURN_IF_ERROR(context->allocate_output(kEmptyRowIndicatorOutput, 62 TensorShape({dense_rows}), 63 &empty_row_indicator_t)); 64 empty_row_indicator = empty_row_indicator_t->vec<bool>().data(); 65 } 66 Tindex* reverse_index_map = nullptr; 67 if (context->output_required(kReverseIndexMapOutput)) { 68 Tensor* reverse_index_map_t = nullptr; 69 TF_RETURN_IF_ERROR(context->allocate_output( 70 kReverseIndexMapOutput, TensorShape({N}), &reverse_index_map_t)); 71 reverse_index_map = reverse_index_map_t->vec<Tindex>().data(); 72 } 73 74 int rank = indices_t.shape().dim_size(1); 75 76 if (dense_rows == 0) { 77 if (N != 0) { 78 return errors::InvalidArgument( 79 "Received SparseTensor with dense_shape[0] = 0 but " 80 "indices.shape[0] = ", 81 N); 82 } 83 Tensor* output_indices_t; 84 TensorShape output_indices_shape({0, rank}); 85 TF_RETURN_IF_ERROR(context->allocate_output( 86 kOutputIndicesOutput, output_indices_shape, &output_indices_t)); 87 Tensor* output_values_t; 88 TF_RETURN_IF_ERROR(context->allocate_output( 89 kOutputValuesOutput, TensorShape({0}), &output_values_t)); 90 91 // Exit early, nothing more to do. 92 return Status::OK(); 93 } 94 95 bool rows_are_ordered = true; 96 Tindex last_indices_row = 0; 97 std::vector<Tindex> csr_offset(dense_rows, 0); 98 for (int i = 0; i < N; ++i) { 99 const Tindex row = indices(i, 0); 100 if (row < 0 || row >= dense_rows) { 101 return errors::InvalidArgument("indices(", i, ", 0) is invalid: ", row, 102 " >= ", dense_rows); 103 } 104 ++csr_offset[row]; 105 rows_are_ordered = rows_are_ordered & (row >= last_indices_row); 106 last_indices_row = row; 107 } 108 bool all_rows_full = true; 109 for (int row = 0; row < dense_rows; ++row) { 110 // csr_offset here describes the number of elements in this dense row 111 bool row_empty = (csr_offset[row] == 0); 112 if (empty_row_indicator) { 113 empty_row_indicator[row] = row_empty; 114 } 115 all_rows_full = all_rows_full & !row_empty; 116 // In filled version, each row has at least one element. 117 csr_offset[row] = std::max(csr_offset[row], Tindex{1}); 118 // Update csr_offset to represent the number of elements up to and 119 // including dense_row + 1: 120 // csr_offset(0) == #{elements of row 0} 121 // csr_offset(1) == #{elements of row 1} + #{elements of row 0} 122 // .. 123 // csr_offset(i) == starting index for elements in row i + 1. 124 if (row > 0) { 125 csr_offset[row] += csr_offset[row - 1]; 126 } 127 } 128 129 if (all_rows_full && rows_are_ordered) { 130 context->set_output(kOutputIndicesOutput, indices_t); 131 context->set_output(kOutputValuesOutput, values_t); 132 if (reverse_index_map) { 133 for (Tindex i = 0; i < N; ++i) { 134 reverse_index_map[i] = i; 135 } 136 } 137 } else { 138 Tensor* output_indices_t; 139 const Tindex N_full = csr_offset[dense_rows - 1]; 140 TensorShape output_indices_shape({N_full, rank}); 141 TF_RETURN_IF_ERROR(context->allocate_output( 142 kOutputIndicesOutput, output_indices_shape, &output_indices_t)); 143 auto output_indices = output_indices_t->matrix<Tindex>(); 144 145 Tensor* output_values_t; 146 TF_RETURN_IF_ERROR(context->allocate_output( 147 kOutputValuesOutput, TensorShape({N_full}), &output_values_t)); 148 auto output_values = output_values_t->vec<T>(); 149 150 std::vector<Tindex> filled_count(dense_rows, 0); 151 152 // Fill in values for rows that are not missing 153 for (Tindex i = 0; i < N; ++i) { 154 const Tindex row = indices(i, 0); 155 Tindex& offset = filled_count[row]; 156 const Tindex output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset; 157 offset++; // Increment the filled count for this row. 158 std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0)); 159 output_values(output_i) = values(i); 160 // We'll need this reverse index map to backprop correctly. 161 if (reverse_index_map) { 162 reverse_index_map[i] = output_i; 163 } 164 } 165 166 // Fill in values for rows that are missing 167 for (Tindex row = 0; row < dense_rows; ++row) { 168 const Tindex row_count = filled_count[row]; 169 if (row_count == 0) { // We haven't filled this row 170 const Tindex starting_index = (row == 0) ? 0 : csr_offset[row - 1]; 171 // Remaining index values were set to zero already. 172 // Just need to set the row index in the right location. 173 output_indices(starting_index, 0) = row; 174 for (Tindex col = 1; col < rank; ++col) { 175 output_indices(starting_index, col) = 0; 176 } 177 output_values(starting_index) = default_value; 178 } 179 } 180 } 181 182 return Status::OK(); 183 } 184 }; 185 186 } // namespace functor 187 188 template <typename Device, typename T, typename Tindex> 189 class SparseFillEmptyRowsOp : public OpKernel { 190 public: SparseFillEmptyRowsOp(OpKernelConstruction * context)191 explicit SparseFillEmptyRowsOp(OpKernelConstruction* context) 192 : OpKernel(context) {} 193 Compute(OpKernelContext * context)194 void Compute(OpKernelContext* context) override { 195 const int kIndicesInput = 0; 196 const int kValuesInput = 1; 197 const int kDenseShapeInput = 2; 198 const int kDefaultValueInput = 3; 199 200 const Tensor& indices_t = context->input(kIndicesInput); 201 const Tensor& values_t = context->input(kValuesInput); 202 const Tensor& dense_shape_t = context->input(kDenseShapeInput); 203 const Tensor& default_value_t = context->input(kDefaultValueInput); 204 205 OP_REQUIRES(context, TensorShapeUtils::IsVector(dense_shape_t.shape()), 206 errors::InvalidArgument("dense_shape must be a vector, saw: ", 207 dense_shape_t.shape().DebugString())); 208 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices_t.shape()), 209 errors::InvalidArgument("indices must be a matrix, saw: ", 210 indices_t.shape().DebugString())); 211 OP_REQUIRES(context, TensorShapeUtils::IsVector(values_t.shape()), 212 errors::InvalidArgument("values must be a vector, saw: ", 213 values_t.shape().DebugString())); 214 OP_REQUIRES(context, TensorShapeUtils::IsScalar(default_value_t.shape()), 215 errors::InvalidArgument("default_value must be a scalar, saw: ", 216 default_value_t.shape().DebugString())); 217 // TODO(ebrevdo): add shape checks between values, indices, 218 // dense_shape. Also add check that dense rank > 0. 219 220 OP_REQUIRES_OK(context, functor::SparseFillEmptyRows<Device, T, Tindex>()( 221 context, default_value_t, indices_t, values_t, 222 dense_shape_t)); 223 } 224 }; 225 226 #define REGISTER_KERNELS(D, T, Tindex) \ 227 REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows") \ 228 .Device(DEVICE_##D) \ 229 .HostMemory("dense_shape") \ 230 .TypeConstraint<T>("T"), \ 231 SparseFillEmptyRowsOp<D##Device, T, Tindex>) 232 233 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64) 234 TF_CALL_ALL_TYPES(REGISTER_CPU_KERNELS); 235 #undef REGISTER_CPU_KERNELS 236 237 #undef REGISTER_KERNELS 238 239 namespace functor { 240 241 template <typename T, typename Tindex> 242 struct SparseFillEmptyRowsGrad<CPUDevice, T, Tindex> { operator ()tensorflow::functor::SparseFillEmptyRowsGrad243 Status operator()(OpKernelContext* context, 244 typename TTypes<Tindex>::ConstVec reverse_index_map, 245 typename TTypes<T>::ConstVec grad_values, 246 typename TTypes<T>::Vec d_values, 247 typename TTypes<T>::Scalar d_default_value) { 248 const CPUDevice& device = context->eigen_device<CPUDevice>(); 249 const Tindex N = reverse_index_map.dimension(0); 250 const Tindex N_full = grad_values.dimension(0); 251 252 T& d_default_value_scalar = d_default_value(); 253 d_default_value_scalar = T(); 254 255 Tensor visited_t; 256 TF_RETURN_IF_ERROR( 257 context->allocate_temp(DT_BOOL, TensorShape({N_full}), &visited_t)); 258 auto visited = visited_t.vec<bool>(); 259 visited.device(device) = visited.constant(false); 260 261 for (int i = 0; i < N; ++i) { 262 // Locate the index of the output of the forward prop associated 263 // with this location in the input of the forward prop. Copy 264 // the gradient into it. Mark it as visited. 265 int64 reverse_index = reverse_index_map(i); 266 if (reverse_index < 0 || reverse_index >= N_full) { 267 return errors::InvalidArgument( 268 "Elements in reverse index must be in [0, ", N_full, ") but got ", 269 reverse_index); 270 } 271 d_values(i) = grad_values(reverse_index); 272 visited(reverse_index) = true; 273 } 274 for (int j = 0; j < N_full; ++j) { 275 // The default value gradient gets the accumulated remainder of 276 // the backprop values (since the default value was used to fill 277 // in these slots in the forward calculation). 278 if (!visited(j)) { 279 d_default_value_scalar += grad_values(j); 280 } 281 } 282 return Status::OK(); 283 } 284 }; 285 286 } // namespace functor 287 288 template <typename Device, typename T, typename Tindex> 289 class SparseFillEmptyRowsGradOp : public OpKernel { 290 public: SparseFillEmptyRowsGradOp(OpKernelConstruction * context)291 explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context) 292 : OpKernel(context) {} 293 Compute(OpKernelContext * context)294 void Compute(OpKernelContext* context) override { 295 const Tensor* reverse_index_map_t; 296 const Tensor* grad_values_t; 297 OP_REQUIRES_OK(context, 298 context->input("reverse_index_map", &reverse_index_map_t)); 299 OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t)); 300 301 OP_REQUIRES( 302 context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()), 303 errors::InvalidArgument("reverse_index_map must be a vector, saw: ", 304 reverse_index_map_t->shape().DebugString())); 305 OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()), 306 errors::InvalidArgument("grad_values must be a vector, saw: ", 307 grad_values_t->shape().DebugString())); 308 309 const auto reverse_index_map = reverse_index_map_t->vec<Tindex>(); 310 const auto grad_values = grad_values_t->vec<T>(); 311 312 const Tindex N = reverse_index_map_t->shape().dim_size(0); 313 314 Tensor* d_values_t; 315 OP_REQUIRES_OK(context, context->allocate_output( 316 "d_values", TensorShape({N}), &d_values_t)); 317 auto d_values = d_values_t->vec<T>(); 318 Tensor* d_default_value_t; 319 OP_REQUIRES_OK(context, 320 context->allocate_output("d_default_value", TensorShape({}), 321 &d_default_value_t)); 322 auto d_default_value = d_default_value_t->scalar<T>(); 323 324 OP_REQUIRES_OK(context, 325 functor::SparseFillEmptyRowsGrad<Device, T, Tindex>()( 326 context, reverse_index_map, grad_values, d_values, 327 d_default_value)); 328 } 329 }; 330 331 #define REGISTER_KERNELS(D, T, Tindex) \ 332 REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \ 333 .Device(DEVICE_##D) \ 334 .TypeConstraint<T>("T"), \ 335 SparseFillEmptyRowsGradOp<D##Device, T, Tindex>) 336 337 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64) 338 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); 339 #undef REGISTER_CPU_KERNELS 340 341 #undef REGISTER_KERNELS 342 } // namespace tensorflow 343