1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/container/flat_hash_map.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/op_requires.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 
26 template <class T>
27 using BatchedMap = std::vector<absl::flat_hash_map<int64, T>>;
28 
29 namespace {
30 // TODO(momernick): Extend this function to work with outputs of rank > 2.
31 template <class T>
OutputSparse(const BatchedMap<T> & per_batch_counts,int num_values,bool is_1d,OpKernelContext * context)32 Status OutputSparse(const BatchedMap<T>& per_batch_counts, int num_values,
33                     bool is_1d, OpKernelContext* context) {
34   int total_values = 0;
35   int num_batches = per_batch_counts.size();
36   for (const auto& per_batch_count : per_batch_counts) {
37     total_values += per_batch_count.size();
38   }
39 
40   Tensor* indices;
41   int inner_dim = is_1d ? 1 : 2;
42   TF_RETURN_IF_ERROR(context->allocate_output(
43       0, TensorShape({total_values, inner_dim}), &indices));
44 
45   Tensor* values;
46   TF_RETURN_IF_ERROR(
47       context->allocate_output(1, TensorShape({total_values}), &values));
48 
49   auto output_indices = indices->matrix<int64>();
50   auto output_values = values->flat<T>();
51   int64_t value_loc = 0;
52   for (int b = 0; b < num_batches; ++b) {
53     const auto& per_batch_count = per_batch_counts[b];
54     std::vector<std::pair<int, T>> pairs(per_batch_count.begin(),
55                                          per_batch_count.end());
56     std::sort(pairs.begin(), pairs.end());
57     for (const auto& x : pairs) {
58       if (is_1d) {
59         output_indices(value_loc, 0) = x.first;
60       } else {
61         output_indices(value_loc, 0) = b;
62         output_indices(value_loc, 1) = x.first;
63       }
64       output_values(value_loc) = x.second;
65       ++value_loc;
66     }
67   }
68   Tensor* dense_shape;
69   if (is_1d) {
70     TF_RETURN_IF_ERROR(
71         context->allocate_output(2, TensorShape({1}), &dense_shape));
72     dense_shape->flat<int64>().data()[0] = num_values;
73   } else {
74     TF_RETURN_IF_ERROR(
75         context->allocate_output(2, TensorShape({2}), &dense_shape));
76     dense_shape->flat<int64>().data()[0] = num_batches;
77     dense_shape->flat<int64>().data()[1] = num_values;
78   }
79 
80   return Status::OK();
81 }
82 
GetOutputSize(int max_seen,int max_length,int min_length)83 int GetOutputSize(int max_seen, int max_length, int min_length) {
84   return max_length > 0 ? max_length : std::max((max_seen + 1), min_length);
85 }
86 
87 }  // namespace
88 
89 template <class T, class W>
90 class DenseCount : public OpKernel {
91  public:
DenseCount(OpKernelConstruction * context)92   explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
93     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
94     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
95     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
96   }
97 
Compute(OpKernelContext * context)98   void Compute(OpKernelContext* context) override {
99     const Tensor& data = context->input(0);
100     const Tensor& weights = context->input(1);
101     bool use_weights = weights.NumElements() > 0;
102 
103     OP_REQUIRES(context,
104                 TensorShapeUtils::IsVector(data.shape()) ||
105                     TensorShapeUtils::IsMatrix(data.shape()),
106                 errors::InvalidArgument(
107                     "Input must be a 1 or 2-dimensional tensor. Got: ",
108                     data.shape().DebugString()));
109 
110     if (use_weights) {
111       OP_REQUIRES(
112           context, weights.shape() == data.shape(),
113           errors::InvalidArgument(
114               "Weights and data must have the same shape. Weight shape: ",
115               weights.shape().DebugString(),
116               "; data shape: ", data.shape().DebugString()));
117     }
118 
119     bool is_1d = TensorShapeUtils::IsVector(data.shape());
120     int negative_valued_axis = -1;
121     int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
122 
123     int num_batch_elements = 1;
124     for (int i = 0; i < num_batch_dimensions; ++i) {
125       OP_REQUIRES(context, data.shape().dim_size(i) != 0,
126                   errors::InvalidArgument(
127                       "Invalid input: Shapes dimension cannot be 0."));
128       num_batch_elements *= data.shape().dim_size(i);
129     }
130     int num_value_elements = data.shape().num_elements() / num_batch_elements;
131     auto per_batch_counts = BatchedMap<W>(num_batch_elements);
132 
133     T max_value = 0;
134 
135     const auto data_values = data.flat<T>();
136     const auto weight_values = weights.flat<W>();
137     int i = 0;
138     for (int b = 0; b < num_batch_elements; ++b) {
139       for (int v = 0; v < num_value_elements; ++v) {
140         const auto& value = data_values(i);
141         if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
142           if (binary_output_) {
143             per_batch_counts[b][value] = 1;
144           } else if (use_weights) {
145             per_batch_counts[b][value] += weight_values(i);
146           } else {
147             per_batch_counts[b][value]++;
148           }
149           if (value > max_value) {
150             max_value = value;
151           }
152         }
153         ++i;
154       }
155     }
156 
157     int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
158     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
159                                             is_1d, context));
160   }
161 
162  private:
163   int maxlength_;
164   int minlength_;
165   bool binary_output_;
166 };
167 
168 template <class T, class W>
169 class SparseCount : public OpKernel {
170  public:
SparseCount(OpKernelConstruction * context)171   explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
172     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
173     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
174     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
175   }
176 
Compute(OpKernelContext * context)177   void Compute(OpKernelContext* context) override {
178     const Tensor& indices = context->input(0);
179     const Tensor& values = context->input(1);
180     const Tensor& shape = context->input(2);
181     const Tensor& weights = context->input(3);
182     bool use_weights = weights.NumElements() > 0;
183 
184     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
185                 errors::InvalidArgument(
186                     "Input indices must be a 2-dimensional tensor. Got: ",
187                     indices.shape().DebugString()));
188 
189     if (use_weights) {
190       OP_REQUIRES(
191           context, weights.shape() == values.shape(),
192           errors::InvalidArgument(
193               "Weights and values must have the same shape. Weight shape: ",
194               weights.shape().DebugString(),
195               "; values shape: ", values.shape().DebugString()));
196     }
197 
198     OP_REQUIRES(context, shape.NumElements() != 0,
199                 errors::InvalidArgument(
200                     "The shape argument requires at least one element."));
201 
202     bool is_1d = shape.NumElements() == 1;
203     auto shape_vector = shape.flat<int64>();
204     int num_batches = is_1d ? 1 : shape_vector(0);
205     int num_values = values.NumElements();
206 
207     for (int b = 0; b < shape_vector.size(); b++) {
208       OP_REQUIRES(context, shape_vector(b) >= 0,
209                   errors::InvalidArgument(
210                       "Elements in dense_shape must be >= 0. Instead got:",
211                       shape.DebugString()));
212     }
213 
214     OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
215                 errors::InvalidArgument(
216                     "Number of values must match first dimension of indices.",
217                     "Got ", num_values,
218                     " values, indices shape: ", indices.shape().DebugString()));
219 
220     const auto indices_values = indices.matrix<int64>();
221     const auto values_values = values.flat<T>();
222     const auto weight_values = weights.flat<W>();
223 
224     auto per_batch_counts = BatchedMap<W>(num_batches);
225 
226     T max_value = 0;
227 
228     OP_REQUIRES(context, num_values <= indices.shape().dim_size(0),
229                 errors::InvalidArgument(
230                     "The first dimension of indices must be equal to or "
231                     "greather than number of values. ( ",
232                     indices.shape().dim_size(0), " vs. ", num_values, " )"));
233     OP_REQUIRES(context, indices.shape().dim_size(1) > 0,
234                 errors::InvalidArgument("The second dimension of indices must "
235                                         "be greater than 0. Received: ",
236                                         indices.shape().dim_size(1)));
237 
238     for (int idx = 0; idx < num_values; ++idx) {
239       int batch = is_1d ? 0 : indices_values(idx, 0);
240       if (batch >= num_batches) {
241         OP_REQUIRES(context, batch < num_batches,
242                     errors::InvalidArgument(
243                         "Indices value along the first dimension must be ",
244                         "lower than the first index of the shape.", "Got ",
245                         batch, " as batch and ", num_batches,
246                         " as the first dimension of the shape."));
247       }
248       const auto& value = values_values(idx);
249       if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
250         if (binary_output_) {
251           per_batch_counts[batch][value] = 1;
252         } else if (use_weights) {
253           per_batch_counts[batch][value] += weight_values(idx);
254         } else {
255           per_batch_counts[batch][value]++;
256         }
257         if (value > max_value) {
258           max_value = value;
259         }
260       }
261     }
262 
263     int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
264     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
265                                             is_1d, context));
266   }
267 
268  private:
269   int maxlength_;
270   int minlength_;
271   bool binary_output_;
272   bool validate_;
273 };
274 
275 template <class T, class W>
276 class RaggedCount : public OpKernel {
277  public:
RaggedCount(OpKernelConstruction * context)278   explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
279     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
280     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
281     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
282   }
283 
Compute(OpKernelContext * context)284   void Compute(OpKernelContext* context) override {
285     const Tensor& splits = context->input(0);
286     const Tensor& values = context->input(1);
287     const Tensor& weights = context->input(2);
288     bool use_weights = weights.NumElements() > 0;
289     bool is_1d = false;
290 
291     if (use_weights) {
292       OP_REQUIRES(
293           context, weights.shape() == values.shape(),
294           errors::InvalidArgument(
295               "Weights and values must have the same shape. Weight shape: ",
296               weights.shape().DebugString(),
297               "; values shape: ", values.shape().DebugString()));
298     }
299 
300     const auto splits_values = splits.flat<int64>();
301     const auto values_values = values.flat<T>();
302     const auto weight_values = weights.flat<W>();
303     int num_batches = splits.NumElements() - 1;
304     int num_values = values.NumElements();
305 
306     OP_REQUIRES(
307         context, num_batches > 0,
308         errors::InvalidArgument(
309             "Must provide at least 2 elements for the splits argument"));
310     OP_REQUIRES(context, splits_values(0) == 0,
311                 errors::InvalidArgument("Splits must start with 0, not with ",
312                                         splits_values(0)));
313     OP_REQUIRES(context, splits_values(num_batches) == num_values,
314                 errors::InvalidArgument(
315                     "Splits must end with the number of values, got ",
316                     splits_values(num_batches), " instead of ", num_values));
317 
318     auto per_batch_counts = BatchedMap<W>(num_batches);
319     T max_value = 0;
320     int batch_idx = 0;
321 
322     for (int idx = 0; idx < num_values; ++idx) {
323       while (idx >= splits_values(batch_idx)) {
324         batch_idx++;
325       }
326       const auto& value = values_values(idx);
327       if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
328         if (binary_output_) {
329           per_batch_counts[batch_idx - 1][value] = 1;
330         } else if (use_weights) {
331           per_batch_counts[batch_idx - 1][value] += weight_values(idx);
332         } else {
333           per_batch_counts[batch_idx - 1][value]++;
334         }
335         if (value > max_value) {
336           max_value = value;
337         }
338       }
339     }
340 
341     int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
342     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
343                                             is_1d, context));
344   }
345 
346  private:
347   int maxlength_;
348   int minlength_;
349   bool binary_output_;
350   bool validate_;
351 };
352 
353 #define REGISTER_W(W_TYPE) \
354   REGISTER(int32, W_TYPE)  \
355   REGISTER(int64, W_TYPE)
356 
357 #define REGISTER(I_TYPE, W_TYPE)                                     \
358                                                                      \
359   REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput")             \
360                               .TypeConstraint<I_TYPE>("T")           \
361                               .TypeConstraint<W_TYPE>("output_type") \
362                               .Device(DEVICE_CPU),                   \
363                           DenseCount<I_TYPE, W_TYPE>)                \
364                                                                      \
365   REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput")            \
366                               .TypeConstraint<I_TYPE>("T")           \
367                               .TypeConstraint<W_TYPE>("output_type") \
368                               .Device(DEVICE_CPU),                   \
369                           SparseCount<I_TYPE, W_TYPE>)               \
370                                                                      \
371   REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput")            \
372                               .TypeConstraint<I_TYPE>("T")           \
373                               .TypeConstraint<W_TYPE>("output_type") \
374                               .Device(DEVICE_CPU),                   \
375                           RaggedCount<I_TYPE, W_TYPE>)
376 
377 TF_CALL_INTEGRAL_TYPES(REGISTER_W);
378 TF_CALL_float(REGISTER_W);
379 TF_CALL_double(REGISTER_W);
380 
381 #undef REGISTER_W
382 #undef REGISTER
383 
384 }  // namespace tensorflow
385