• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/container/flat_hash_map.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/op_requires.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 
26 template <class T>
27 using BatchedMap = std::vector<absl::flat_hash_map<int64, T>>;
28 
29 namespace {
30 // TODO(momernick): Extend this function to work with outputs of rank > 2.
31 template <class T>
OutputSparse(const BatchedMap<T> & per_batch_counts,int num_values,bool is_1d,OpKernelContext * context)32 Status OutputSparse(const BatchedMap<T>& per_batch_counts, int num_values,
33                     bool is_1d, OpKernelContext* context) {
34   int total_values = 0;
35   int num_batches = per_batch_counts.size();
36   for (const auto& per_batch_count : per_batch_counts) {
37     total_values += per_batch_count.size();
38   }
39 
40   Tensor* indices;
41   int inner_dim = is_1d ? 1 : 2;
42   TF_RETURN_IF_ERROR(context->allocate_output(
43       0, TensorShape({total_values, inner_dim}), &indices));
44 
45   Tensor* values;
46   TF_RETURN_IF_ERROR(
47       context->allocate_output(1, TensorShape({total_values}), &values));
48 
49   auto output_indices = indices->matrix<int64>();
50   auto output_values = values->flat<T>();
51   int64 value_loc = 0;
52   for (int b = 0; b < num_batches; ++b) {
53     const auto& per_batch_count = per_batch_counts[b];
54     std::vector<std::pair<int, T>> pairs(per_batch_count.begin(),
55                                          per_batch_count.end());
56     std::sort(pairs.begin(), pairs.end());
57     for (const auto& x : pairs) {
58       if (is_1d) {
59         output_indices(value_loc, 0) = x.first;
60       } else {
61         output_indices(value_loc, 0) = b;
62         output_indices(value_loc, 1) = x.first;
63       }
64       output_values(value_loc) = x.second;
65       ++value_loc;
66     }
67   }
68   Tensor* dense_shape;
69   if (is_1d) {
70     TF_RETURN_IF_ERROR(
71         context->allocate_output(2, TensorShape({1}), &dense_shape));
72     dense_shape->flat<int64>().data()[0] = num_values;
73   } else {
74     TF_RETURN_IF_ERROR(
75         context->allocate_output(2, TensorShape({2}), &dense_shape));
76     dense_shape->flat<int64>().data()[0] = num_batches;
77     dense_shape->flat<int64>().data()[1] = num_values;
78   }
79 
80   return Status::OK();
81 }
82 
GetOutputSize(int max_seen,int max_length,int min_length)83 int GetOutputSize(int max_seen, int max_length, int min_length) {
84   return max_length > 0 ? max_length : std::max((max_seen + 1), min_length);
85 }
86 
87 }  // namespace
88 
89 template <class T, class W>
90 class DenseCount : public OpKernel {
91  public:
DenseCount(OpKernelConstruction * context)92   explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
93     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
94     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
95     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
96   }
97 
Compute(OpKernelContext * context)98   void Compute(OpKernelContext* context) override {
99     const Tensor& data = context->input(0);
100     const Tensor& weights = context->input(1);
101     bool use_weights = weights.NumElements() > 0;
102 
103     OP_REQUIRES(context,
104                 TensorShapeUtils::IsVector(data.shape()) ||
105                     TensorShapeUtils::IsMatrix(data.shape()),
106                 errors::InvalidArgument(
107                     "Input must be a 1 or 2-dimensional tensor. Got: ",
108                     data.shape().DebugString()));
109 
110     if (use_weights) {
111       OP_REQUIRES(
112           context, weights.shape() == data.shape(),
113           errors::InvalidArgument(
114               "Weights and data must have the same shape. Weight shape: ",
115               weights.shape().DebugString(),
116               "; data shape: ", data.shape().DebugString()));
117     }
118 
119     bool is_1d = TensorShapeUtils::IsVector(data.shape());
120     int negative_valued_axis = -1;
121     int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
122 
123     int num_batch_elements = 1;
124     for (int i = 0; i < num_batch_dimensions; ++i) {
125       num_batch_elements *= data.shape().dim_size(i);
126     }
127     int num_value_elements = data.shape().num_elements() / num_batch_elements;
128     auto per_batch_counts = BatchedMap<W>(num_batch_elements);
129 
130     T max_value = 0;
131 
132     const auto data_values = data.flat<T>();
133     const auto weight_values = weights.flat<W>();
134     int i = 0;
135     for (int b = 0; b < num_batch_elements; ++b) {
136       for (int v = 0; v < num_value_elements; ++v) {
137         const auto& value = data_values(i);
138         if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
139           if (binary_output_) {
140             per_batch_counts[b][value] = 1;
141           } else if (use_weights) {
142             per_batch_counts[b][value] += weight_values(i);
143           } else {
144             per_batch_counts[b][value]++;
145           }
146           if (value > max_value) {
147             max_value = value;
148           }
149         }
150         ++i;
151       }
152     }
153 
154     int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
155     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
156                                             is_1d, context));
157   }
158 
159  private:
160   int maxlength_;
161   int minlength_;
162   bool binary_output_;
163 };
164 
165 template <class T, class W>
166 class SparseCount : public OpKernel {
167  public:
SparseCount(OpKernelConstruction * context)168   explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
169     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
170     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
171     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
172   }
173 
Compute(OpKernelContext * context)174   void Compute(OpKernelContext* context) override {
175     const Tensor& indices = context->input(0);
176     const Tensor& values = context->input(1);
177     const Tensor& shape = context->input(2);
178     const Tensor& weights = context->input(3);
179     bool use_weights = weights.NumElements() > 0;
180 
181     OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
182                 errors::InvalidArgument(
183                     "Input indices must be a 2-dimensional tensor. Got: ",
184                     indices.shape().DebugString()));
185 
186     if (use_weights) {
187       OP_REQUIRES(
188           context, weights.shape() == values.shape(),
189           errors::InvalidArgument(
190               "Weights and values must have the same shape. Weight shape: ",
191               weights.shape().DebugString(),
192               "; values shape: ", values.shape().DebugString()));
193     }
194 
195     bool is_1d = shape.NumElements() == 1;
196     int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
197     int num_values = values.NumElements();
198 
199     OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
200                 errors::InvalidArgument(
201                     "Number of values must match first dimension of indices.",
202                     "Got ", num_values,
203                     " values, indices shape: ", indices.shape().DebugString()));
204 
205     const auto indices_values = indices.matrix<int64>();
206     const auto values_values = values.flat<T>();
207     const auto weight_values = weights.flat<W>();
208 
209     auto per_batch_counts = BatchedMap<W>(num_batches);
210 
211     T max_value = 0;
212 
213     for (int idx = 0; idx < num_values; ++idx) {
214       int batch = is_1d ? 0 : indices_values(idx, 0);
215       const auto& value = values_values(idx);
216       if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
217         if (binary_output_) {
218           per_batch_counts[batch][value] = 1;
219         } else if (use_weights) {
220           per_batch_counts[batch][value] += weight_values(idx);
221         } else {
222           per_batch_counts[batch][value]++;
223         }
224         if (value > max_value) {
225           max_value = value;
226         }
227       }
228     }
229 
230     int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
231     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
232                                             is_1d, context));
233   }
234 
235  private:
236   int maxlength_;
237   int minlength_;
238   bool binary_output_;
239   bool validate_;
240 };
241 
242 template <class T, class W>
243 class RaggedCount : public OpKernel {
244  public:
RaggedCount(OpKernelConstruction * context)245   explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
246     OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
247     OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
248     OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
249   }
250 
Compute(OpKernelContext * context)251   void Compute(OpKernelContext* context) override {
252     const Tensor& splits = context->input(0);
253     const Tensor& values = context->input(1);
254     const Tensor& weights = context->input(2);
255     bool use_weights = weights.NumElements() > 0;
256     bool is_1d = false;
257 
258     if (use_weights) {
259       OP_REQUIRES(
260           context, weights.shape() == values.shape(),
261           errors::InvalidArgument(
262               "Weights and values must have the same shape. Weight shape: ",
263               weights.shape().DebugString(),
264               "; values shape: ", values.shape().DebugString()));
265     }
266 
267     const auto splits_values = splits.flat<int64>();
268     const auto values_values = values.flat<T>();
269     const auto weight_values = weights.flat<W>();
270     int num_batches = splits.NumElements() - 1;
271     int num_values = values.NumElements();
272 
273     OP_REQUIRES(
274         context, num_batches > 0,
275         errors::InvalidArgument(
276             "Must provide at least 2 elements for the splits argument"));
277     OP_REQUIRES(context, splits_values(0) == 0,
278                 errors::InvalidArgument("Splits must start with 0, not with ",
279                                         splits_values(0)));
280     OP_REQUIRES(context, splits_values(num_batches) == num_values,
281                 errors::InvalidArgument(
282                     "Splits must end with the number of values, got ",
283                     splits_values(num_batches), " instead of ", num_values));
284 
285     auto per_batch_counts = BatchedMap<W>(num_batches);
286     T max_value = 0;
287     int batch_idx = 0;
288 
289     for (int idx = 0; idx < num_values; ++idx) {
290       while (idx >= splits_values(batch_idx)) {
291         batch_idx++;
292       }
293       const auto& value = values_values(idx);
294       if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
295         if (binary_output_) {
296           per_batch_counts[batch_idx - 1][value] = 1;
297         } else if (use_weights) {
298           per_batch_counts[batch_idx - 1][value] += weight_values(idx);
299         } else {
300           per_batch_counts[batch_idx - 1][value]++;
301         }
302         if (value > max_value) {
303           max_value = value;
304         }
305       }
306     }
307 
308     int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
309     OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
310                                             is_1d, context));
311   }
312 
313  private:
314   int maxlength_;
315   int minlength_;
316   bool binary_output_;
317   bool validate_;
318 };
319 
320 #define REGISTER_W(W_TYPE) \
321   REGISTER(int32, W_TYPE)  \
322   REGISTER(int64, W_TYPE)
323 
324 #define REGISTER(I_TYPE, W_TYPE)                                     \
325                                                                      \
326   REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput")             \
327                               .TypeConstraint<I_TYPE>("T")           \
328                               .TypeConstraint<W_TYPE>("output_type") \
329                               .Device(DEVICE_CPU),                   \
330                           DenseCount<I_TYPE, W_TYPE>)                \
331                                                                      \
332   REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput")            \
333                               .TypeConstraint<I_TYPE>("T")           \
334                               .TypeConstraint<W_TYPE>("output_type") \
335                               .Device(DEVICE_CPU),                   \
336                           SparseCount<I_TYPE, W_TYPE>)               \
337                                                                      \
338   REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput")            \
339                               .TypeConstraint<I_TYPE>("T")           \
340                               .TypeConstraint<W_TYPE>("output_type") \
341                               .Device(DEVICE_CPU),                   \
342                           RaggedCount<I_TYPE, W_TYPE>)
343 
344 TF_CALL_INTEGRAL_TYPES(REGISTER_W);
345 TF_CALL_float(REGISTER_W);
346 TF_CALL_double(REGISTER_W);
347 
348 #undef REGISTER_W
349 #undef REGISTER
350 
351 }  // namespace tensorflow
352