1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18
19 #include <limits>
20 #include <string>
21 #include <vector>
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/fingerprint.h"
34 #include "tensorflow/core/platform/strong_hash.h"
35 #include "tensorflow/core/util/work_sharder.h"
36
37 namespace tensorflow {
38
39 namespace {
40 // An interface that represents a column with batches.
41 template <typename InternalType>
42 class ColumnInterface {
43 public:
44 // Returns the number of features in the specified batch.
45 virtual int64 FeatureCount(int64_t batch) const = 0;
46
47 // Returns the fingerprint of nth feature from the specified batch.
48 virtual InternalType Feature(int64_t batch, int64_t n,
49 bool strong_hash) const = 0;
50
~ColumnInterface()51 virtual ~ColumnInterface() {}
52 };
53
54 // A column that is backed by a sparse tensor.
55 template <typename InternalType>
56 class SparseTensorColumn : public ColumnInterface<InternalType> {
57 public:
SparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices)58 SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
59 std::vector<int64> feature_start_indices)
60 : values_(values),
61 feature_counts_(std::move(feature_counts)),
62 feature_start_indices_(std::move(feature_start_indices)) {
63 CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
64 }
65
FeatureCount(int64_t batch) const66 int64 FeatureCount(int64_t batch) const override {
67 return feature_counts_[batch];
68 }
69
70 InternalType Feature(int64_t batch, int64_t n,
71 bool strong_hash) const override;
72
~SparseTensorColumn()73 ~SparseTensorColumn() override {}
74
75 private:
76 const Tensor& values_;
77 std::vector<int64> feature_counts_;
78 std::vector<int64> feature_start_indices_;
79 };
80
81 // A column that is backed by a sparse tensor.
82 template <typename InternalType>
83 class KeyedSparseTensorColumn : public ColumnInterface<InternalType> {
84 public:
KeyedSparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices,std::vector<int64> key)85 KeyedSparseTensorColumn(const Tensor& values,
86 std::vector<int64> feature_counts,
87 std::vector<int64> feature_start_indices,
88 std::vector<int64> key)
89 : values_(values),
90 feature_counts_(std::move(feature_counts)),
91 feature_start_indices_(std::move(feature_start_indices)) {
92 DCHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
93 std::memcpy(key_, key.data(), sizeof(key_));
94 }
95
FeatureCount(int64_t batch) const96 int64 FeatureCount(int64_t batch) const override {
97 return feature_counts_[batch];
98 }
99
100 InternalType Feature(int64_t batch, int64_t n,
101 bool strong_hash) const override;
102
~KeyedSparseTensorColumn()103 ~KeyedSparseTensorColumn() override {}
104
105 private:
106 const Tensor& values_;
107 tensorflow::uint64 key_[2];
108 std::vector<int64> feature_counts_;
109 std::vector<int64> feature_start_indices_;
110 };
111
112 // InternalType is int64 only when using HashCrosser.
113 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const114 int64 SparseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
115 bool strong_hash) const {
116 const int64_t start = feature_start_indices_[batch];
117 if (DT_STRING == values_.dtype())
118 return Fingerprint64(values_.vec<tstring>().data()[start + n]);
119 return values_.vec<int64>().data()[start + n];
120 }
121
122 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const123 int64 KeyedSparseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
124 bool strong_hash) const {
125 const int64_t start = feature_start_indices_[batch];
126 if (strong_hash) {
127 if (DT_STRING == values_.dtype()) {
128 return StrongKeyedHash(key_, values_.vec<tstring>()(start + n));
129 }
130 return StrongKeyedHash(
131 key_, {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
132 sizeof(values_.dtype())});
133 }
134 if (DT_STRING == values_.dtype())
135 return Fingerprint64(values_.vec<tstring>()(start + n));
136 return Fingerprint64(
137 {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
138 sizeof(values_.dtype())});
139 }
140
141 // InternalType is string or StringPiece when using StringCrosser.
142 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const143 tstring SparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
144 bool strong_hash) const {
145 const int64_t start = feature_start_indices_[batch];
146 if (DT_STRING == values_.dtype())
147 return values_.vec<tstring>().data()[start + n];
148 return std::to_string(values_.vec<int64>().data()[start + n]);
149 }
150
151 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const152 tstring KeyedSparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
153 bool strong_hash) const {
154 const int64_t start = feature_start_indices_[batch];
155 if (DT_STRING == values_.dtype())
156 return values_.vec<tstring>().data()[start + n];
157 return std::to_string(values_.vec<int64>().data()[start + n]);
158 }
159
160 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const161 StringPiece SparseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
162 bool strong_hash) const {
163 const int64_t start = feature_start_indices_[batch];
164 return values_.vec<tstring>().data()[start + n];
165 }
166
167 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const168 StringPiece KeyedSparseTensorColumn<StringPiece>::Feature(
169 int64_t batch, int64_t n, bool strong_hash) const {
170 const int64_t start = feature_start_indices_[batch];
171 return values_.vec<tstring>().data()[start + n];
172 }
173
174 // A column that is backed by a dense tensor.
175 template <typename InternalType>
176 class DenseTensorColumn : public ColumnInterface<InternalType> {
177 public:
DenseTensorColumn(const Tensor & tensor)178 explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
179
FeatureCount(int64_t batch) const180 int64 FeatureCount(int64_t batch) const override {
181 return tensor_.dim_size(1);
182 }
183
184 InternalType Feature(int64_t batch, int64_t n,
185 bool strong_hash) const override;
186
~DenseTensorColumn()187 ~DenseTensorColumn() override {}
188
189 private:
190 const Tensor& tensor_;
191 };
192
193 // A column that is backed by a dense tensor.
194 template <typename InternalType>
195 class KeyedDenseTensorColumn : public ColumnInterface<InternalType> {
196 public:
KeyedDenseTensorColumn(const Tensor & tensor,std::vector<int64> key)197 explicit KeyedDenseTensorColumn(const Tensor& tensor, std::vector<int64> key)
198 : tensor_(tensor) {
199 std::memcpy(key_, key.data(), sizeof(key_));
200 }
201
FeatureCount(int64_t batch) const202 int64 FeatureCount(int64_t batch) const override {
203 return tensor_.dim_size(1);
204 }
205
206 InternalType Feature(int64_t batch, int64_t n,
207 bool strong_hash) const override;
208
~KeyedDenseTensorColumn()209 ~KeyedDenseTensorColumn() override {}
210
211 private:
212 const Tensor& tensor_;
213 tensorflow::uint64 key_[2];
214 };
215
216 // InternalType is int64 only when using HashCrosser.
217 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const218 int64 DenseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
219 bool strong_hash) const {
220 if (DT_STRING == tensor_.dtype())
221 return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
222 return tensor_.matrix<int64>()(batch, n);
223 }
224
225 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const226 int64 KeyedDenseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
227 bool strong_hash) const {
228 if (strong_hash) {
229 if (DT_STRING == tensor_.dtype()) {
230 return StrongKeyedHash(key_, tensor_.matrix<tstring>()(batch, n));
231 }
232 return StrongKeyedHash(
233 key_, {reinterpret_cast<const char*>(tensor_.matrix<int64>()(batch, n)),
234 sizeof(tensor_.dtype())});
235 }
236 if (DT_STRING == tensor_.dtype())
237 return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
238 return tensor_.matrix<int64>()(batch, n);
239 }
240
241 // Internal type is string or StringPiece when using StringCrosser.
242 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const243 tstring DenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
244 bool strong_hash) const {
245 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
246 return std::to_string(tensor_.matrix<int64>()(batch, n));
247 }
248
249 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const250 tstring KeyedDenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
251 bool strong_hash) const {
252 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
253 return std::to_string(tensor_.matrix<int64>()(batch, n));
254 }
255
256 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const257 StringPiece DenseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
258 bool strong_hash) const {
259 return tensor_.matrix<tstring>()(batch, n);
260 }
261
262 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const263 StringPiece KeyedDenseTensorColumn<StringPiece>::Feature(
264 int64_t batch, int64_t n, bool strong_hash) const {
265 return tensor_.matrix<tstring>()(batch, n);
266 }
267
268 // Updates Output tensors with sparse crosses.
269 template <typename OutType>
270 class OutputUpdater {
271 public:
OutputUpdater(const std::vector<int64> & output_start_indices,Tensor * indices_out,Tensor * values_out)272 OutputUpdater(const std::vector<int64>& output_start_indices,
273 Tensor* indices_out, Tensor* values_out)
274 : output_start_indices_(output_start_indices),
275 indices_out_(indices_out),
276 values_out_(values_out) {}
277
Update(const int64_t batch_index,const int64_t cross_count,const OutType & cross) const278 void Update(const int64_t batch_index, const int64_t cross_count,
279 const OutType& cross) const {
280 const int64_t output_index =
281 output_start_indices_[batch_index] + cross_count;
282
283 auto indices_matrix = indices_out_->matrix<int64>();
284 indices_matrix(output_index, 0) = batch_index;
285 indices_matrix(output_index, 1) = cross_count;
286
287 auto value_vec = values_out_->vec<OutType>();
288 value_vec(output_index) = cross;
289 }
290
291 private:
292 const std::vector<int64>& output_start_indices_;
293 Tensor* indices_out_;
294 Tensor* values_out_;
295 };
296
297 // Generates the sparse crosses as concatenation of strings.
298 template <typename InternalType>
299 class StringCrosser {
300 public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64_t num_buckets_unused,const uint64 hash_key_unused,const tstring k_feature_separator)301 StringCrosser(const std::vector<
302 std::unique_ptr<ColumnInterface<InternalType>>>& columns,
303 const int64_t num_buckets_unused, const uint64 hash_key_unused,
304 const tstring k_feature_separator)
305 : columns_(columns), k_feature_separator_(k_feature_separator) {}
306
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const307 string Generate(const int64_t batch_index,
308 const std::vector<int>& permutation,
309 bool unused_strong_hash) const {
310 gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
311 for (int i = 0; i < permutation.size(); i++) {
312 cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i], false);
313 }
314 // TODO(zakaria): this will copy the string twice, might effect
315 // performance.
316 return absl::StrJoin(cross_vec, k_feature_separator_);
317 }
318
319 private:
320 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
321 const tstring k_feature_separator_;
322 };
323
324 // Generates the sparse crosses as nested hash to avoid string manipulations.
325 class HashCrosser {
326 public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64_t num_buckets,const uint64 hash_key,const tstring k_feature_separator_unused)327 HashCrosser(
328 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
329 const int64_t num_buckets, const uint64 hash_key,
330 const tstring k_feature_separator_unused)
331 : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
332
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const333 int64 Generate(const int64_t batch_index, const std::vector<int>& permutation,
334 bool unused_strong_hash) const {
335 // Do the fingerprint concatenation on uint64.
336 uint64 hashed_output = hash_key_;
337 for (size_t i = 0; i < permutation.size(); ++i) {
338 uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i], false);
339 hashed_output = FingerprintCat64(hashed_output, hash_i);
340 }
341 // The return value is int64 based on the number of buckets.
342 if (num_buckets_ > 0) {
343 return hashed_output % num_buckets_;
344 } else {
345 // To prevent negative output we take modulo to max int64.
346 return hashed_output % std::numeric_limits<int64>::max();
347 }
348 }
349
350 private:
351 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
352 const int64 num_buckets_;
353 const uint64 hash_key_;
354 };
355
356 // Generates the sparse crosses as nested hash to avoid string manipulations.
357 class HashCrosserV2 {
358 public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64_t num_buckets,const uint64 hash_key_unused,const tstring k_feature_separator_unused)359 HashCrosserV2(
360 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
361 const int64_t num_buckets, const uint64 hash_key_unused,
362 const tstring k_feature_separator_unused)
363 : columns_(columns), num_buckets_(num_buckets) {}
364
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool strong_hash) const365 int64 Generate(const int64_t batch_index, const std::vector<int>& permutation,
366 bool strong_hash) const {
367 // Do the fingerprint concatenation on uint64.
368 uint64 hashed_output =
369 columns_[0]->Feature(batch_index, permutation[0], strong_hash);
370 for (size_t i = 1; i < permutation.size(); ++i) {
371 uint64 hash_i =
372 columns_[i]->Feature(batch_index, permutation[i], strong_hash);
373 hashed_output = FingerprintCat64(hashed_output, hash_i);
374 }
375 // The return value is int64 based on the number of buckets.
376 if (num_buckets_ > 0) {
377 return hashed_output % num_buckets_;
378 } else {
379 // To prevent negative output we take modulo to max int64.
380 return hashed_output % std::numeric_limits<int64>::max();
381 }
382 }
383
384 private:
385 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
386 const int64 num_buckets_;
387 };
388
389 // ProductIterator generates cartesian products based on indices.
390 template <typename InternalType>
391 class ProductIterator {
392 public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_index)393 explicit ProductIterator(
394 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
395 columns,
396 int64_t batch_index)
397 : columns_(columns), batch_index_(batch_index) {
398 next_permutation_.resize(columns_.size(), 0);
399 // Sets has_next_ to false if any feature column has 0 features.
400 has_next_ = true;
401 for (int i = 0; i < columns_.size(); i++) {
402 if (columns_[i]->FeatureCount(batch_index_) == 0) {
403 has_next_ = false;
404 break;
405 }
406 }
407 }
408
Next()409 std::vector<int> Next() {
410 std::vector<int> permutation(next_permutation_);
411
412 // Generates next permutation, if available.
413 bool carry = true;
414 for (int i = next_permutation_.size() - 1; i >= 0; i--) {
415 if (carry) {
416 next_permutation_[i] = next_permutation_[i] + 1;
417 }
418 if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
419 next_permutation_[i] = 0;
420 } else {
421 carry = false;
422 break;
423 }
424 }
425 has_next_ = !carry;
426 return permutation;
427 }
428
HasNext()429 bool HasNext() { return has_next_; }
430
431 private:
432 bool has_next_;
433 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
434 const int64 batch_index_;
435 std::vector<int> next_permutation_;
436 };
437
438 template <bool HASHED_OUTPUT, typename InternalType>
439 struct CrossTraits;
440
441 template <typename InternalType>
442 struct CrossTraits<false, InternalType> {
443 typedef StringCrosser<InternalType> Crosser;
444 typedef StringCrosser<InternalType> CrosserV2;
445 typedef OutputUpdater<tstring> Updater;
446 };
447
448 template <>
449 struct CrossTraits<true, int64> {
450 typedef HashCrosser Crosser;
451 typedef HashCrosserV2 CrosserV2;
452 typedef OutputUpdater<int64> Updater;
453 };
454 } // namespace
455
456 // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)457 int64 CalculateBatchSize(const OpInputList& shapes_list_in,
458 const OpInputList& dense_list_in) {
459 if (shapes_list_in.size() > 0) {
460 return shapes_list_in[0].vec<int64>()(0);
461 }
462
463 if (dense_list_in.size() > 0) {
464 return dense_list_in[0].dim_size(0);
465 }
466
467 return 0;
468 }
469
470 // Validates input tensors.
ValidateInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,const DataType & internal_type)471 Status ValidateInput(const OpInputList& indices_list_in,
472 const OpInputList& values_list_in,
473 const OpInputList& shapes_list_in,
474 const OpInputList& dense_list_in,
475 const DataType& internal_type) {
476 const auto size = indices_list_in.size();
477 // Only perform internal_type check for SparseCrossOp.
478 // Check if the internal_type is not invalid before doing so.
479 bool check_type = internal_type != DT_INVALID;
480 // Validates indices_list_in OpInputList.
481 for (int i = 0; i < size; i++) {
482 if (check_type && indices_list_in[i].dtype() != DT_INT64) {
483 return errors::InvalidArgument("Input indices should be of type ",
484 DT_INT64, " but received ",
485 indices_list_in[i].dtype());
486 }
487 if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
488 return errors::InvalidArgument(
489 "Input indices should be a matrix but received shape ",
490 indices_list_in[i].shape().DebugString(), " at position ", i);
491 }
492 if (indices_list_in[i].shape().dim_size(1) != 2) {
493 return errors::InvalidArgument("Expected D2 of index to be 2 got ",
494 indices_list_in[i].shape().dim_size(1),
495 " at position ", i);
496 }
497 }
498
499 // Validates values_list_in OpInputList.
500 if (values_list_in.size() != size) {
501 return errors::InvalidArgument("Expected ", size, " input values, got ",
502 values_list_in.size());
503 }
504 for (int i = 0; i < size; i++) {
505 // Make sure to avoid the expected type to be string, but input values to be
506 // int64.
507 if (check_type && internal_type == DT_STRING &&
508 values_list_in[i].dtype() == DT_INT64) {
509 return errors::InvalidArgument("Input values should be of internal type ",
510 internal_type, " but received ",
511 values_list_in[i].dtype());
512 }
513 if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
514 return errors::InvalidArgument(
515 "Input values should be a vector but received shape ",
516 values_list_in[i].shape().DebugString(), " at position ", i);
517 }
518 if (indices_list_in[i].shape().dim_size(0) !=
519 values_list_in[i].shape().dim_size(0)) {
520 return errors::InvalidArgument(
521 "Expected size of values to be ",
522 indices_list_in[i].shape().dim_size(0), " got ",
523 values_list_in[i].shape().dim_size(0), " at position ", i);
524 }
525 }
526
527 // Validates shapes_list_in OpInputList
528 if (shapes_list_in.size() != size) {
529 return errors::InvalidArgument("Expected ", size, " input shapes, got ",
530 shapes_list_in.size());
531 }
532 for (int i = 0; i < size; i++) {
533 if (check_type && shapes_list_in[i].dtype() != DT_INT64) {
534 return errors::InvalidArgument("Input shape should be of type ", DT_INT64,
535 " but received ",
536 shapes_list_in[i].dtype());
537 }
538 if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
539 return errors::InvalidArgument(
540 "Input shapes should be a vector but received shape ",
541 shapes_list_in[i].shape().DebugString(), " at position ", i);
542 }
543
544 if (shapes_list_in[i].vec<int64>().size() != 2) {
545 return errors::InvalidArgument("shape should imply a 2D tensor, but got ",
546 shapes_list_in[i].shape().DebugString(),
547 " at position ", i);
548 }
549 }
550
551 // Validates dense_list_in OpInputList
552 for (int i = 0; i < dense_list_in.size(); ++i) {
553 // Make sure to avoid the expected type to be string, but input values to be
554 // int64.
555 if (check_type && internal_type == DT_STRING &&
556 dense_list_in[i].dtype() == DT_INT64) {
557 return errors::InvalidArgument("Dense inputs should be of internal type ",
558 internal_type, " but received ",
559 dense_list_in[i].dtype());
560 }
561 if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
562 return errors::InvalidArgument(
563 "Dense inputs should be a matrix but received shape ",
564 dense_list_in[i].shape().DebugString(), " at position ", i);
565 }
566 }
567
568 // Validates batch sizes. (Note: we do this after validating the input
569 // shapes, because CalculateBatchSize() depends on inputs having valid
570 // shapes).
571 const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
572 for (int i = 0; i < size; i++) {
573 if (shapes_list_in[i].vec<int64>()(0) != batch_size) {
574 return errors::InvalidArgument("Expected batch size ", batch_size,
575 " got ", shapes_list_in[i].vec<int64>()(0),
576 " at position ", i);
577 }
578 }
579 for (int i = 0; i < dense_list_in.size(); ++i) {
580 if (dense_list_in[i].dim_size(0) != batch_size) {
581 return errors::InvalidArgument("Expected batch size ", batch_size,
582 " got ", dense_list_in[i].dim_size(0),
583 " at dense tensor ", i);
584 }
585 }
586
587 return Status::OK();
588 }
589
590 // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64_t batch_size,std::vector<std::vector<int64>> * feature_counts,std::vector<std::vector<int64>> * feature_start_indices)591 void ExtractFeatureData(
592 const OpInputList& indices_list_in, int64_t batch_size,
593 std::vector<std::vector<int64>>* feature_counts,
594 std::vector<std::vector<int64>>* feature_start_indices) {
595 gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
596 for (int b = 0; b < batch_size; b++) {
597 for (int i = 0; i < indices_list_in.size(); i++) {
598 const auto indices = indices_list_in[i].matrix<int64>();
599 int64_t feature_count = 0;
600 int64_t start_index = current_row[i];
601 // Loops until we reach next batch index for current feature column.
602 while (current_row[i] < indices_list_in[i].dim_size(0) &&
603 indices(current_row[i], 0) == b) {
604 feature_count++;
605 current_row[i]++;
606 }
607 (*feature_counts)[i].push_back(feature_count);
608 (*feature_start_indices)[i].push_back(start_index);
609 }
610 }
611 }
612
613 // Returns number of crosses for a given batch_index
614 template <typename InternalType>
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)615 int64 CrossCountByBatchIndex(
616 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
617 int batch_index) {
618 int64_t cross_count = 1;
619 for (int i = 0; i < columns.size(); i++) {
620 const auto feature_count = columns[i]->FeatureCount(batch_index);
621 // If one column is missing any feature, there won't be any cross.
622 if (feature_count == 0) {
623 return 0;
624 }
625 cross_count *= feature_count;
626 }
627 return cross_count;
628 }
629
630 // Generate the columns given the sparse and dense inputs.
631 template <typename InternalType>
632 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)633 GenerateColumnsFromInput(const OpInputList& indices_list_in,
634 const OpInputList& values_list_in,
635 const OpInputList& shapes_list_in,
636 const OpInputList& dense_list_in) {
637 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
638 const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
639 const int64_t number_of_columns = shapes_list_in.size();
640
641 std::vector<std::vector<int64>> feature_counts(number_of_columns,
642 std::vector<int64>());
643 std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
644 std::vector<int64>());
645
646 ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
647 &feature_start_indices);
648
649 columns.reserve(values_list_in.size());
650 for (int i = 0; i < values_list_in.size(); ++i) {
651 columns.emplace_back(new SparseTensorColumn<InternalType>(
652 values_list_in[i], std::move(feature_counts[i]),
653 std::move(feature_start_indices[i])));
654 }
655 for (int i = 0; i < dense_list_in.size(); ++i) {
656 columns.emplace_back(new DenseTensorColumn<InternalType>(dense_list_in[i]));
657 }
658
659 return columns;
660 }
661
662 // Generate the columns given the sparse and dense inputs.
663 template <typename InternalType>
664 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateKeyedColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,std::vector<int64> keys)665 GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in,
666 const OpInputList& values_list_in,
667 const OpInputList& shapes_list_in,
668 const OpInputList& dense_list_in,
669 std::vector<int64> keys) {
670 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
671 const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
672 const int64_t number_of_columns = shapes_list_in.size();
673
674 std::vector<std::vector<int64>> feature_counts(number_of_columns,
675 std::vector<int64>());
676 std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
677 std::vector<int64>());
678
679 ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
680 &feature_start_indices);
681
682 columns.reserve(values_list_in.size());
683 for (int i = 0; i < values_list_in.size(); ++i) {
684 columns.emplace_back(new KeyedSparseTensorColumn<InternalType>(
685 values_list_in[i], std::move(feature_counts[i]),
686 std::move(feature_start_indices[i]), keys));
687 }
688 for (int i = 0; i < dense_list_in.size(); ++i) {
689 columns.emplace_back(
690 new KeyedDenseTensorColumn<InternalType>(dense_list_in[i], keys));
691 }
692
693 return columns;
694 }
695
696 // Allocates output tensors with proper size and sets the shape tensor of
697 // the output SparseTensor.
698 // It also output_start_indices which contains the start indices for each
699 // input in the output SparseTensor.
700 template <typename InternalType>
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64> * output_start_indices)701 Status CreateOutputTensors(
702 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
703 int64_t batch_size, OpKernelContext* context, Tensor** indices_out,
704 Tensor** values_out, Tensor** shape_out,
705 std::vector<int64>* output_start_indices) {
706 // Calculates dimensions for output tensors.
707 int64_t cross_count_total = 0;
708 int64_t max_cross_count = 0;
709 for (int64_t b = 0; b < batch_size; b++) {
710 // For each input, sets starting indices in output SparseTensor
711 (*output_start_indices)[b] = cross_count_total;
712 const auto cross_count = CrossCountByBatchIndex(columns, b);
713 max_cross_count = std::max(max_cross_count, cross_count);
714 cross_count_total += cross_count;
715 }
716
717 // Allocates tensors.
718 TF_RETURN_IF_ERROR(context->allocate_output(
719 0, TensorShape({cross_count_total, 2}), indices_out));
720 TF_RETURN_IF_ERROR(context->allocate_output(
721 1, TensorShape({cross_count_total}), values_out));
722 TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({2}), shape_out));
723
724 // Sets shape.
725 auto shape_vec = (*shape_out)->vec<int64>();
726 shape_vec(0) = batch_size;
727 shape_vec(1) = max_cross_count;
728
729 return Status::OK();
730 }
731
732 template <bool HASHED_OUTPUT, typename InternalType>
733 class SparseCrossOp : public OpKernel {
734 public:
SparseCrossOp(OpKernelConstruction * context)735 explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
736 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
737 // Read signed_hash_key_ as int64 since uint64 attributes are not
738 // supported by REGISTER_OP.
739 int64_t signed_hash_key_;
740 OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
741 hash_key_ = static_cast<uint64>(signed_hash_key_);
742 OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_));
743 }
744
Compute(OpKernelContext * context)745 void Compute(OpKernelContext* context) override {
746 OpInputList indices_list_in;
747 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
748 OpInputList values_list_in;
749 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
750 OpInputList shapes_list_in;
751 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
752 OpInputList dense_list_in;
753 OP_REQUIRES_OK(context,
754 context->input_list("dense_inputs", &dense_list_in));
755
756 DataType internal_type = internal_type_;
757 OP_REQUIRES_OK(
758 context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
759 dense_list_in, internal_type));
760
761 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
762 GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
763 shapes_list_in, dense_list_in);
764
765 const tstring k_feature_separator = "_X_";
766 typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
767 columns, num_buckets_, hash_key_, k_feature_separator);
768 Tensor* indices_out;
769 Tensor* values_out;
770 Tensor* shape_out;
771 const int64_t batch_size =
772 CalculateBatchSize(shapes_list_in, dense_list_in);
773 std::vector<int64> output_start_indices(batch_size);
774 OP_REQUIRES_OK(
775 context,
776 CreateOutputTensors(columns, batch_size, context, &indices_out,
777 &values_out, &shape_out, &output_start_indices));
778
779 typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
780 output_start_indices, indices_out, values_out);
781 auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
782 for (int b = begin; b < end; b++) {
783 ProductIterator<InternalType> product_iterator(columns, b);
784 int64_t cross_count = 0;
785 while (product_iterator.HasNext()) {
786 const auto permutation = product_iterator.Next();
787 updater.Update(b, cross_count,
788 crosser.Generate(b, permutation, false));
789 cross_count++;
790 }
791 }
792 };
793
794 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
795 // TODO(zakaria): optimize kCostPerUnit
796 const int kCostPerUnit = 5000 * indices_list_in.size();
797 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
798 kCostPerUnit, do_work);
799 }
800
801 private:
802 int64 num_buckets_;
803 uint64 hash_key_;
804 DataType internal_type_;
805 };
806
807 class SparseCrossV2Op : public OpKernel {
808 public:
SparseCrossV2Op(OpKernelConstruction * context)809 explicit SparseCrossV2Op(OpKernelConstruction* context) : OpKernel(context) {}
810
Compute(OpKernelContext * context)811 void Compute(OpKernelContext* context) override {
812 OpInputList indices_list_in;
813 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
814 OpInputList values_list_in;
815 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
816 OpInputList shapes_list_in;
817 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
818 OpInputList dense_list_in;
819 OP_REQUIRES_OK(context,
820 context->input_list("dense_inputs", &dense_list_in));
821
822 // Set internal_type to invalid_type so that the check will be ignored.
823 DataType internal_type = DT_INVALID;
824 OP_REQUIRES_OK(
825 context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
826 dense_list_in, internal_type));
827
828 const Tensor* sep_t;
829 OP_REQUIRES_OK(context, context->input("sep", &sep_t));
830 const tstring separator = sep_t->scalar<tstring>()();
831
832 std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns =
833 GenerateColumnsFromInput<tstring>(indices_list_in, values_list_in,
834 shapes_list_in, dense_list_in);
835 Tensor* indices_out;
836 Tensor* values_out;
837 Tensor* shape_out;
838 const int64_t batch_size =
839 CalculateBatchSize(shapes_list_in, dense_list_in);
840 std::vector<int64> output_start_indices(batch_size);
841 OP_REQUIRES_OK(
842 context,
843 CreateOutputTensors(columns, batch_size, context, &indices_out,
844 &values_out, &shape_out, &output_start_indices));
845 StringCrosser<tstring> crosser(columns, 0, 0, separator);
846 OutputUpdater<tstring> updater(output_start_indices, indices_out,
847 values_out);
848 auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
849 for (int b = begin; b < end; b++) {
850 ProductIterator<tstring> product_iterator(columns, b);
851 int64_t cross_count = 0;
852 while (product_iterator.HasNext()) {
853 const auto permutation = product_iterator.Next();
854 updater.Update(b, cross_count,
855 crosser.Generate(b, permutation, false));
856 cross_count++;
857 }
858 }
859 };
860
861 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
862 // TODO(zakaria): optimize kCostPerUnit
863 const int kCostPerUnit = 5000 * indices_list_in.size();
864 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
865 kCostPerUnit, do_work);
866 }
867 };
868
869 class SparseCrossHashedOp : public OpKernel {
870 public:
SparseCrossHashedOp(OpKernelConstruction * context)871 explicit SparseCrossHashedOp(OpKernelConstruction* context)
872 : OpKernel(context) {}
873
Compute(OpKernelContext * context)874 void Compute(OpKernelContext* context) override {
875 OpInputList indices_list_in;
876 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
877 OpInputList values_list_in;
878 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
879 OpInputList shapes_list_in;
880 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
881 OpInputList dense_list_in;
882 OP_REQUIRES_OK(context,
883 context->input_list("dense_inputs", &dense_list_in));
884
885 // Set internal_type to invalid_type so that the check will be ignored.
886 DataType internal_type = DT_INVALID;
887 OP_REQUIRES_OK(
888 context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
889 dense_list_in, internal_type));
890
891 const Tensor* num_buckets_t;
892 OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));
893 const int64_t num_buckets = num_buckets_t->scalar<int64>()();
894
895 const Tensor* strong_hash_t;
896 OP_REQUIRES_OK(context, context->input("strong_hash", &strong_hash_t));
897 const bool strong_hash = strong_hash_t->scalar<bool>()();
898
899 const Tensor* salt_t;
900 OP_REQUIRES_OK(context, context->input("salt", &salt_t));
901 const auto salt = salt_t->flat<int64>();
902 std::vector<int64> key_{salt(0), salt(1)};
903
904 std::vector<std::unique_ptr<ColumnInterface<int64>>> columns =
905 GenerateKeyedColumnsFromInput<int64>(indices_list_in, values_list_in,
906 shapes_list_in, dense_list_in,
907 key_);
908 Tensor* indices_out;
909 Tensor* values_out;
910 Tensor* shape_out;
911 const int64_t batch_size =
912 CalculateBatchSize(shapes_list_in, dense_list_in);
913 std::vector<int64> output_start_indices(batch_size);
914 OP_REQUIRES_OK(
915 context,
916 CreateOutputTensors(columns, batch_size, context, &indices_out,
917 &values_out, &shape_out, &output_start_indices));
918 const tstring unused_sep;
919 HashCrosserV2 crosser(columns, num_buckets, 0, unused_sep);
920 OutputUpdater<int64> updater(output_start_indices, indices_out, values_out);
921 auto do_work = [&columns, crosser, updater, strong_hash](int64_t begin,
922 int64_t end) {
923 for (int b = begin; b < end; b++) {
924 ProductIterator<int64> product_iterator(columns, b);
925 int64_t cross_count = 0;
926 while (product_iterator.HasNext()) {
927 const auto permutation = product_iterator.Next();
928 updater.Update(b, cross_count,
929 crosser.Generate(b, permutation, strong_hash));
930 cross_count++;
931 }
932 }
933 };
934
935 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
936 // TODO(zakaria): optimize kCostPerUnit
937 const int kCostPerUnit = 5000 * indices_list_in.size();
938 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
939 kCostPerUnit, do_work);
940 }
941 };
942
943 REGISTER_KERNEL_BUILDER(Name("SparseCross")
944 .Device(DEVICE_CPU)
945 .TypeConstraint<tstring>("out_type")
946 .TypeConstraint<tstring>("internal_type"),
947 SparseCrossOp<false, StringPiece>);
948
949 REGISTER_KERNEL_BUILDER(Name("SparseCross")
950 .Device(DEVICE_CPU)
951 .TypeConstraint<tstring>("out_type")
952 .TypeConstraint<int64>("internal_type"),
953 SparseCrossOp<false, tstring>);
954
955 REGISTER_KERNEL_BUILDER(Name("SparseCross")
956 .Device(DEVICE_CPU)
957 .TypeConstraint<int64>("out_type")
958 .TypeConstraint<tstring>("internal_type"),
959 SparseCrossOp<true, int64>);
960
961 REGISTER_KERNEL_BUILDER(Name("SparseCross")
962 .Device(DEVICE_CPU)
963 .TypeConstraint<int64>("out_type")
964 .TypeConstraint<int64>("internal_type"),
965 SparseCrossOp<true, int64>);
966
967 REGISTER_KERNEL_BUILDER(Name("SparseCrossV2").Device(DEVICE_CPU),
968 SparseCrossV2Op);
969
970 REGISTER_KERNEL_BUILDER(Name("SparseCrossHashed").Device(DEVICE_CPU),
971 SparseCrossHashedOp);
972
973 } // namespace tensorflow
974