1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/container/flat_hash_map.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/op_requires.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/types.h"
23
24 namespace tensorflow {
25
26 template <class T>
27 using BatchedMap = std::vector<absl::flat_hash_map<int64, T>>;
28
29 namespace {
30 // TODO(momernick): Extend this function to work with outputs of rank > 2.
31 template <class T>
OutputSparse(const BatchedMap<T> & per_batch_counts,int num_values,bool is_1d,OpKernelContext * context)32 Status OutputSparse(const BatchedMap<T>& per_batch_counts, int num_values,
33 bool is_1d, OpKernelContext* context) {
34 int total_values = 0;
35 int num_batches = per_batch_counts.size();
36 for (const auto& per_batch_count : per_batch_counts) {
37 total_values += per_batch_count.size();
38 }
39
40 Tensor* indices;
41 int inner_dim = is_1d ? 1 : 2;
42 TF_RETURN_IF_ERROR(context->allocate_output(
43 0, TensorShape({total_values, inner_dim}), &indices));
44
45 Tensor* values;
46 TF_RETURN_IF_ERROR(
47 context->allocate_output(1, TensorShape({total_values}), &values));
48
49 auto output_indices = indices->matrix<int64>();
50 auto output_values = values->flat<T>();
51 int64 value_loc = 0;
52 for (int b = 0; b < num_batches; ++b) {
53 const auto& per_batch_count = per_batch_counts[b];
54 std::vector<std::pair<int, T>> pairs(per_batch_count.begin(),
55 per_batch_count.end());
56 std::sort(pairs.begin(), pairs.end());
57 for (const auto& x : pairs) {
58 if (is_1d) {
59 output_indices(value_loc, 0) = x.first;
60 } else {
61 output_indices(value_loc, 0) = b;
62 output_indices(value_loc, 1) = x.first;
63 }
64 output_values(value_loc) = x.second;
65 ++value_loc;
66 }
67 }
68 Tensor* dense_shape;
69 if (is_1d) {
70 TF_RETURN_IF_ERROR(
71 context->allocate_output(2, TensorShape({1}), &dense_shape));
72 dense_shape->flat<int64>().data()[0] = num_values;
73 } else {
74 TF_RETURN_IF_ERROR(
75 context->allocate_output(2, TensorShape({2}), &dense_shape));
76 dense_shape->flat<int64>().data()[0] = num_batches;
77 dense_shape->flat<int64>().data()[1] = num_values;
78 }
79
80 return Status::OK();
81 }
82
GetOutputSize(int max_seen,int max_length,int min_length)83 int GetOutputSize(int max_seen, int max_length, int min_length) {
84 return max_length > 0 ? max_length : std::max((max_seen + 1), min_length);
85 }
86
87 } // namespace
88
89 template <class T, class W>
90 class DenseCount : public OpKernel {
91 public:
DenseCount(OpKernelConstruction * context)92 explicit DenseCount(OpKernelConstruction* context) : OpKernel(context) {
93 OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
94 OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
95 OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
96 }
97
Compute(OpKernelContext * context)98 void Compute(OpKernelContext* context) override {
99 const Tensor& data = context->input(0);
100 const Tensor& weights = context->input(1);
101 bool use_weights = weights.NumElements() > 0;
102
103 OP_REQUIRES(context,
104 TensorShapeUtils::IsVector(data.shape()) ||
105 TensorShapeUtils::IsMatrix(data.shape()),
106 errors::InvalidArgument(
107 "Input must be a 1 or 2-dimensional tensor. Got: ",
108 data.shape().DebugString()));
109
110 if (use_weights) {
111 OP_REQUIRES(
112 context, weights.shape() == data.shape(),
113 errors::InvalidArgument(
114 "Weights and data must have the same shape. Weight shape: ",
115 weights.shape().DebugString(),
116 "; data shape: ", data.shape().DebugString()));
117 }
118
119 bool is_1d = TensorShapeUtils::IsVector(data.shape());
120 int negative_valued_axis = -1;
121 int num_batch_dimensions = (data.shape().dims() + negative_valued_axis);
122
123 int num_batch_elements = 1;
124 for (int i = 0; i < num_batch_dimensions; ++i) {
125 num_batch_elements *= data.shape().dim_size(i);
126 }
127 int num_value_elements = data.shape().num_elements() / num_batch_elements;
128 auto per_batch_counts = BatchedMap<W>(num_batch_elements);
129
130 T max_value = 0;
131
132 const auto data_values = data.flat<T>();
133 const auto weight_values = weights.flat<W>();
134 int i = 0;
135 for (int b = 0; b < num_batch_elements; ++b) {
136 for (int v = 0; v < num_value_elements; ++v) {
137 const auto& value = data_values(i);
138 if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
139 if (binary_output_) {
140 per_batch_counts[b][value] = 1;
141 } else if (use_weights) {
142 per_batch_counts[b][value] += weight_values(i);
143 } else {
144 per_batch_counts[b][value]++;
145 }
146 if (value > max_value) {
147 max_value = value;
148 }
149 }
150 ++i;
151 }
152 }
153
154 int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
155 OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
156 is_1d, context));
157 }
158
159 private:
160 int maxlength_;
161 int minlength_;
162 bool binary_output_;
163 };
164
165 template <class T, class W>
166 class SparseCount : public OpKernel {
167 public:
SparseCount(OpKernelConstruction * context)168 explicit SparseCount(OpKernelConstruction* context) : OpKernel(context) {
169 OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
170 OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
171 OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
172 }
173
Compute(OpKernelContext * context)174 void Compute(OpKernelContext* context) override {
175 const Tensor& indices = context->input(0);
176 const Tensor& values = context->input(1);
177 const Tensor& shape = context->input(2);
178 const Tensor& weights = context->input(3);
179 bool use_weights = weights.NumElements() > 0;
180
181 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
182 errors::InvalidArgument(
183 "Input indices must be a 2-dimensional tensor. Got: ",
184 indices.shape().DebugString()));
185
186 if (use_weights) {
187 OP_REQUIRES(
188 context, weights.shape() == values.shape(),
189 errors::InvalidArgument(
190 "Weights and values must have the same shape. Weight shape: ",
191 weights.shape().DebugString(),
192 "; values shape: ", values.shape().DebugString()));
193 }
194
195 bool is_1d = shape.NumElements() == 1;
196 int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
197 int num_values = values.NumElements();
198
199 OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
200 errors::InvalidArgument(
201 "Number of values must match first dimension of indices.",
202 "Got ", num_values,
203 " values, indices shape: ", indices.shape().DebugString()));
204
205 const auto indices_values = indices.matrix<int64>();
206 const auto values_values = values.flat<T>();
207 const auto weight_values = weights.flat<W>();
208
209 auto per_batch_counts = BatchedMap<W>(num_batches);
210
211 T max_value = 0;
212
213 for (int idx = 0; idx < num_values; ++idx) {
214 int batch = is_1d ? 0 : indices_values(idx, 0);
215 const auto& value = values_values(idx);
216 if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
217 if (binary_output_) {
218 per_batch_counts[batch][value] = 1;
219 } else if (use_weights) {
220 per_batch_counts[batch][value] += weight_values(idx);
221 } else {
222 per_batch_counts[batch][value]++;
223 }
224 if (value > max_value) {
225 max_value = value;
226 }
227 }
228 }
229
230 int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
231 OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
232 is_1d, context));
233 }
234
235 private:
236 int maxlength_;
237 int minlength_;
238 bool binary_output_;
239 bool validate_;
240 };
241
242 template <class T, class W>
243 class RaggedCount : public OpKernel {
244 public:
RaggedCount(OpKernelConstruction * context)245 explicit RaggedCount(OpKernelConstruction* context) : OpKernel(context) {
246 OP_REQUIRES_OK(context, context->GetAttr("minlength", &minlength_));
247 OP_REQUIRES_OK(context, context->GetAttr("maxlength", &maxlength_));
248 OP_REQUIRES_OK(context, context->GetAttr("binary_output", &binary_output_));
249 }
250
Compute(OpKernelContext * context)251 void Compute(OpKernelContext* context) override {
252 const Tensor& splits = context->input(0);
253 const Tensor& values = context->input(1);
254 const Tensor& weights = context->input(2);
255 bool use_weights = weights.NumElements() > 0;
256 bool is_1d = false;
257
258 if (use_weights) {
259 OP_REQUIRES(
260 context, weights.shape() == values.shape(),
261 errors::InvalidArgument(
262 "Weights and values must have the same shape. Weight shape: ",
263 weights.shape().DebugString(),
264 "; values shape: ", values.shape().DebugString()));
265 }
266
267 const auto splits_values = splits.flat<int64>();
268 const auto values_values = values.flat<T>();
269 const auto weight_values = weights.flat<W>();
270 int num_batches = splits.NumElements() - 1;
271 int num_values = values.NumElements();
272
273 OP_REQUIRES(
274 context, num_batches > 0,
275 errors::InvalidArgument(
276 "Must provide at least 2 elements for the splits argument"));
277 OP_REQUIRES(context, splits_values(0) == 0,
278 errors::InvalidArgument("Splits must start with 0, not with ",
279 splits_values(0)));
280 OP_REQUIRES(context, splits_values(num_batches) == num_values,
281 errors::InvalidArgument(
282 "Splits must end with the number of values, got ",
283 splits_values(num_batches), " instead of ", num_values));
284
285 auto per_batch_counts = BatchedMap<W>(num_batches);
286 T max_value = 0;
287 int batch_idx = 0;
288
289 for (int idx = 0; idx < num_values; ++idx) {
290 while (idx >= splits_values(batch_idx)) {
291 batch_idx++;
292 }
293 const auto& value = values_values(idx);
294 if (value >= 0 && (maxlength_ <= 0 || value < maxlength_)) {
295 if (binary_output_) {
296 per_batch_counts[batch_idx - 1][value] = 1;
297 } else if (use_weights) {
298 per_batch_counts[batch_idx - 1][value] += weight_values(idx);
299 } else {
300 per_batch_counts[batch_idx - 1][value]++;
301 }
302 if (value > max_value) {
303 max_value = value;
304 }
305 }
306 }
307
308 int num_output_values = GetOutputSize(max_value, maxlength_, minlength_);
309 OP_REQUIRES_OK(context, OutputSparse<W>(per_batch_counts, num_output_values,
310 is_1d, context));
311 }
312
313 private:
314 int maxlength_;
315 int minlength_;
316 bool binary_output_;
317 bool validate_;
318 };
319
320 #define REGISTER_W(W_TYPE) \
321 REGISTER(int32, W_TYPE) \
322 REGISTER(int64, W_TYPE)
323
324 #define REGISTER(I_TYPE, W_TYPE) \
325 \
326 REGISTER_KERNEL_BUILDER(Name("DenseCountSparseOutput") \
327 .TypeConstraint<I_TYPE>("T") \
328 .TypeConstraint<W_TYPE>("output_type") \
329 .Device(DEVICE_CPU), \
330 DenseCount<I_TYPE, W_TYPE>) \
331 \
332 REGISTER_KERNEL_BUILDER(Name("SparseCountSparseOutput") \
333 .TypeConstraint<I_TYPE>("T") \
334 .TypeConstraint<W_TYPE>("output_type") \
335 .Device(DEVICE_CPU), \
336 SparseCount<I_TYPE, W_TYPE>) \
337 \
338 REGISTER_KERNEL_BUILDER(Name("RaggedCountSparseOutput") \
339 .TypeConstraint<I_TYPE>("T") \
340 .TypeConstraint<W_TYPE>("output_type") \
341 .Device(DEVICE_CPU), \
342 RaggedCount<I_TYPE, W_TYPE>)
343
344 TF_CALL_INTEGRAL_TYPES(REGISTER_W);
345 TF_CALL_float(REGISTER_W);
346 TF_CALL_double(REGISTER_W);
347
348 #undef REGISTER_W
349 #undef REGISTER
350
351 } // namespace tensorflow
352