• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
30 #include "tensorflow/core/util/sparse/sparse_tensor.h"
31 
32 namespace tensorflow {
33 
34 using CPUDevice = Eigen::ThreadPoolDevice;
35 
36 template <typename T>
37 class SparseFillEmptyRowsOp : public OpKernel {
38  public:
SparseFillEmptyRowsOp(OpKernelConstruction * context)39   explicit SparseFillEmptyRowsOp(OpKernelConstruction* context)
40       : OpKernel(context) {}
41 
Compute(OpKernelContext * context)42   void Compute(OpKernelContext* context) override {
43     const Tensor* indices_t;
44     const Tensor* values_t;
45     const Tensor* dense_shape_t;
46     const Tensor* default_value_t;
47     OP_REQUIRES_OK(context, context->input("indices", &indices_t));
48     OP_REQUIRES_OK(context, context->input("values", &values_t));
49     OP_REQUIRES_OK(context, context->input("dense_shape", &dense_shape_t));
50     OP_REQUIRES_OK(context, context->input("default_value", &default_value_t));
51 
52     const CPUDevice& d = context->eigen_device<CPUDevice>();
53 
54     OP_REQUIRES(context, TensorShapeUtils::IsVector(dense_shape_t->shape()),
55                 errors::InvalidArgument("dense_shape must be a vector, saw: ",
56                                         dense_shape_t->shape().DebugString()));
57     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices_t->shape()),
58                 errors::InvalidArgument("indices must be a matrix, saw: ",
59                                         indices_t->shape().DebugString()));
60     OP_REQUIRES(context, TensorShapeUtils::IsVector(values_t->shape()),
61                 errors::InvalidArgument("values must be a vector, saw: ",
62                                         values_t->shape().DebugString()));
63     OP_REQUIRES(
64         context, TensorShapeUtils::IsScalar(default_value_t->shape()),
65         errors::InvalidArgument("default_value must be a scalar, saw: ",
66                                 default_value_t->shape().DebugString()));
67     // TODO(ebrevdo): add shape checks between values, indices,
68     // dense_shape.  Also add check that dense rank > 0.
69 
70     const T& default_value = default_value_t->scalar<T>()();
71     const auto indices = indices_t->matrix<int64>();
72     const auto values = values_t->vec<T>();
73     const auto dense_shape = dense_shape_t->vec<int64>();
74 
75     const int64 N = indices_t->shape().dim_size(0);
76     const int64 dense_rows = dense_shape(0);
77 
78     Tensor* empty_row_indicator_t;
79     OP_REQUIRES_OK(context, context->allocate_output("empty_row_indicator",
80                                                      TensorShape({dense_rows}),
81                                                      &empty_row_indicator_t));
82     auto empty_row_indicator = empty_row_indicator_t->vec<bool>();
83     Tensor* reverse_index_map_t;
84     OP_REQUIRES_OK(
85         context, context->allocate_output("reverse_index_map", TensorShape({N}),
86                                           &reverse_index_map_t));
87     auto reverse_index_map = reverse_index_map_t->vec<int64>();
88 
89     int rank = indices_t->shape().dim_size(1);
90 
91     if (dense_rows == 0) {
92       OP_REQUIRES(
93           context, N == 0,
94           errors::InvalidArgument("Received SparseTensor with dense_shape[0] = "
95                                   "0 but indices.shape[0] = ",
96                                   N));
97       Tensor* output_indices_t;
98       TensorShape output_indices_shape({0, rank});
99       OP_REQUIRES_OK(context, context->allocate_output("output_indices",
100                                                        output_indices_shape,
101                                                        &output_indices_t));
102       Tensor* output_values_t;
103       OP_REQUIRES_OK(context,
104                      context->allocate_output("output_values", TensorShape({0}),
105                                               &output_values_t));
106 
107       // Exit early, nothing more to do.
108       return;
109     }
110 
111     Tensor scratch_t;
112     OP_REQUIRES_OK(context,
113                    context->allocate_temp(DT_INT64, TensorShape({dense_rows}),
114                                           &scratch_t));
115     auto scratch = scratch_t.vec<int64>();
116     scratch.device(d) = scratch.constant(0);
117     for (int i = 0; i < N; ++i) {
118       const int64 row = indices(i, 0);
119       OP_REQUIRES(context, row >= 0 && row < dense_rows,
120                   errors::InvalidArgument("indices(", i, ", 0) is invalid: ",
121                                           row, " >= ", dense_rows));
122       ++scratch(indices(i, 0));
123     }
124     for (int row = 0; row < dense_rows; ++row) {
125       // Scratch here describes the number of elements in this dense row
126       empty_row_indicator(row) = (scratch(row) == 0);
127       // In filled version, each row has at least one element.
128       scratch(row) = std::max(scratch(row), 1LL);
129       // Update scratch to represent the number of elements up to and
130       // including dense_row + 1:
131       //  scratch(0) == #{elements of row 0}
132       //  scratch(1) == #{elements of row 1} + #{elements of row 0}
133       //  ..
134       //  scratch(i) == starting index for elements in row i + 1.
135       if (row > 0) {
136         scratch(row) += scratch(row - 1);
137       }
138     }
139     Tensor* output_indices_t;
140     const int64 N_full = scratch(dense_rows - 1);
141     TensorShape output_indices_shape({N_full, rank});
142     OP_REQUIRES_OK(context, context->allocate_output("output_indices",
143                                                      output_indices_shape,
144                                                      &output_indices_t));
145     auto output_indices = output_indices_t->matrix<int64>();
146     output_indices.device(d) = output_indices.constant(0);
147 
148     Tensor* output_values_t;
149     OP_REQUIRES_OK(
150         context, context->allocate_output(
151                      "output_values", TensorShape({N_full}), &output_values_t));
152     auto output_values = output_values_t->vec<T>();
153     output_values.device(d) = output_values.constant(default_value);
154 
155     Tensor filled_count_t;
156     OP_REQUIRES_OK(context,
157                    context->allocate_temp(DT_INT64, TensorShape({dense_rows}),
158                                           &filled_count_t));
159     auto filled_count = filled_count_t.vec<int64>();
160     filled_count.device(d) = filled_count.constant(0);
161 
162     // Fill in values for rows that are not missing
163     for (int64 i = 0; i < N; ++i) {
164       const int64 row = indices(i, 0);
165       int64& offset = filled_count(row);
166       const int64 output_i = ((row == 0) ? 0 : scratch(row - 1)) + offset;
167       offset++;  // Increment the filled count for this row.
168       std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
169       output_values(output_i) = values(i);
170       // We'll need this reverse index map to backprop correctly.
171       reverse_index_map(i) = output_i;
172     }
173 
174     // Fill in values for rows that are missing
175     for (int64 row = 0; row < dense_rows; ++row) {
176       const int64 row_count = filled_count(row);
177       if (row_count == 0) {  // We haven't filled this row
178         const int64 starting_index = (row == 0) ? 0 : scratch(row - 1);
179         // Remaining index values were set to zero already.
180         // The value at this index was set to default_value already.
181         // Just need to set the row index in the right location.
182         output_indices(starting_index, 0) = row;
183       }
184     }
185   }
186 };
187 
188 #define REGISTER_KERNELS(type)                            \
189   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows")     \
190                               .Device(DEVICE_CPU)         \
191                               .TypeConstraint<type>("T"), \
192                           SparseFillEmptyRowsOp<type>)
193 
194 TF_CALL_ALL_TYPES(REGISTER_KERNELS);
195 #undef REGISTER_KERNELS
196 
197 template <typename T>
198 class SparseFillEmptyRowsGradOp : public OpKernel {
199  public:
SparseFillEmptyRowsGradOp(OpKernelConstruction * context)200   explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context)
201       : OpKernel(context) {}
202 
Compute(OpKernelContext * context)203   void Compute(OpKernelContext* context) override {
204     const Tensor* reverse_index_map_t;
205     const Tensor* grad_values_t;
206     OP_REQUIRES_OK(context,
207                    context->input("reverse_index_map", &reverse_index_map_t));
208     OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t));
209 
210     const CPUDevice& d = context->eigen_device<CPUDevice>();
211 
212     OP_REQUIRES(
213         context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
214         errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
215                                 reverse_index_map_t->shape().DebugString()));
216 
217     const auto reverse_index_map = reverse_index_map_t->vec<int64>();
218     const auto grad_values = grad_values_t->vec<T>();
219 
220     const int64 N = reverse_index_map_t->shape().dim_size(0);
221     const int64 N_full = grad_values_t->shape().dim_size(0);
222 
223     Tensor* d_values_t;
224     OP_REQUIRES_OK(context, context->allocate_output(
225                                 "d_values", TensorShape({N}), &d_values_t));
226     auto d_values = d_values_t->vec<T>();
227     Tensor* d_default_value_t;
228     OP_REQUIRES_OK(context,
229                    context->allocate_output("d_default_value", TensorShape({}),
230                                             &d_default_value_t));
231     T& d_default_value = d_default_value_t->scalar<T>()();
232     d_default_value = T();
233 
234     Tensor visited_t;
235     OP_REQUIRES_OK(context, context->allocate_temp(
236                                 DT_BOOL, TensorShape({N_full}), &visited_t));
237     auto visited = visited_t.vec<bool>();
238     visited.device(d) = visited.constant(false);
239 
240     for (int i = 0; i < N; ++i) {
241       // Locate the index of the output of the forward prop associated
242       // with this location in the input of the forward prop.  Copy
243       // the gradient into it.  Mark it as visited.
244       d_values(i) = grad_values(reverse_index_map(i));
245       visited(reverse_index_map(i)) = true;
246     }
247     for (int j = 0; j < N_full; ++j) {
248       // The default value gradient gets the accumulated remainder of
249       // the backprop values (since the default value was used to fill
250       // in these slots in the forward calculation).
251       if (!visited(j)) {
252         d_default_value += grad_values(j);
253       }
254     }
255   }
256 };
257 
258 #define REGISTER_KERNELS(type)                            \
259   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \
260                               .Device(DEVICE_CPU)         \
261                               .TypeConstraint<type>("T"), \
262                           SparseFillEmptyRowsGradOp<type>)
263 
264 TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
265 #undef REGISTER_KERNELS
266 }  // namespace tensorflow
267