• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 #include "tensorflow/core/framework/register_types.h"
18 #include "tensorflow/core/framework/tensor.h"
19 #include "tensorflow/core/framework/tensor_util.h"
20 #include "tensorflow/core/framework/types.h"
21 
22 namespace tensorflow {
23 
24 template <typename T>
25 class SparseSliceGradOp : public OpKernel {
26  public:
SparseSliceGradOp(OpKernelConstruction * ctx)27   explicit SparseSliceGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}
28 
Compute(OpKernelContext * ctx)29   void Compute(OpKernelContext *ctx) override {
30     const Tensor *backprop_val_grad, *input_indices, *output_indices, *input_start;
31     OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad", &backprop_val_grad));
32     OP_REQUIRES_OK(ctx, ctx->input("input_indices", &input_indices));
33     OP_REQUIRES_OK(ctx, ctx->input("input_start", &input_start));
34     OP_REQUIRES_OK(ctx, ctx->input("output_indices", &output_indices));
35 
36     OP_REQUIRES(ctx,
37                 TensorShapeUtils::IsMatrix(input_indices->shape()) &&
38                     TensorShapeUtils::IsMatrix(output_indices->shape()),
39                 errors::InvalidArgument(
40                     "Input and output indices should be matrices "
41                     "but received shapes: ",
42                     input_indices->shape().DebugString(), " and ",
43                     output_indices->shape().DebugString()));
44     OP_REQUIRES(
45         ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()),
46         errors::InvalidArgument(
47             "Input backprop_val_grad should be a vector but received shape: ",
48             backprop_val_grad->shape().DebugString()));
49     OP_REQUIRES(
50         ctx,
51         input_indices->dim_size(1) == output_indices->dim_size(1),
52         errors::InvalidArgument("The input and output should have the same "
53                                 "ndims: got: ", input_indices->dim_size(1), " and ",
54                                 output_indices->dim_size(1)));
55     OP_REQUIRES(
56         ctx, output_indices->dim_size(0) <= input_indices->dim_size(0),
57         errors::InvalidArgument("# rows of output_indices should be not greater "
58                                 "than of input_indices, got ",
59                                 output_indices->dim_size(0), " and ",
60                                 input_indices->dim_size(0)));
61     OP_REQUIRES(
62         ctx, backprop_val_grad->NumElements() == output_indices->dim_size(0),
63         errors::InvalidArgument("# elements of backprop_val_grad and # rows of "
64                                 "output_indices should match (#nnz of sum): got ",
65                                 backprop_val_grad->NumElements(), " and ",
66                                 output_indices->dim_size(0)));
67     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_start->shape()),
68                 errors::InvalidArgument(
69                     "The input_start should be a vector but received shape ",
70                     input_start->shape().DebugString()));
71 
72     const int num_dims = input_indices->dim_size(1);
73     OP_REQUIRES(ctx, num_dims == input_start->NumElements(),
74                 errors::InvalidArgument(
75                     "Expected input_start to be a vector of length ", num_dims,
76                     " but got length ", input_start->NumElements()));
77 
78     const int64 input_nnz = input_indices->dim_size(0);
79 
80     Tensor *val_grad;
81     OP_REQUIRES_OK(ctx,
82                    ctx->allocate_output(0, TensorShape({input_nnz}), &val_grad));
83 
84     T *val_grad_flat = val_grad->flat<T>().data();
85     const T *backprop_val_grad_flat = backprop_val_grad->flat<T>().data();
86     memset(val_grad_flat, 0, sizeof(T) * input_nnz);
87 
88     // Fill gradients for position where indices of input and output are same.
89     const auto input_indices_mat = input_indices->matrix<int64>();
90     const auto output_indices_mat = output_indices->matrix<int64>();
91     const auto input_start_flat = input_start->flat<int64>();
92     int64 j = 0;
93     for (int64 i = 0; i < input_nnz && j < backprop_val_grad->NumElements();
94          ++i) {
95       bool is_same = true;
96       for (int d = 0; d < num_dims; ++d) {
97         const int64 a = input_indices_mat(i, d);
98         const int64 b = output_indices_mat(j, d);
99         const int64 offset = input_start_flat(d);
100         if (a != b + offset) {
101           is_same = false;
102           break;
103         }
104       }
105       if (is_same) {
106         val_grad_flat[i] = backprop_val_grad_flat[j];
107         ++j;
108       }
109     }
110     OP_REQUIRES(
111         ctx, backprop_val_grad->NumElements() == j,
112         errors::Internal("Elements of backprop_val_grad aren't all propagated. "
113                          "Num elements:", backprop_val_grad->NumElements(),
114                          ", used: ", j));
115   }
116 };
117 
118 #define REGISTER_KERNELS(type)                                              \
119   REGISTER_KERNEL_BUILDER(                                                  \
120       Name("SparseSliceGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
121       SparseSliceGradOp<type>)
122 
123 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
124 #undef REGISTER_KERNELS
125 }  // namespace tensorflow
126