• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/sparse_fill_empty_rows_op.h"
19 
20 #include <algorithm>
21 #include <numeric>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/register_types.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_util.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/util/sparse/sparse_tensor.h"
33 
34 namespace tensorflow {
35 
36 using CPUDevice = Eigen::ThreadPoolDevice;
37 
38 namespace functor {
39 
40 template <typename T, typename Tindex>
41 struct SparseFillEmptyRows<CPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseFillEmptyRows42   Status operator()(OpKernelContext* context, const Tensor& default_value_t,
43                     const Tensor& indices_t, const Tensor& values_t,
44                     const Tensor& dense_shape_t) {
45     const int kOutputIndicesOutput = 0;
46     const int kOutputValuesOutput = 1;
47     const int kEmptyRowIndicatorOutput = 2;
48     const int kReverseIndexMapOutput = 3;
49 
50     const T& default_value = default_value_t.scalar<T>()();
51     const auto indices = indices_t.matrix<Tindex>();
52     const auto values = values_t.vec<T>();
53     const auto dense_shape = dense_shape_t.vec<Tindex>();
54 
55     const Tindex N = indices_t.shape().dim_size(0);
56     const Tindex dense_rows = dense_shape(0);
57 
58     bool* empty_row_indicator = nullptr;
59     if (context->output_required(kEmptyRowIndicatorOutput)) {
60       Tensor* empty_row_indicator_t = nullptr;
61       TF_RETURN_IF_ERROR(context->allocate_output(kEmptyRowIndicatorOutput,
62                                                   TensorShape({dense_rows}),
63                                                   &empty_row_indicator_t));
64       empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
65     }
66     Tindex* reverse_index_map = nullptr;
67     if (context->output_required(kReverseIndexMapOutput)) {
68       Tensor* reverse_index_map_t = nullptr;
69       TF_RETURN_IF_ERROR(context->allocate_output(
70           kReverseIndexMapOutput, TensorShape({N}), &reverse_index_map_t));
71       reverse_index_map = reverse_index_map_t->vec<Tindex>().data();
72     }
73 
74     int rank = indices_t.shape().dim_size(1);
75 
76     if (dense_rows == 0) {
77       if (N != 0) {
78         return errors::InvalidArgument(
79             "Received SparseTensor with dense_shape[0] = 0 but "
80             "indices.shape[0] = ",
81             N);
82       }
83       Tensor* output_indices_t;
84       TensorShape output_indices_shape({0, rank});
85       TF_RETURN_IF_ERROR(context->allocate_output(
86           kOutputIndicesOutput, output_indices_shape, &output_indices_t));
87       Tensor* output_values_t;
88       TF_RETURN_IF_ERROR(context->allocate_output(
89           kOutputValuesOutput, TensorShape({0}), &output_values_t));
90 
91       // Exit early, nothing more to do.
92       return Status::OK();
93     }
94 
95     bool rows_are_ordered = true;
96     Tindex last_indices_row = 0;
97     std::vector<Tindex> csr_offset(dense_rows, 0);
98     for (int i = 0; i < N; ++i) {
99       const Tindex row = indices(i, 0);
100       if (row < 0 || row >= dense_rows) {
101         return errors::InvalidArgument("indices(", i, ", 0) is invalid: ", row,
102                                        " >= ", dense_rows);
103       }
104       ++csr_offset[row];
105       rows_are_ordered = rows_are_ordered & (row >= last_indices_row);
106       last_indices_row = row;
107     }
108     bool all_rows_full = true;
109     for (int row = 0; row < dense_rows; ++row) {
110       // csr_offset here describes the number of elements in this dense row
111       bool row_empty = (csr_offset[row] == 0);
112       if (empty_row_indicator) {
113         empty_row_indicator[row] = row_empty;
114       }
115       all_rows_full = all_rows_full & !row_empty;
116       // In filled version, each row has at least one element.
117       csr_offset[row] = std::max(csr_offset[row], Tindex{1});
118       // Update csr_offset to represent the number of elements up to and
119       // including dense_row + 1:
120       //  csr_offset(0) == #{elements of row 0}
121       //  csr_offset(1) == #{elements of row 1} + #{elements of row 0}
122       //  ..
123       //  csr_offset(i) == starting index for elements in row i + 1.
124       if (row > 0) {
125         csr_offset[row] += csr_offset[row - 1];
126       }
127     }
128 
129     if (all_rows_full && rows_are_ordered) {
130       context->set_output(kOutputIndicesOutput, indices_t);
131       context->set_output(kOutputValuesOutput, values_t);
132       if (reverse_index_map) {
133         for (Tindex i = 0; i < N; ++i) {
134           reverse_index_map[i] = i;
135         }
136       }
137     } else {
138       Tensor* output_indices_t;
139       const Tindex N_full = csr_offset[dense_rows - 1];
140       TensorShape output_indices_shape({N_full, rank});
141       TF_RETURN_IF_ERROR(context->allocate_output(
142           kOutputIndicesOutput, output_indices_shape, &output_indices_t));
143       auto output_indices = output_indices_t->matrix<Tindex>();
144 
145       Tensor* output_values_t;
146       TF_RETURN_IF_ERROR(context->allocate_output(
147           kOutputValuesOutput, TensorShape({N_full}), &output_values_t));
148       auto output_values = output_values_t->vec<T>();
149 
150       std::vector<Tindex> filled_count(dense_rows, 0);
151 
152       // Fill in values for rows that are not missing
153       for (Tindex i = 0; i < N; ++i) {
154         const Tindex row = indices(i, 0);
155         Tindex& offset = filled_count[row];
156         const Tindex output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset;
157         offset++;  // Increment the filled count for this row.
158         std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
159         output_values(output_i) = values(i);
160         // We'll need this reverse index map to backprop correctly.
161         if (reverse_index_map) {
162           reverse_index_map[i] = output_i;
163         }
164       }
165 
166       // Fill in values for rows that are missing
167       for (Tindex row = 0; row < dense_rows; ++row) {
168         const Tindex row_count = filled_count[row];
169         if (row_count == 0) {  // We haven't filled this row
170           const Tindex starting_index = (row == 0) ? 0 : csr_offset[row - 1];
171           // Remaining index values were set to zero already.
172           // Just need to set the row index in the right location.
173           output_indices(starting_index, 0) = row;
174           for (Tindex col = 1; col < rank; ++col) {
175             output_indices(starting_index, col) = 0;
176           }
177           output_values(starting_index) = default_value;
178         }
179       }
180     }
181 
182     return Status::OK();
183   }
184 };
185 
186 }  // namespace functor
187 
188 template <typename Device, typename T, typename Tindex>
189 class SparseFillEmptyRowsOp : public OpKernel {
190  public:
SparseFillEmptyRowsOp(OpKernelConstruction * context)191   explicit SparseFillEmptyRowsOp(OpKernelConstruction* context)
192       : OpKernel(context) {}
193 
Compute(OpKernelContext * context)194   void Compute(OpKernelContext* context) override {
195     const int kIndicesInput = 0;
196     const int kValuesInput = 1;
197     const int kDenseShapeInput = 2;
198     const int kDefaultValueInput = 3;
199 
200     const Tensor& indices_t = context->input(kIndicesInput);
201     const Tensor& values_t = context->input(kValuesInput);
202     const Tensor& dense_shape_t = context->input(kDenseShapeInput);
203     const Tensor& default_value_t = context->input(kDefaultValueInput);
204 
205     OP_REQUIRES(context, TensorShapeUtils::IsVector(dense_shape_t.shape()),
206                 errors::InvalidArgument("dense_shape must be a vector, saw: ",
207                                         dense_shape_t.shape().DebugString()));
208     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices_t.shape()),
209                 errors::InvalidArgument("indices must be a matrix, saw: ",
210                                         indices_t.shape().DebugString()));
211     OP_REQUIRES(context, TensorShapeUtils::IsVector(values_t.shape()),
212                 errors::InvalidArgument("values must be a vector, saw: ",
213                                         values_t.shape().DebugString()));
214     OP_REQUIRES(context, TensorShapeUtils::IsScalar(default_value_t.shape()),
215                 errors::InvalidArgument("default_value must be a scalar, saw: ",
216                                         default_value_t.shape().DebugString()));
217     // TODO(ebrevdo): add shape checks between values, indices,
218     // dense_shape.  Also add check that dense rank > 0.
219 
220     OP_REQUIRES_OK(context, functor::SparseFillEmptyRows<Device, T, Tindex>()(
221                                 context, default_value_t, indices_t, values_t,
222                                 dense_shape_t));
223   }
224 };
225 
226 #define REGISTER_KERNELS(D, T, Tindex)                   \
227   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRows")    \
228                               .Device(DEVICE_##D)        \
229                               .HostMemory("dense_shape") \
230                               .TypeConstraint<T>("T"),   \
231                           SparseFillEmptyRowsOp<D##Device, T, Tindex>)
232 
233 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
234 TF_CALL_ALL_TYPES(REGISTER_CPU_KERNELS);
235 #undef REGISTER_CPU_KERNELS
236 
237 #undef REGISTER_KERNELS
238 
239 namespace functor {
240 
241 template <typename T, typename Tindex>
242 struct SparseFillEmptyRowsGrad<CPUDevice, T, Tindex> {
operator ()tensorflow::functor::SparseFillEmptyRowsGrad243   Status operator()(OpKernelContext* context,
244                     typename TTypes<Tindex>::ConstVec reverse_index_map,
245                     typename TTypes<T>::ConstVec grad_values,
246                     typename TTypes<T>::Vec d_values,
247                     typename TTypes<T>::Scalar d_default_value) {
248     const CPUDevice& device = context->eigen_device<CPUDevice>();
249     const Tindex N = reverse_index_map.dimension(0);
250     const Tindex N_full = grad_values.dimension(0);
251 
252     T& d_default_value_scalar = d_default_value();
253     d_default_value_scalar = T();
254 
255     Tensor visited_t;
256     TF_RETURN_IF_ERROR(
257         context->allocate_temp(DT_BOOL, TensorShape({N_full}), &visited_t));
258     auto visited = visited_t.vec<bool>();
259     visited.device(device) = visited.constant(false);
260 
261     for (int i = 0; i < N; ++i) {
262       // Locate the index of the output of the forward prop associated
263       // with this location in the input of the forward prop.  Copy
264       // the gradient into it.  Mark it as visited.
265       int64 reverse_index = reverse_index_map(i);
266       if (reverse_index < 0 || reverse_index >= N_full) {
267         return errors::InvalidArgument(
268             "Elements in reverse index must be in [0, ", N_full, ") but got ",
269             reverse_index);
270       }
271       d_values(i) = grad_values(reverse_index);
272       visited(reverse_index) = true;
273     }
274     for (int j = 0; j < N_full; ++j) {
275       // The default value gradient gets the accumulated remainder of
276       // the backprop values (since the default value was used to fill
277       // in these slots in the forward calculation).
278       if (!visited(j)) {
279         d_default_value_scalar += grad_values(j);
280       }
281     }
282     return Status::OK();
283   }
284 };
285 
286 }  // namespace functor
287 
288 template <typename Device, typename T, typename Tindex>
289 class SparseFillEmptyRowsGradOp : public OpKernel {
290  public:
SparseFillEmptyRowsGradOp(OpKernelConstruction * context)291   explicit SparseFillEmptyRowsGradOp(OpKernelConstruction* context)
292       : OpKernel(context) {}
293 
Compute(OpKernelContext * context)294   void Compute(OpKernelContext* context) override {
295     const Tensor* reverse_index_map_t;
296     const Tensor* grad_values_t;
297     OP_REQUIRES_OK(context,
298                    context->input("reverse_index_map", &reverse_index_map_t));
299     OP_REQUIRES_OK(context, context->input("grad_values", &grad_values_t));
300 
301     OP_REQUIRES(
302         context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
303         errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
304                                 reverse_index_map_t->shape().DebugString()));
305     OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()),
306                 errors::InvalidArgument("grad_values must be a vector, saw: ",
307                                         grad_values_t->shape().DebugString()));
308 
309     const auto reverse_index_map = reverse_index_map_t->vec<Tindex>();
310     const auto grad_values = grad_values_t->vec<T>();
311 
312     const Tindex N = reverse_index_map_t->shape().dim_size(0);
313 
314     Tensor* d_values_t;
315     OP_REQUIRES_OK(context, context->allocate_output(
316                                 "d_values", TensorShape({N}), &d_values_t));
317     auto d_values = d_values_t->vec<T>();
318     Tensor* d_default_value_t;
319     OP_REQUIRES_OK(context,
320                    context->allocate_output("d_default_value", TensorShape({}),
321                                             &d_default_value_t));
322     auto d_default_value = d_default_value_t->scalar<T>();
323 
324     OP_REQUIRES_OK(context,
325                    functor::SparseFillEmptyRowsGrad<Device, T, Tindex>()(
326                        context, reverse_index_map, grad_values, d_values,
327                        d_default_value));
328   }
329 };
330 
331 #define REGISTER_KERNELS(D, T, Tindex)                    \
332   REGISTER_KERNEL_BUILDER(Name("SparseFillEmptyRowsGrad") \
333                               .Device(DEVICE_##D)         \
334                               .TypeConstraint<T>("T"),    \
335                           SparseFillEmptyRowsGradOp<D##Device, T, Tindex>)
336 
337 #define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T, int64)
338 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
339 #undef REGISTER_CPU_KERNELS
340 
341 #undef REGISTER_KERNELS
342 }  // namespace tensorflow
343