1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_kernel.h" 17 #include "tensorflow/core/framework/register_types.h" 18 #include "tensorflow/core/framework/tensor.h" 19 #include "tensorflow/core/framework/tensor_util.h" 20 #include "tensorflow/core/framework/types.h" 21 22 namespace tensorflow { 23 24 template <typename T> 25 class SparseSliceGradOp : public OpKernel { 26 public: SparseSliceGradOp(OpKernelConstruction * ctx)27 explicit SparseSliceGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} 28 Compute(OpKernelContext * ctx)29 void Compute(OpKernelContext *ctx) override { 30 const Tensor *backprop_val_grad, *input_indices, *output_indices, *input_start; 31 OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad", &backprop_val_grad)); 32 OP_REQUIRES_OK(ctx, ctx->input("input_indices", &input_indices)); 33 OP_REQUIRES_OK(ctx, ctx->input("input_start", &input_start)); 34 OP_REQUIRES_OK(ctx, ctx->input("output_indices", &output_indices)); 35 36 OP_REQUIRES(ctx, 37 TensorShapeUtils::IsMatrix(input_indices->shape()) && 38 TensorShapeUtils::IsMatrix(output_indices->shape()), 39 errors::InvalidArgument( 40 "Input and output indices should be matrices " 41 "but received shapes: ", 42 input_indices->shape().DebugString(), " and ", 43 output_indices->shape().DebugString())); 44 OP_REQUIRES( 45 ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()), 46 errors::InvalidArgument( 47 "Input backprop_val_grad should be a vector but received shape: ", 48 backprop_val_grad->shape().DebugString())); 49 OP_REQUIRES( 50 ctx, 51 input_indices->dim_size(1) == output_indices->dim_size(1), 52 errors::InvalidArgument("The input and output should have the same " 53 "ndims: got: ", input_indices->dim_size(1), " and ", 54 output_indices->dim_size(1))); 55 OP_REQUIRES( 56 ctx, output_indices->dim_size(0) <= input_indices->dim_size(0), 57 errors::InvalidArgument("# rows of output_indices should be not greater " 58 "than of input_indices, got ", 59 output_indices->dim_size(0), " and ", 60 input_indices->dim_size(0))); 61 OP_REQUIRES( 62 ctx, backprop_val_grad->NumElements() == output_indices->dim_size(0), 63 errors::InvalidArgument("# elements of backprop_val_grad and # rows of " 64 "output_indices should match (#nnz of sum): got ", 65 backprop_val_grad->NumElements(), " and ", 66 output_indices->dim_size(0))); 67 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_start->shape()), 68 errors::InvalidArgument( 69 "The input_start should be a vector but received shape ", 70 input_start->shape().DebugString())); 71 72 const int num_dims = input_indices->dim_size(1); 73 OP_REQUIRES(ctx, num_dims == input_start->NumElements(), 74 errors::InvalidArgument( 75 "Expected input_start to be a vector of length ", num_dims, 76 " but got length ", input_start->NumElements())); 77 78 const int64 input_nnz = input_indices->dim_size(0); 79 80 Tensor *val_grad; 81 OP_REQUIRES_OK(ctx, 82 ctx->allocate_output(0, TensorShape({input_nnz}), &val_grad)); 83 84 T *val_grad_flat = val_grad->flat<T>().data(); 85 const T *backprop_val_grad_flat = backprop_val_grad->flat<T>().data(); 86 memset(val_grad_flat, 0, sizeof(T) * input_nnz); 87 88 // Fill gradients for position where indices of input and output are same. 89 const auto input_indices_mat = input_indices->matrix<int64>(); 90 const auto output_indices_mat = output_indices->matrix<int64>(); 91 const auto input_start_flat = input_start->flat<int64>(); 92 int64 j = 0; 93 for (int64 i = 0; i < input_nnz && j < backprop_val_grad->NumElements(); 94 ++i) { 95 bool is_same = true; 96 for (int d = 0; d < num_dims; ++d) { 97 const int64 a = input_indices_mat(i, d); 98 const int64 b = output_indices_mat(j, d); 99 const int64 offset = input_start_flat(d); 100 if (a != b + offset) { 101 is_same = false; 102 break; 103 } 104 } 105 if (is_same) { 106 val_grad_flat[i] = backprop_val_grad_flat[j]; 107 ++j; 108 } 109 } 110 OP_REQUIRES( 111 ctx, backprop_val_grad->NumElements() == j, 112 errors::Internal("Elements of backprop_val_grad aren't all propagated. " 113 "Num elements:", backprop_val_grad->NumElements(), 114 ", used: ", j)); 115 } 116 }; 117 118 #define REGISTER_KERNELS(type) \ 119 REGISTER_KERNEL_BUILDER( \ 120 Name("SparseSliceGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 121 SparseSliceGradOp<type>) 122 123 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); 124 #undef REGISTER_KERNELS 125 } // namespace tensorflow 126