1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18
19 #include <limits>
20 #include <string>
21 #include <vector>
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/fingerprint.h"
33 #include "tensorflow/core/platform/strong_hash.h"
34 #include "tensorflow/core/util/work_sharder.h"
35
36 namespace tensorflow {
37
38 namespace {
39 // An interface that represents a column with batches.
40 template <typename InternalType>
41 class ColumnInterface {
42 public:
43 // Returns the number of features in the specified batch.
44 virtual int64 FeatureCount(int64 batch) const = 0;
45
46 // Returns the fingerprint of nth feature from the specified batch.
47 virtual InternalType Feature(int64 batch, int64 n,
48 bool strong_hash) const = 0;
49
~ColumnInterface()50 virtual ~ColumnInterface() {}
51 };
52
53 // A column that is backed by a sparse tensor.
54 template <typename InternalType>
55 class SparseTensorColumn : public ColumnInterface<InternalType> {
56 public:
SparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices)57 SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
58 std::vector<int64> feature_start_indices)
59 : values_(values),
60 feature_counts_(std::move(feature_counts)),
61 feature_start_indices_(std::move(feature_start_indices)) {
62 CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
63 }
64
FeatureCount(int64 batch) const65 int64 FeatureCount(int64 batch) const override {
66 return feature_counts_[batch];
67 }
68
69 InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
70
~SparseTensorColumn()71 ~SparseTensorColumn() override {}
72
73 private:
74 const Tensor& values_;
75 std::vector<int64> feature_counts_;
76 std::vector<int64> feature_start_indices_;
77 };
78
79 // A column that is backed by a sparse tensor.
80 template <typename InternalType>
81 class KeyedSparseTensorColumn : public ColumnInterface<InternalType> {
82 public:
KeyedSparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices,std::vector<int64> key)83 KeyedSparseTensorColumn(const Tensor& values,
84 std::vector<int64> feature_counts,
85 std::vector<int64> feature_start_indices,
86 std::vector<int64> key)
87 : values_(values),
88 feature_counts_(std::move(feature_counts)),
89 feature_start_indices_(std::move(feature_start_indices)) {
90 DCHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
91 std::memcpy(key_, key.data(), sizeof(key_));
92 }
93
FeatureCount(int64 batch) const94 int64 FeatureCount(int64 batch) const override {
95 return feature_counts_[batch];
96 }
97
98 InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
99
~KeyedSparseTensorColumn()100 ~KeyedSparseTensorColumn() override {}
101
102 private:
103 const Tensor& values_;
104 tensorflow::uint64 key_[2];
105 std::vector<int64> feature_counts_;
106 std::vector<int64> feature_start_indices_;
107 };
108
109 // InternalType is int64 only when using HashCrosser.
110 template <>
Feature(int64 batch,int64 n,bool strong_hash) const111 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n,
112 bool strong_hash) const {
113 const int64 start = feature_start_indices_[batch];
114 if (DT_STRING == values_.dtype())
115 return Fingerprint64(values_.vec<tstring>().data()[start + n]);
116 return values_.vec<int64>().data()[start + n];
117 }
118
119 template <>
Feature(int64 batch,int64 n,bool strong_hash) const120 int64 KeyedSparseTensorColumn<int64>::Feature(int64 batch, int64 n,
121 bool strong_hash) const {
122 const int64 start = feature_start_indices_[batch];
123 if (strong_hash) {
124 if (DT_STRING == values_.dtype()) {
125 return StrongKeyedHash(key_, values_.vec<tstring>()(start + n));
126 }
127 return StrongKeyedHash(
128 key_, {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
129 sizeof(values_.dtype())});
130 }
131 if (DT_STRING == values_.dtype())
132 return Fingerprint64(values_.vec<tstring>()(start + n));
133 return Fingerprint64(
134 {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
135 sizeof(values_.dtype())});
136 }
137
138 // InternalType is string or StringPiece when using StringCrosser.
139 template <>
Feature(int64 batch,int64 n,bool strong_hash) const140 tstring SparseTensorColumn<tstring>::Feature(int64 batch, int64 n,
141 bool strong_hash) const {
142 const int64 start = feature_start_indices_[batch];
143 if (DT_STRING == values_.dtype())
144 return values_.vec<tstring>().data()[start + n];
145 return std::to_string(values_.vec<int64>().data()[start + n]);
146 }
147
148 template <>
Feature(int64 batch,int64 n,bool strong_hash) const149 tstring KeyedSparseTensorColumn<tstring>::Feature(int64 batch, int64 n,
150 bool strong_hash) const {
151 const int64 start = feature_start_indices_[batch];
152 if (DT_STRING == values_.dtype())
153 return values_.vec<tstring>().data()[start + n];
154 return std::to_string(values_.vec<int64>().data()[start + n]);
155 }
156
157 template <>
Feature(int64 batch,int64 n,bool strong_hash) const158 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch, int64 n,
159 bool strong_hash) const {
160 const int64 start = feature_start_indices_[batch];
161 return values_.vec<tstring>().data()[start + n];
162 }
163
164 template <>
Feature(int64 batch,int64 n,bool strong_hash) const165 StringPiece KeyedSparseTensorColumn<StringPiece>::Feature(
166 int64 batch, int64 n, bool strong_hash) const {
167 const int64 start = feature_start_indices_[batch];
168 return values_.vec<tstring>().data()[start + n];
169 }
170
171 // A column that is backed by a dense tensor.
172 template <typename InternalType>
173 class DenseTensorColumn : public ColumnInterface<InternalType> {
174 public:
DenseTensorColumn(const Tensor & tensor)175 explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
176
FeatureCount(int64 batch) const177 int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
178
179 InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
180
~DenseTensorColumn()181 ~DenseTensorColumn() override {}
182
183 private:
184 const Tensor& tensor_;
185 };
186
187 // A column that is backed by a dense tensor.
188 template <typename InternalType>
189 class KeyedDenseTensorColumn : public ColumnInterface<InternalType> {
190 public:
KeyedDenseTensorColumn(const Tensor & tensor,std::vector<int64> key)191 explicit KeyedDenseTensorColumn(const Tensor& tensor, std::vector<int64> key)
192 : tensor_(tensor) {
193 std::memcpy(key_, key.data(), sizeof(key_));
194 }
195
FeatureCount(int64 batch) const196 int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
197
198 InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
199
~KeyedDenseTensorColumn()200 ~KeyedDenseTensorColumn() override {}
201
202 private:
203 const Tensor& tensor_;
204 tensorflow::uint64 key_[2];
205 };
206
207 // InternalType is int64 only when using HashCrosser.
208 template <>
Feature(int64 batch,int64 n,bool strong_hash) const209 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n,
210 bool strong_hash) const {
211 if (DT_STRING == tensor_.dtype())
212 return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
213 return tensor_.matrix<int64>()(batch, n);
214 }
215
216 template <>
Feature(int64 batch,int64 n,bool strong_hash) const217 int64 KeyedDenseTensorColumn<int64>::Feature(int64 batch, int64 n,
218 bool strong_hash) const {
219 if (strong_hash) {
220 if (DT_STRING == tensor_.dtype()) {
221 return StrongKeyedHash(key_, tensor_.matrix<tstring>()(batch, n));
222 }
223 return StrongKeyedHash(
224 key_, {reinterpret_cast<const char*>(tensor_.matrix<int64>()(batch, n)),
225 sizeof(tensor_.dtype())});
226 }
227 if (DT_STRING == tensor_.dtype())
228 return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
229 return tensor_.matrix<int64>()(batch, n);
230 }
231
232 // Internal type is string or StringPiece when using StringCrosser.
233 template <>
Feature(int64 batch,int64 n,bool strong_hash) const234 tstring DenseTensorColumn<tstring>::Feature(int64 batch, int64 n,
235 bool strong_hash) const {
236 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
237 return std::to_string(tensor_.matrix<int64>()(batch, n));
238 }
239
240 template <>
Feature(int64 batch,int64 n,bool strong_hash) const241 tstring KeyedDenseTensorColumn<tstring>::Feature(int64 batch, int64 n,
242 bool strong_hash) const {
243 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
244 return std::to_string(tensor_.matrix<int64>()(batch, n));
245 }
246
247 template <>
Feature(int64 batch,int64 n,bool strong_hash) const248 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch, int64 n,
249 bool strong_hash) const {
250 return tensor_.matrix<tstring>()(batch, n);
251 }
252
253 template <>
Feature(int64 batch,int64 n,bool strong_hash) const254 StringPiece KeyedDenseTensorColumn<StringPiece>::Feature(
255 int64 batch, int64 n, bool strong_hash) const {
256 return tensor_.matrix<tstring>()(batch, n);
257 }
258
259 // Updates Output tensors with sparse crosses.
260 template <typename OutType>
261 class OutputUpdater {
262 public:
OutputUpdater(const std::vector<int64> & output_start_indices,Tensor * indices_out,Tensor * values_out)263 OutputUpdater(const std::vector<int64>& output_start_indices,
264 Tensor* indices_out, Tensor* values_out)
265 : output_start_indices_(output_start_indices),
266 indices_out_(indices_out),
267 values_out_(values_out) {}
268
Update(const int64 batch_index,const int64 cross_count,const OutType & cross) const269 void Update(const int64 batch_index, const int64 cross_count,
270 const OutType& cross) const {
271 const int64 output_index = output_start_indices_[batch_index] + cross_count;
272
273 auto indices_matrix = indices_out_->matrix<int64>();
274 indices_matrix(output_index, 0) = batch_index;
275 indices_matrix(output_index, 1) = cross_count;
276
277 auto value_vec = values_out_->vec<OutType>();
278 value_vec(output_index) = cross;
279 }
280
281 private:
282 const std::vector<int64>& output_start_indices_;
283 Tensor* indices_out_;
284 Tensor* values_out_;
285 };
286
287 // Generates the sparse crosses as concatenation of strings.
288 template <typename InternalType>
289 class StringCrosser {
290 public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64 num_buckets_unused,const uint64 hash_key_unused,const tstring k_feature_separator)291 StringCrosser(const std::vector<
292 std::unique_ptr<ColumnInterface<InternalType>>>& columns,
293 const int64 num_buckets_unused, const uint64 hash_key_unused,
294 const tstring k_feature_separator)
295 : columns_(columns), k_feature_separator_(k_feature_separator) {}
296
Generate(const int64 batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const297 string Generate(const int64 batch_index, const std::vector<int>& permutation,
298 bool unused_strong_hash) const {
299 gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
300 for (int i = 0; i < permutation.size(); i++) {
301 cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i], false);
302 }
303 // TODO(zakaria): this will copy the string twice, might effect
304 // performance.
305 return absl::StrJoin(cross_vec, k_feature_separator_);
306 }
307
308 private:
309 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
310 const tstring k_feature_separator_;
311 };
312
313 // Generates the sparse crosses as nested hash to avoid string manipulations.
314 class HashCrosser {
315 public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key,const tstring k_feature_separator_unused)316 HashCrosser(
317 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
318 const int64 num_buckets, const uint64 hash_key,
319 const tstring k_feature_separator_unused)
320 : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
321
Generate(const int64 batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const322 int64 Generate(const int64 batch_index, const std::vector<int>& permutation,
323 bool unused_strong_hash) const {
324 // Do the fingerprint concatenation on uint64.
325 uint64 hashed_output = hash_key_;
326 for (size_t i = 0; i < permutation.size(); ++i) {
327 uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i], false);
328 hashed_output = FingerprintCat64(hashed_output, hash_i);
329 }
330 // The return value is int64 based on the number of buckets.
331 if (num_buckets_ > 0) {
332 return hashed_output % num_buckets_;
333 } else {
334 // To prevent negative output we take modulo to max int64.
335 return hashed_output % std::numeric_limits<int64>::max();
336 }
337 }
338
339 private:
340 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
341 const int64 num_buckets_;
342 const uint64 hash_key_;
343 };
344
345 // Generates the sparse crosses as nested hash to avoid string manipulations.
346 class HashCrosserV2 {
347 public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key_unused,const tstring k_feature_separator_unused)348 HashCrosserV2(
349 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
350 const int64 num_buckets, const uint64 hash_key_unused,
351 const tstring k_feature_separator_unused)
352 : columns_(columns), num_buckets_(num_buckets) {}
353
Generate(const int64 batch_index,const std::vector<int> & permutation,bool strong_hash) const354 int64 Generate(const int64 batch_index, const std::vector<int>& permutation,
355 bool strong_hash) const {
356 // Do the fingerprint concatenation on uint64.
357 uint64 hashed_output =
358 columns_[0]->Feature(batch_index, permutation[0], strong_hash);
359 for (size_t i = 1; i < permutation.size(); ++i) {
360 uint64 hash_i =
361 columns_[i]->Feature(batch_index, permutation[i], strong_hash);
362 hashed_output = FingerprintCat64(hashed_output, hash_i);
363 }
364 // The return value is int64 based on the number of buckets.
365 if (num_buckets_ > 0) {
366 return hashed_output % num_buckets_;
367 } else {
368 // To prevent negative output we take modulo to max int64.
369 return hashed_output % std::numeric_limits<int64>::max();
370 }
371 }
372
373 private:
374 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
375 const int64 num_buckets_;
376 };
377
378 // ProductIterator generates cartesian products based on indices.
379 template <typename InternalType>
380 class ProductIterator {
381 public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_index)382 explicit ProductIterator(
383 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
384 columns,
385 int64 batch_index)
386 : columns_(columns), batch_index_(batch_index) {
387 next_permutation_.resize(columns_.size(), 0);
388 // Sets has_next_ to false if any feature column has 0 features.
389 has_next_ = true;
390 for (int i = 0; i < columns_.size(); i++) {
391 if (columns_[i]->FeatureCount(batch_index_) == 0) {
392 has_next_ = false;
393 break;
394 }
395 }
396 }
397
Next()398 std::vector<int> Next() {
399 std::vector<int> permutation(next_permutation_);
400
401 // Generates next permutation, if available.
402 bool carry = true;
403 for (int i = next_permutation_.size() - 1; i >= 0; i--) {
404 if (carry) {
405 next_permutation_[i] = next_permutation_[i] + 1;
406 }
407 if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
408 next_permutation_[i] = 0;
409 } else {
410 carry = false;
411 break;
412 }
413 }
414 has_next_ = !carry;
415 return permutation;
416 }
417
HasNext()418 bool HasNext() { return has_next_; }
419
420 private:
421 bool has_next_;
422 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
423 const int64 batch_index_;
424 std::vector<int> next_permutation_;
425 };
426
427 template <bool HASHED_OUTPUT, typename InternalType>
428 struct CrossTraits;
429
430 template <typename InternalType>
431 struct CrossTraits<false, InternalType> {
432 typedef StringCrosser<InternalType> Crosser;
433 typedef StringCrosser<InternalType> CrosserV2;
434 typedef OutputUpdater<tstring> Updater;
435 };
436
437 template <>
438 struct CrossTraits<true, int64> {
439 typedef HashCrosser Crosser;
440 typedef HashCrosserV2 CrosserV2;
441 typedef OutputUpdater<int64> Updater;
442 };
443 } // namespace
444
445 // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)446 int64 CalculateBatchSize(const OpInputList& shapes_list_in,
447 const OpInputList& dense_list_in) {
448 if (shapes_list_in.size() > 0) {
449 return shapes_list_in[0].vec<int64>()(0);
450 }
451
452 if (dense_list_in.size() > 0) {
453 return dense_list_in[0].dim_size(0);
454 }
455
456 return 0;
457 }
458
459 // Validates input tensors.
ValidateInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)460 Status ValidateInput(const OpInputList& indices_list_in,
461 const OpInputList& values_list_in,
462 const OpInputList& shapes_list_in,
463 const OpInputList& dense_list_in) {
464 const auto size = indices_list_in.size();
465 // Validates indices_list_in OpInputList.
466 for (int i = 0; i < size; i++) {
467 if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
468 return errors::InvalidArgument(
469 "Input indices should be a matrix but received shape ",
470 indices_list_in[i].shape().DebugString(), " at position ", i);
471 }
472 if (indices_list_in[i].shape().dim_size(1) != 2) {
473 return errors::InvalidArgument("Expected D2 of index to be 2 got ",
474 indices_list_in[i].shape().dim_size(1),
475 " at position ", i);
476 }
477 }
478
479 // Validates values_list_in OpInputList.
480 if (values_list_in.size() != size) {
481 return errors::InvalidArgument("Expected ", size, " input values, got ",
482 values_list_in.size());
483 }
484 for (int i = 0; i < size; i++) {
485 if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
486 return errors::InvalidArgument(
487 "Input values should be a vector but received shape ",
488 values_list_in[i].shape().DebugString(), " at position ", i);
489 }
490 if (indices_list_in[i].shape().dim_size(0) !=
491 values_list_in[i].shape().dim_size(0)) {
492 return errors::InvalidArgument(
493 "Expected size of values to be ",
494 indices_list_in[i].shape().dim_size(0), " got ",
495 values_list_in[i].shape().dim_size(0), " at position ", i);
496 }
497 }
498
499 // Validates shapes_list_in OpInputList
500 if (shapes_list_in.size() != size) {
501 return errors::InvalidArgument("Expected ", size, " input shapes, got ",
502 shapes_list_in.size());
503 }
504 for (int i = 0; i < size; i++) {
505 if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
506 return errors::InvalidArgument(
507 "Input shapes should be a vector but received shape ",
508 shapes_list_in[i].shape().DebugString(), " at position ", i);
509 }
510
511 if (shapes_list_in[i].vec<int64>().size() != 2) {
512 return errors::InvalidArgument("shape should imply a 2D tensor, but got ",
513 shapes_list_in[i].shape().DebugString(),
514 " at position ", i);
515 }
516 }
517
518 // Validates dense_list_in OpInputList
519 for (int i = 0; i < dense_list_in.size(); ++i) {
520 if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
521 return errors::InvalidArgument(
522 "Dense inputs should be a matrix but received shape ",
523 dense_list_in[i].shape().DebugString(), " at position ", i);
524 }
525 }
526
527 // Validates batch sizes. (Note: we do this after validating the input
528 // shapes, because CalculateBatchSize() depends on inputs having valid
529 // shapes).
530 const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
531 for (int i = 0; i < size; i++) {
532 if (shapes_list_in[i].vec<int64>()(0) != batch_size) {
533 return errors::InvalidArgument("Expected batch size ", batch_size,
534 " got ", shapes_list_in[i].vec<int64>()(0),
535 " at position ", i);
536 }
537 }
538 for (int i = 0; i < dense_list_in.size(); ++i) {
539 if (dense_list_in[i].dim_size(0) != batch_size) {
540 return errors::InvalidArgument("Expected batch size ", batch_size,
541 " got ", dense_list_in[i].dim_size(0),
542 " at dense tensor ", i);
543 }
544 }
545
546 return Status::OK();
547 }
548
549 // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64 batch_size,std::vector<std::vector<int64>> * feature_counts,std::vector<std::vector<int64>> * feature_start_indices)550 void ExtractFeatureData(
551 const OpInputList& indices_list_in, int64 batch_size,
552 std::vector<std::vector<int64>>* feature_counts,
553 std::vector<std::vector<int64>>* feature_start_indices) {
554 gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
555 for (int b = 0; b < batch_size; b++) {
556 for (int i = 0; i < indices_list_in.size(); i++) {
557 const auto indices = indices_list_in[i].matrix<int64>();
558 int64 feature_count = 0;
559 int64 start_index = current_row[i];
560 // Loops until we reach next batch index for current feature column.
561 while (current_row[i] < indices_list_in[i].dim_size(0) &&
562 indices(current_row[i], 0) == b) {
563 feature_count++;
564 current_row[i]++;
565 }
566 (*feature_counts)[i].push_back(feature_count);
567 (*feature_start_indices)[i].push_back(start_index);
568 }
569 }
570 }
571
572 // Returns number of crosses for a given batch_index
573 template <typename InternalType>
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)574 int64 CrossCountByBatchIndex(
575 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
576 int batch_index) {
577 int64 cross_count = 1;
578 for (int i = 0; i < columns.size(); i++) {
579 const auto feature_count = columns[i]->FeatureCount(batch_index);
580 // If one column is missing any feature, there won't be any cross.
581 if (feature_count == 0) {
582 return 0;
583 }
584 cross_count *= feature_count;
585 }
586 return cross_count;
587 }
588
589 // Generate the columns given the sparse and dense inputs.
590 template <typename InternalType>
591 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)592 GenerateColumnsFromInput(const OpInputList& indices_list_in,
593 const OpInputList& values_list_in,
594 const OpInputList& shapes_list_in,
595 const OpInputList& dense_list_in) {
596 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
597 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
598 const int64 number_of_columns = shapes_list_in.size();
599
600 std::vector<std::vector<int64>> feature_counts(number_of_columns,
601 std::vector<int64>());
602 std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
603 std::vector<int64>());
604
605 ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
606 &feature_start_indices);
607
608 columns.reserve(values_list_in.size());
609 for (int i = 0; i < values_list_in.size(); ++i) {
610 columns.emplace_back(new SparseTensorColumn<InternalType>(
611 values_list_in[i], std::move(feature_counts[i]),
612 std::move(feature_start_indices[i])));
613 }
614 for (int i = 0; i < dense_list_in.size(); ++i) {
615 columns.emplace_back(new DenseTensorColumn<InternalType>(dense_list_in[i]));
616 }
617
618 return columns;
619 }
620
621 // Generate the columns given the sparse and dense inputs.
622 template <typename InternalType>
623 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateKeyedColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,std::vector<int64> keys)624 GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in,
625 const OpInputList& values_list_in,
626 const OpInputList& shapes_list_in,
627 const OpInputList& dense_list_in,
628 std::vector<int64> keys) {
629 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
630 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
631 const int64 number_of_columns = shapes_list_in.size();
632
633 std::vector<std::vector<int64>> feature_counts(number_of_columns,
634 std::vector<int64>());
635 std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
636 std::vector<int64>());
637
638 ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
639 &feature_start_indices);
640
641 columns.reserve(values_list_in.size());
642 for (int i = 0; i < values_list_in.size(); ++i) {
643 columns.emplace_back(new KeyedSparseTensorColumn<InternalType>(
644 values_list_in[i], std::move(feature_counts[i]),
645 std::move(feature_start_indices[i]), keys));
646 }
647 for (int i = 0; i < dense_list_in.size(); ++i) {
648 columns.emplace_back(
649 new KeyedDenseTensorColumn<InternalType>(dense_list_in[i], keys));
650 }
651
652 return columns;
653 }
654
655 // Allocates output tensors with proper size and sets the shape tensor of
656 // the output SparseTensor.
657 // It also output_start_indices which contains the start indices for each
658 // input in the output SparseTensor.
659 template <typename InternalType>
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64> * output_start_indices)660 Status CreateOutputTensors(
661 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
662 int64 batch_size, OpKernelContext* context, Tensor** indices_out,
663 Tensor** values_out, Tensor** shape_out,
664 std::vector<int64>* output_start_indices) {
665 // Calculates dimensions for output tensors.
666 int64 cross_count_total = 0;
667 int64 max_cross_count = 0;
668 for (int64 b = 0; b < batch_size; b++) {
669 // For each input, sets starting indices in output SparseTensor
670 (*output_start_indices)[b] = cross_count_total;
671 const auto cross_count = CrossCountByBatchIndex(columns, b);
672 max_cross_count = std::max(max_cross_count, cross_count);
673 cross_count_total += cross_count;
674 }
675
676 // Allocates tensors.
677 TF_RETURN_IF_ERROR(context->allocate_output(
678 0, TensorShape({cross_count_total, 2}), indices_out));
679 TF_RETURN_IF_ERROR(context->allocate_output(
680 1, TensorShape({cross_count_total}), values_out));
681 TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({2}), shape_out));
682
683 // Sets shape.
684 auto shape_vec = (*shape_out)->vec<int64>();
685 shape_vec(0) = batch_size;
686 shape_vec(1) = max_cross_count;
687
688 return Status::OK();
689 }
690
691 template <bool HASHED_OUTPUT, typename InternalType>
692 class SparseCrossOp : public OpKernel {
693 public:
SparseCrossOp(OpKernelConstruction * context)694 explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
695 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
696 // Read signed_hash_key_ as int64 since uint64 attributes are not
697 // supported by REGISTER_OP.
698 int64 signed_hash_key_;
699 OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
700 hash_key_ = static_cast<uint64>(signed_hash_key_);
701 }
702
Compute(OpKernelContext * context)703 void Compute(OpKernelContext* context) override {
704 OpInputList indices_list_in;
705 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
706 OpInputList values_list_in;
707 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
708 OpInputList shapes_list_in;
709 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
710 OpInputList dense_list_in;
711 OP_REQUIRES_OK(context,
712 context->input_list("dense_inputs", &dense_list_in));
713
714 OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
715 shapes_list_in, dense_list_in));
716
717 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
718 GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
719 shapes_list_in, dense_list_in);
720
721 const tstring k_feature_separator = "_X_";
722 typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
723 columns, num_buckets_, hash_key_, k_feature_separator);
724 Tensor* indices_out;
725 Tensor* values_out;
726 Tensor* shape_out;
727 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
728 std::vector<int64> output_start_indices(batch_size);
729 OP_REQUIRES_OK(
730 context,
731 CreateOutputTensors(columns, batch_size, context, &indices_out,
732 &values_out, &shape_out, &output_start_indices));
733
734 typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
735 output_start_indices, indices_out, values_out);
736 auto do_work = [&columns, crosser, updater](int64 begin, int64 end) {
737 for (int b = begin; b < end; b++) {
738 ProductIterator<InternalType> product_iterator(columns, b);
739 int64 cross_count = 0;
740 while (product_iterator.HasNext()) {
741 const auto permutation = product_iterator.Next();
742 updater.Update(b, cross_count,
743 crosser.Generate(b, permutation, false));
744 cross_count++;
745 }
746 }
747 };
748
749 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
750 // TODO(zakaria): optimize kCostPerUnit
751 const int kCostPerUnit = 5000 * indices_list_in.size();
752 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
753 kCostPerUnit, do_work);
754 }
755
756 private:
757 int64 num_buckets_;
758 uint64 hash_key_;
759 };
760
761 class SparseCrossV2Op : public OpKernel {
762 public:
SparseCrossV2Op(OpKernelConstruction * context)763 explicit SparseCrossV2Op(OpKernelConstruction* context) : OpKernel(context) {}
764
Compute(OpKernelContext * context)765 void Compute(OpKernelContext* context) override {
766 OpInputList indices_list_in;
767 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
768 OpInputList values_list_in;
769 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
770 OpInputList shapes_list_in;
771 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
772 OpInputList dense_list_in;
773 OP_REQUIRES_OK(context,
774 context->input_list("dense_inputs", &dense_list_in));
775
776 OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
777 shapes_list_in, dense_list_in));
778
779 const Tensor* sep_t;
780 OP_REQUIRES_OK(context, context->input("sep", &sep_t));
781 const tstring separator = sep_t->scalar<tstring>()();
782
783 std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns =
784 GenerateColumnsFromInput<tstring>(indices_list_in, values_list_in,
785 shapes_list_in, dense_list_in);
786 Tensor* indices_out;
787 Tensor* values_out;
788 Tensor* shape_out;
789 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
790 std::vector<int64> output_start_indices(batch_size);
791 OP_REQUIRES_OK(
792 context,
793 CreateOutputTensors(columns, batch_size, context, &indices_out,
794 &values_out, &shape_out, &output_start_indices));
795 StringCrosser<tstring> crosser(columns, 0, 0, separator);
796 OutputUpdater<tstring> updater(output_start_indices, indices_out,
797 values_out);
798 auto do_work = [&columns, crosser, updater](int64 begin, int64 end) {
799 for (int b = begin; b < end; b++) {
800 ProductIterator<tstring> product_iterator(columns, b);
801 int64 cross_count = 0;
802 while (product_iterator.HasNext()) {
803 const auto permutation = product_iterator.Next();
804 updater.Update(b, cross_count,
805 crosser.Generate(b, permutation, false));
806 cross_count++;
807 }
808 }
809 };
810
811 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
812 // TODO(zakaria): optimize kCostPerUnit
813 const int kCostPerUnit = 5000 * indices_list_in.size();
814 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
815 kCostPerUnit, do_work);
816 }
817 };
818
819 class SparseCrossHashedOp : public OpKernel {
820 public:
SparseCrossHashedOp(OpKernelConstruction * context)821 explicit SparseCrossHashedOp(OpKernelConstruction* context)
822 : OpKernel(context) {}
823
Compute(OpKernelContext * context)824 void Compute(OpKernelContext* context) override {
825 OpInputList indices_list_in;
826 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
827 OpInputList values_list_in;
828 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
829 OpInputList shapes_list_in;
830 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
831 OpInputList dense_list_in;
832 OP_REQUIRES_OK(context,
833 context->input_list("dense_inputs", &dense_list_in));
834
835 OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
836 shapes_list_in, dense_list_in));
837
838 const Tensor* num_buckets_t;
839 OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));
840 const int64 num_buckets = num_buckets_t->scalar<int64>()();
841
842 const Tensor* strong_hash_t;
843 OP_REQUIRES_OK(context, context->input("strong_hash", &strong_hash_t));
844 const bool strong_hash = strong_hash_t->scalar<bool>()();
845
846 const Tensor* salt_t;
847 OP_REQUIRES_OK(context, context->input("salt", &salt_t));
848 const auto salt = salt_t->flat<int64>();
849 std::vector<int64> key_{salt(0), salt(1)};
850
851 std::vector<std::unique_ptr<ColumnInterface<int64>>> columns =
852 GenerateKeyedColumnsFromInput<int64>(indices_list_in, values_list_in,
853 shapes_list_in, dense_list_in,
854 key_);
855 Tensor* indices_out;
856 Tensor* values_out;
857 Tensor* shape_out;
858 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
859 std::vector<int64> output_start_indices(batch_size);
860 OP_REQUIRES_OK(
861 context,
862 CreateOutputTensors(columns, batch_size, context, &indices_out,
863 &values_out, &shape_out, &output_start_indices));
864 const tstring unused_sep;
865 HashCrosserV2 crosser(columns, num_buckets, 0, unused_sep);
866 OutputUpdater<int64> updater(output_start_indices, indices_out, values_out);
867 auto do_work = [&columns, crosser, updater, strong_hash](int64 begin,
868 int64 end) {
869 for (int b = begin; b < end; b++) {
870 ProductIterator<int64> product_iterator(columns, b);
871 int64 cross_count = 0;
872 while (product_iterator.HasNext()) {
873 const auto permutation = product_iterator.Next();
874 updater.Update(b, cross_count,
875 crosser.Generate(b, permutation, strong_hash));
876 cross_count++;
877 }
878 }
879 };
880
881 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
882 // TODO(zakaria): optimize kCostPerUnit
883 const int kCostPerUnit = 5000 * indices_list_in.size();
884 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
885 kCostPerUnit, do_work);
886 }
887 };
888
889 REGISTER_KERNEL_BUILDER(Name("SparseCross")
890 .Device(DEVICE_CPU)
891 .TypeConstraint<tstring>("out_type")
892 .TypeConstraint<tstring>("internal_type"),
893 SparseCrossOp<false, StringPiece>);
894
895 REGISTER_KERNEL_BUILDER(Name("SparseCross")
896 .Device(DEVICE_CPU)
897 .TypeConstraint<tstring>("out_type")
898 .TypeConstraint<int64>("internal_type"),
899 SparseCrossOp<false, tstring>);
900
901 REGISTER_KERNEL_BUILDER(Name("SparseCross")
902 .Device(DEVICE_CPU)
903 .TypeConstraint<int64>("out_type")
904 .TypeConstraint<tstring>("internal_type"),
905 SparseCrossOp<true, int64>);
906
907 REGISTER_KERNEL_BUILDER(Name("SparseCross")
908 .Device(DEVICE_CPU)
909 .TypeConstraint<int64>("out_type")
910 .TypeConstraint<int64>("internal_type"),
911 SparseCrossOp<true, int64>);
912
913 REGISTER_KERNEL_BUILDER(Name("SparseCrossV2").Device(DEVICE_CPU),
914 SparseCrossV2Op);
915
916 REGISTER_KERNEL_BUILDER(Name("SparseCrossHashed").Device(DEVICE_CPU),
917 SparseCrossHashedOp);
918
919 } // namespace tensorflow
920