1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18
19 #include <limits>
20 #include <string>
21 #include <vector>
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/op_requires.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/fingerprint.h"
36 #include "tensorflow/core/platform/strong_hash.h"
37 #include "tensorflow/core/util/work_sharder.h"
38
39 namespace tensorflow {
40
41 namespace {
42 // An interface that represents a column with batches.
43 template <typename InternalType>
44 class ColumnInterface {
45 public:
46 // Returns the number of features in the specified batch.
47 virtual int64_t FeatureCount(int64_t batch) const = 0;
48
49 // Returns the fingerprint of nth feature from the specified batch.
50 virtual InternalType Feature(int64_t batch, int64_t n,
51 bool strong_hash) const = 0;
52
~ColumnInterface()53 virtual ~ColumnInterface() {}
54 };
55
56 // A column that is backed by a sparse tensor.
57 template <typename InternalType>
58 class SparseTensorColumn : public ColumnInterface<InternalType> {
59 public:
SparseTensorColumn(const Tensor & values,std::vector<int64_t> feature_counts,std::vector<int64_t> feature_start_indices)60 SparseTensorColumn(const Tensor& values, std::vector<int64_t> feature_counts,
61 std::vector<int64_t> feature_start_indices)
62 : values_(values),
63 feature_counts_(std::move(feature_counts)),
64 feature_start_indices_(std::move(feature_start_indices)) {
65 CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
66 }
67
FeatureCount(int64_t batch) const68 int64_t FeatureCount(int64_t batch) const override {
69 return feature_counts_[batch];
70 }
71
72 InternalType Feature(int64_t batch, int64_t n,
73 bool strong_hash) const override;
74
~SparseTensorColumn()75 ~SparseTensorColumn() override {}
76
77 private:
78 const Tensor& values_;
79 std::vector<int64_t> feature_counts_;
80 std::vector<int64_t> feature_start_indices_;
81 };
82
83 // A column that is backed by a sparse tensor.
84 template <typename InternalType>
85 class KeyedSparseTensorColumn : public ColumnInterface<InternalType> {
86 public:
KeyedSparseTensorColumn(const Tensor & values,std::vector<int64_t> feature_counts,std::vector<int64_t> feature_start_indices,std::vector<int64_t> key)87 KeyedSparseTensorColumn(const Tensor& values,
88 std::vector<int64_t> feature_counts,
89 std::vector<int64_t> feature_start_indices,
90 std::vector<int64_t> key)
91 : values_(values),
92 feature_counts_(std::move(feature_counts)),
93 feature_start_indices_(std::move(feature_start_indices)) {
94 DCHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
95 std::memcpy(key_, key.data(), sizeof(key_));
96 }
97
FeatureCount(int64_t batch) const98 int64_t FeatureCount(int64_t batch) const override {
99 return feature_counts_[batch];
100 }
101
102 InternalType Feature(int64_t batch, int64_t n,
103 bool strong_hash) const override;
104
~KeyedSparseTensorColumn()105 ~KeyedSparseTensorColumn() override {}
106
107 private:
108 const Tensor& values_;
109 tensorflow::uint64 key_[2];
110 std::vector<int64_t> feature_counts_;
111 std::vector<int64_t> feature_start_indices_;
112 };
113
114 // InternalType is int64 only when using HashCrosser.
115 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const116 int64_t SparseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
117 bool strong_hash) const {
118 const int64_t start = feature_start_indices_[batch];
119 if (DT_STRING == values_.dtype())
120 return Fingerprint64(values_.vec<tstring>().data()[start + n]);
121 return values_.vec<int64_t>().data()[start + n];
122 }
123
124 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const125 int64_t KeyedSparseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
126 bool strong_hash) const {
127 const int64_t start = feature_start_indices_[batch];
128 if (strong_hash) {
129 if (DT_STRING == values_.dtype()) {
130 return StrongKeyedHash(key_, values_.vec<tstring>()(start + n));
131 }
132 return StrongKeyedHash(
133 key_,
134 {reinterpret_cast<const char*>(&values_.vec<int64_t>()(start + n)),
135 sizeof(values_.dtype())});
136 }
137 if (DT_STRING == values_.dtype())
138 return Fingerprint64(values_.vec<tstring>()(start + n));
139 return Fingerprint64(
140 {reinterpret_cast<const char*>(&values_.vec<int64_t>()(start + n)),
141 sizeof(values_.dtype())});
142 }
143
144 // InternalType is string or StringPiece when using StringCrosser.
145 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const146 tstring SparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
147 bool strong_hash) const {
148 const int64_t start = feature_start_indices_[batch];
149 if (DT_STRING == values_.dtype())
150 return values_.vec<tstring>().data()[start + n];
151 return std::to_string(values_.vec<int64_t>().data()[start + n]);
152 }
153
154 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const155 tstring KeyedSparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
156 bool strong_hash) const {
157 const int64_t start = feature_start_indices_[batch];
158 if (DT_STRING == values_.dtype())
159 return values_.vec<tstring>().data()[start + n];
160 return std::to_string(values_.vec<int64_t>().data()[start + n]);
161 }
162
163 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const164 StringPiece SparseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
165 bool strong_hash) const {
166 const int64_t start = feature_start_indices_[batch];
167 return values_.vec<tstring>().data()[start + n];
168 }
169
170 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const171 StringPiece KeyedSparseTensorColumn<StringPiece>::Feature(
172 int64_t batch, int64_t n, bool strong_hash) const {
173 const int64_t start = feature_start_indices_[batch];
174 return values_.vec<tstring>().data()[start + n];
175 }
176
177 // A column that is backed by a dense tensor.
178 template <typename InternalType>
179 class DenseTensorColumn : public ColumnInterface<InternalType> {
180 public:
DenseTensorColumn(const Tensor & tensor)181 explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
182
FeatureCount(int64_t batch) const183 int64_t FeatureCount(int64_t batch) const override {
184 return tensor_.dim_size(1);
185 }
186
187 InternalType Feature(int64_t batch, int64_t n,
188 bool strong_hash) const override;
189
~DenseTensorColumn()190 ~DenseTensorColumn() override {}
191
192 private:
193 const Tensor& tensor_;
194 };
195
196 // A column that is backed by a dense tensor.
197 template <typename InternalType>
198 class KeyedDenseTensorColumn : public ColumnInterface<InternalType> {
199 public:
KeyedDenseTensorColumn(const Tensor & tensor,std::vector<int64_t> key)200 explicit KeyedDenseTensorColumn(const Tensor& tensor,
201 std::vector<int64_t> key)
202 : tensor_(tensor) {
203 std::memcpy(key_, key.data(), sizeof(key_));
204 }
205
FeatureCount(int64_t batch) const206 int64_t FeatureCount(int64_t batch) const override {
207 return tensor_.dim_size(1);
208 }
209
210 InternalType Feature(int64_t batch, int64_t n,
211 bool strong_hash) const override;
212
~KeyedDenseTensorColumn()213 ~KeyedDenseTensorColumn() override {}
214
215 private:
216 const Tensor& tensor_;
217 tensorflow::uint64 key_[2];
218 };
219
220 // InternalType is int64 only when using HashCrosser.
221 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const222 int64_t DenseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
223 bool strong_hash) const {
224 if (DT_STRING == tensor_.dtype())
225 return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
226 return tensor_.matrix<int64_t>()(batch, n);
227 }
228
229 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const230 int64_t KeyedDenseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
231 bool strong_hash) const {
232 if (strong_hash) {
233 if (DT_STRING == tensor_.dtype()) {
234 return StrongKeyedHash(key_, tensor_.matrix<tstring>()(batch, n));
235 }
236 return StrongKeyedHash(
237 key_,
238 {reinterpret_cast<const char*>(tensor_.matrix<int64_t>()(batch, n)),
239 sizeof(tensor_.dtype())});
240 }
241 if (DT_STRING == tensor_.dtype())
242 return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
243 return tensor_.matrix<int64_t>()(batch, n);
244 }
245
246 // Internal type is string or StringPiece when using StringCrosser.
247 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const248 tstring DenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
249 bool strong_hash) const {
250 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
251 return std::to_string(tensor_.matrix<int64_t>()(batch, n));
252 }
253
254 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const255 tstring KeyedDenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
256 bool strong_hash) const {
257 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
258 return std::to_string(tensor_.matrix<int64_t>()(batch, n));
259 }
260
261 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const262 StringPiece DenseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
263 bool strong_hash) const {
264 return tensor_.matrix<tstring>()(batch, n);
265 }
266
267 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const268 StringPiece KeyedDenseTensorColumn<StringPiece>::Feature(
269 int64_t batch, int64_t n, bool strong_hash) const {
270 return tensor_.matrix<tstring>()(batch, n);
271 }
272
273 // Updates Output tensors with sparse crosses.
274 template <typename OutType>
275 class OutputUpdater {
276 public:
OutputUpdater(const std::vector<int64_t> & output_start_indices,Tensor * indices_out,Tensor * values_out)277 OutputUpdater(const std::vector<int64_t>& output_start_indices,
278 Tensor* indices_out, Tensor* values_out)
279 : output_start_indices_(output_start_indices),
280 indices_out_(indices_out),
281 values_out_(values_out) {}
282
Update(const int64_t batch_index,const int64_t cross_count,const OutType & cross) const283 void Update(const int64_t batch_index, const int64_t cross_count,
284 const OutType& cross) const {
285 const int64_t output_index =
286 output_start_indices_[batch_index] + cross_count;
287
288 auto indices_matrix = indices_out_->matrix<int64_t>();
289 indices_matrix(output_index, 0) = batch_index;
290 indices_matrix(output_index, 1) = cross_count;
291
292 auto value_vec = values_out_->vec<OutType>();
293 value_vec(output_index) = cross;
294 }
295
296 private:
297 const std::vector<int64_t>& output_start_indices_;
298 Tensor* indices_out_;
299 Tensor* values_out_;
300 };
301
302 // Generates the sparse crosses as concatenation of strings.
303 template <typename InternalType>
304 class StringCrosser {
305 public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64_t num_buckets_unused,const uint64 hash_key_unused,const tstring k_feature_separator)306 StringCrosser(const std::vector<
307 std::unique_ptr<ColumnInterface<InternalType>>>& columns,
308 const int64_t num_buckets_unused, const uint64 hash_key_unused,
309 const tstring k_feature_separator)
310 : columns_(columns), k_feature_separator_(k_feature_separator) {}
311
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const312 string Generate(const int64_t batch_index,
313 const std::vector<int>& permutation,
314 bool unused_strong_hash) const {
315 gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
316 for (int i = 0; i < permutation.size(); i++) {
317 cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i], false);
318 }
319 // TODO(zakaria): this will copy the string twice, might effect
320 // performance.
321 return absl::StrJoin(cross_vec, k_feature_separator_);
322 }
323
324 private:
325 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
326 const tstring k_feature_separator_;
327 };
328
329 // Generates the sparse crosses as nested hash to avoid string manipulations.
330 class HashCrosser {
331 public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64_t>>> & columns,const int64_t num_buckets,const uint64 hash_key,const tstring k_feature_separator_unused)332 HashCrosser(
333 const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns,
334 const int64_t num_buckets, const uint64 hash_key,
335 const tstring k_feature_separator_unused)
336 : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
337
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const338 int64_t Generate(const int64_t batch_index,
339 const std::vector<int>& permutation,
340 bool unused_strong_hash) const {
341 // Do the fingerprint concatenation on uint64.
342 uint64 hashed_output = hash_key_;
343 for (size_t i = 0; i < permutation.size(); ++i) {
344 uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i], false);
345 hashed_output = FingerprintCat64(hashed_output, hash_i);
346 }
347 // The return value is int64 based on the number of buckets.
348 if (num_buckets_ > 0) {
349 return hashed_output % num_buckets_;
350 } else {
351 // To prevent negative output we take modulo to max int64.
352 return hashed_output % std::numeric_limits<int64_t>::max();
353 }
354 }
355
356 private:
357 const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns_;
358 const int64_t num_buckets_;
359 const uint64 hash_key_;
360 };
361
362 // Generates the sparse crosses as nested hash to avoid string manipulations.
363 class HashCrosserV2 {
364 public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64_t>>> & columns,const int64_t num_buckets,const uint64 hash_key_unused,const tstring k_feature_separator_unused)365 HashCrosserV2(
366 const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns,
367 const int64_t num_buckets, const uint64 hash_key_unused,
368 const tstring k_feature_separator_unused)
369 : columns_(columns), num_buckets_(num_buckets) {}
370
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool strong_hash) const371 int64_t Generate(const int64_t batch_index,
372 const std::vector<int>& permutation,
373 bool strong_hash) const {
374 // Do the fingerprint concatenation on uint64.
375 uint64 hashed_output =
376 columns_[0]->Feature(batch_index, permutation[0], strong_hash);
377 for (size_t i = 1; i < permutation.size(); ++i) {
378 uint64 hash_i =
379 columns_[i]->Feature(batch_index, permutation[i], strong_hash);
380 hashed_output = FingerprintCat64(hashed_output, hash_i);
381 }
382 // The return value is int64 based on the number of buckets.
383 if (num_buckets_ > 0) {
384 return hashed_output % num_buckets_;
385 } else {
386 // To prevent negative output we take modulo to max int64.
387 return hashed_output % std::numeric_limits<int64_t>::max();
388 }
389 }
390
391 private:
392 const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns_;
393 const int64_t num_buckets_;
394 };
395
396 // ProductIterator generates cartesian products based on indices.
397 template <typename InternalType>
398 class ProductIterator {
399 public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_index)400 explicit ProductIterator(
401 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
402 columns,
403 int64_t batch_index)
404 : columns_(columns), batch_index_(batch_index) {
405 next_permutation_.resize(columns_.size(), 0);
406 // Sets has_next_ to false if any feature column has 0 features.
407 has_next_ = true;
408 for (int i = 0; i < columns_.size(); i++) {
409 if (columns_[i]->FeatureCount(batch_index_) == 0) {
410 has_next_ = false;
411 break;
412 }
413 }
414 }
415
Next()416 std::vector<int> Next() {
417 std::vector<int> permutation(next_permutation_);
418
419 // Generates next permutation, if available.
420 bool carry = true;
421 for (int i = next_permutation_.size() - 1; i >= 0; i--) {
422 if (carry) {
423 next_permutation_[i] = next_permutation_[i] + 1;
424 }
425 if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
426 next_permutation_[i] = 0;
427 } else {
428 carry = false;
429 break;
430 }
431 }
432 has_next_ = !carry;
433 return permutation;
434 }
435
HasNext()436 bool HasNext() { return has_next_; }
437
438 private:
439 bool has_next_;
440 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
441 const int64_t batch_index_;
442 std::vector<int> next_permutation_;
443 };
444
445 template <bool HASHED_OUTPUT, typename InternalType>
446 struct CrossTraits;
447
448 template <typename InternalType>
449 struct CrossTraits<false, InternalType> {
450 typedef StringCrosser<InternalType> Crosser;
451 typedef StringCrosser<InternalType> CrosserV2;
452 typedef OutputUpdater<tstring> Updater;
453 };
454
455 template <>
456 struct CrossTraits<true, int64_t> {
457 typedef HashCrosser Crosser;
458 typedef HashCrosserV2 CrosserV2;
459 typedef OutputUpdater<int64_t> Updater;
460 };
461 } // namespace
462
463 // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)464 int64_t CalculateBatchSize(const OpInputList& shapes_list_in,
465 const OpInputList& dense_list_in) {
466 if (shapes_list_in.size() > 0) {
467 return shapes_list_in[0].vec<int64_t>()(0);
468 }
469
470 if (dense_list_in.size() > 0) {
471 return dense_list_in[0].dim_size(0);
472 }
473
474 return 0;
475 }
476
477 // Validates input tensors.
ValidateInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,const DataType & internal_type)478 Status ValidateInput(const OpInputList& indices_list_in,
479 const OpInputList& values_list_in,
480 const OpInputList& shapes_list_in,
481 const OpInputList& dense_list_in,
482 const DataType& internal_type) {
483 const auto size = indices_list_in.size();
484 // Only perform internal_type check for SparseCrossOp.
485 // Check if the internal_type is not invalid before doing so.
486 bool check_type = internal_type != DT_INVALID;
487 // Validates indices_list_in OpInputList.
488 for (int i = 0; i < size; i++) {
489 if (check_type && indices_list_in[i].dtype() != DT_INT64) {
490 return errors::InvalidArgument("Input indices should be of type ",
491 DT_INT64, " but received ",
492 indices_list_in[i].dtype());
493 }
494 if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
495 return errors::InvalidArgument(
496 "Input indices should be a matrix but received shape ",
497 indices_list_in[i].shape().DebugString(), " at position ", i);
498 }
499 if (indices_list_in[i].shape().dim_size(1) != 2) {
500 return errors::InvalidArgument("Expected D2 of index to be 2 got ",
501 indices_list_in[i].shape().dim_size(1),
502 " at position ", i);
503 }
504 }
505
506 // Validates values_list_in OpInputList.
507 if (values_list_in.size() != size) {
508 return errors::InvalidArgument("Expected ", size, " input values, got ",
509 values_list_in.size());
510 }
511 for (int i = 0; i < size; i++) {
512 // Make sure to avoid the expected type to be string, but input values to be
513 // int64.
514 if (check_type && internal_type == DT_STRING &&
515 values_list_in[i].dtype() == DT_INT64) {
516 return errors::InvalidArgument("Input values should be of internal type ",
517 internal_type, " but received ",
518 values_list_in[i].dtype());
519 }
520 if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
521 return errors::InvalidArgument(
522 "Input values should be a vector but received shape ",
523 values_list_in[i].shape().DebugString(), " at position ", i);
524 }
525 if (indices_list_in[i].shape().dim_size(0) !=
526 values_list_in[i].shape().dim_size(0)) {
527 return errors::InvalidArgument(
528 "Expected size of values to be ",
529 indices_list_in[i].shape().dim_size(0), " got ",
530 values_list_in[i].shape().dim_size(0), " at position ", i);
531 }
532 }
533
534 // Validates shapes_list_in OpInputList
535 if (shapes_list_in.size() != size) {
536 return errors::InvalidArgument("Expected ", size, " input shapes, got ",
537 shapes_list_in.size());
538 }
539 for (int i = 0; i < size; i++) {
540 if (check_type && shapes_list_in[i].dtype() != DT_INT64) {
541 return errors::InvalidArgument("Input shape should be of type ", DT_INT64,
542 " but received ",
543 shapes_list_in[i].dtype());
544 }
545 if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
546 return errors::InvalidArgument(
547 "Input shapes should be a vector but received shape ",
548 shapes_list_in[i].shape().DebugString(), " at position ", i);
549 }
550
551 if (shapes_list_in[i].vec<int64_t>().size() != 2) {
552 return errors::InvalidArgument("shape should imply a 2D tensor, but got ",
553 shapes_list_in[i].shape().DebugString(),
554 " at position ", i);
555 }
556 }
557
558 // Validates dense_list_in OpInputList
559 for (int i = 0; i < dense_list_in.size(); ++i) {
560 // Make sure to avoid the expected type to be string, but input values to be
561 // int64.
562 if (check_type && internal_type == DT_STRING &&
563 dense_list_in[i].dtype() == DT_INT64) {
564 return errors::InvalidArgument("Dense inputs should be of internal type ",
565 internal_type, " but received ",
566 dense_list_in[i].dtype());
567 }
568 if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
569 return errors::InvalidArgument(
570 "Dense inputs should be a matrix but received shape ",
571 dense_list_in[i].shape().DebugString(), " at position ", i);
572 }
573 }
574
575 // Validates batch sizes. (Note: we do this after validating the input
576 // shapes, because CalculateBatchSize() depends on inputs having valid
577 // shapes).
578 const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
579 for (int i = 0; i < size; i++) {
580 if (shapes_list_in[i].vec<int64_t>()(0) != batch_size) {
581 return errors::InvalidArgument(
582 "Expected batch size ", batch_size, " got ",
583 shapes_list_in[i].vec<int64_t>()(0), " at position ", i);
584 }
585 }
586 for (int i = 0; i < dense_list_in.size(); ++i) {
587 if (dense_list_in[i].dim_size(0) != batch_size) {
588 return errors::InvalidArgument("Expected batch size ", batch_size,
589 " got ", dense_list_in[i].dim_size(0),
590 " at dense tensor ", i);
591 }
592 }
593
594 return OkStatus();
595 }
596
597 // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64_t batch_size,std::vector<std::vector<int64_t>> * feature_counts,std::vector<std::vector<int64_t>> * feature_start_indices)598 void ExtractFeatureData(
599 const OpInputList& indices_list_in, int64_t batch_size,
600 std::vector<std::vector<int64_t>>* feature_counts,
601 std::vector<std::vector<int64_t>>* feature_start_indices) {
602 gtl::InlinedVector<int64_t, 8> current_row(indices_list_in.size(), 0);
603 for (int b = 0; b < batch_size; b++) {
604 for (int i = 0; i < indices_list_in.size(); i++) {
605 const auto indices = indices_list_in[i].matrix<int64_t>();
606 int64_t feature_count = 0;
607 int64_t start_index = current_row[i];
608 // Loops until we reach next batch index for current feature column.
609 while (current_row[i] < indices_list_in[i].dim_size(0) &&
610 indices(current_row[i], 0) == b) {
611 feature_count++;
612 current_row[i]++;
613 }
614 (*feature_counts)[i].push_back(feature_count);
615 (*feature_start_indices)[i].push_back(start_index);
616 }
617 }
618 }
619
620 // Returns number of crosses for a given batch_index
621 template <typename InternalType>
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)622 int64_t CrossCountByBatchIndex(
623 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
624 int batch_index) {
625 int64_t cross_count = 1;
626 for (int i = 0; i < columns.size(); i++) {
627 const auto feature_count = columns[i]->FeatureCount(batch_index);
628 // If one column is missing any feature, there won't be any cross.
629 if (feature_count == 0) {
630 return 0;
631 }
632 cross_count *= feature_count;
633 }
634 return cross_count;
635 }
636
637 // Generate the columns given the sparse and dense inputs.
638 template <typename InternalType>
639 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)640 GenerateColumnsFromInput(const OpInputList& indices_list_in,
641 const OpInputList& values_list_in,
642 const OpInputList& shapes_list_in,
643 const OpInputList& dense_list_in) {
644 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
645 const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
646 const int64_t number_of_columns = shapes_list_in.size();
647
648 std::vector<std::vector<int64_t>> feature_counts(number_of_columns,
649 std::vector<int64_t>());
650 std::vector<std::vector<int64_t>> feature_start_indices(
651 number_of_columns, std::vector<int64_t>());
652
653 ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
654 &feature_start_indices);
655
656 columns.reserve(values_list_in.size());
657 for (int i = 0; i < values_list_in.size(); ++i) {
658 columns.emplace_back(new SparseTensorColumn<InternalType>(
659 values_list_in[i], std::move(feature_counts[i]),
660 std::move(feature_start_indices[i])));
661 }
662 for (int i = 0; i < dense_list_in.size(); ++i) {
663 columns.emplace_back(new DenseTensorColumn<InternalType>(dense_list_in[i]));
664 }
665
666 return columns;
667 }
668
669 // Generate the columns given the sparse and dense inputs.
670 template <typename InternalType>
671 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateKeyedColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,std::vector<int64_t> keys)672 GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in,
673 const OpInputList& values_list_in,
674 const OpInputList& shapes_list_in,
675 const OpInputList& dense_list_in,
676 std::vector<int64_t> keys) {
677 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
678 const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
679 const int64_t number_of_columns = shapes_list_in.size();
680
681 std::vector<std::vector<int64_t>> feature_counts(number_of_columns,
682 std::vector<int64_t>());
683 std::vector<std::vector<int64_t>> feature_start_indices(
684 number_of_columns, std::vector<int64_t>());
685
686 ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
687 &feature_start_indices);
688
689 columns.reserve(values_list_in.size());
690 for (int i = 0; i < values_list_in.size(); ++i) {
691 columns.emplace_back(new KeyedSparseTensorColumn<InternalType>(
692 values_list_in[i], std::move(feature_counts[i]),
693 std::move(feature_start_indices[i]), keys));
694 }
695 for (int i = 0; i < dense_list_in.size(); ++i) {
696 columns.emplace_back(
697 new KeyedDenseTensorColumn<InternalType>(dense_list_in[i], keys));
698 }
699
700 return columns;
701 }
702
703 // Allocates output tensors with proper size and sets the shape tensor of
704 // the output SparseTensor.
705 // It also output_start_indices which contains the start indices for each
706 // input in the output SparseTensor.
707 template <typename InternalType>
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64_t> * output_start_indices)708 Status CreateOutputTensors(
709 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
710 int64_t batch_size, OpKernelContext* context, Tensor** indices_out,
711 Tensor** values_out, Tensor** shape_out,
712 std::vector<int64_t>* output_start_indices) {
713 // Calculates dimensions for output tensors.
714 int64_t cross_count_total = 0;
715 int64_t max_cross_count = 0;
716 for (int64_t b = 0; b < batch_size; b++) {
717 // For each input, sets starting indices in output SparseTensor
718 (*output_start_indices)[b] = cross_count_total;
719 const auto cross_count = CrossCountByBatchIndex(columns, b);
720 max_cross_count = std::max(max_cross_count, cross_count);
721 cross_count_total += cross_count;
722 }
723
724 // Allocates tensors.
725 TF_RETURN_IF_ERROR(context->allocate_output(
726 0, TensorShape({cross_count_total, 2}), indices_out));
727 TF_RETURN_IF_ERROR(context->allocate_output(
728 1, TensorShape({cross_count_total}), values_out));
729 TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({2}), shape_out));
730
731 // Sets shape.
732 auto shape_vec = (*shape_out)->vec<int64_t>();
733 shape_vec(0) = batch_size;
734 shape_vec(1) = max_cross_count;
735
736 return OkStatus();
737 }
738
739 template <bool HASHED_OUTPUT, typename InternalType>
740 class SparseCrossOp : public OpKernel {
741 public:
SparseCrossOp(OpKernelConstruction * context)742 explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
743 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
744 // Read signed_hash_key_ as int64 since uint64 attributes are not
745 // supported by REGISTER_OP.
746 int64_t signed_hash_key_;
747 OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
748 hash_key_ = static_cast<uint64>(signed_hash_key_);
749 OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_));
750 }
751
Compute(OpKernelContext * context)752 void Compute(OpKernelContext* context) override {
753 OpInputList indices_list_in;
754 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
755 OpInputList values_list_in;
756 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
757 OpInputList shapes_list_in;
758 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
759 OpInputList dense_list_in;
760 OP_REQUIRES_OK(context,
761 context->input_list("dense_inputs", &dense_list_in));
762
763 DataType internal_type = internal_type_;
764 OP_REQUIRES_OK(
765 context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
766 dense_list_in, internal_type));
767
768 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
769 GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
770 shapes_list_in, dense_list_in);
771
772 const tstring k_feature_separator = "_X_";
773 typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
774 columns, num_buckets_, hash_key_, k_feature_separator);
775 Tensor* indices_out;
776 Tensor* values_out;
777 Tensor* shape_out;
778 const int64_t batch_size =
779 CalculateBatchSize(shapes_list_in, dense_list_in);
780 std::vector<int64_t> output_start_indices(batch_size);
781 OP_REQUIRES_OK(
782 context,
783 CreateOutputTensors(columns, batch_size, context, &indices_out,
784 &values_out, &shape_out, &output_start_indices));
785
786 typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
787 output_start_indices, indices_out, values_out);
788 auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
789 for (int b = begin; b < end; b++) {
790 ProductIterator<InternalType> product_iterator(columns, b);
791 int64_t cross_count = 0;
792 while (product_iterator.HasNext()) {
793 const auto permutation = product_iterator.Next();
794 updater.Update(b, cross_count,
795 crosser.Generate(b, permutation, false));
796 cross_count++;
797 }
798 }
799 };
800
801 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
802 // TODO(zakaria): optimize kCostPerUnit
803 const int kCostPerUnit = 5000 * indices_list_in.size();
804 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
805 kCostPerUnit, do_work);
806 }
807
808 private:
809 int64_t num_buckets_;
810 uint64 hash_key_;
811 DataType internal_type_;
812 };
813
814 class SparseCrossV2Op : public OpKernel {
815 public:
SparseCrossV2Op(OpKernelConstruction * context)816 explicit SparseCrossV2Op(OpKernelConstruction* context) : OpKernel(context) {}
817
Compute(OpKernelContext * context)818 void Compute(OpKernelContext* context) override {
819 OpInputList indices_list_in;
820 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
821 OpInputList values_list_in;
822 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
823 OpInputList shapes_list_in;
824 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
825 OpInputList dense_list_in;
826 OP_REQUIRES_OK(context,
827 context->input_list("dense_inputs", &dense_list_in));
828
829 // Set internal_type to invalid_type so that the check will be ignored.
830 DataType internal_type = DT_INVALID;
831 OP_REQUIRES_OK(
832 context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
833 dense_list_in, internal_type));
834
835 const Tensor* sep_t;
836 OP_REQUIRES_OK(context, context->input("sep", &sep_t));
837 OP_REQUIRES(context, TensorShapeUtils::IsScalar(sep_t->shape()),
838 errors::InvalidArgument("Input separator should be a scalar. "
839 "Received: ",
840 sep_t->DebugString()));
841 const tstring separator = sep_t->scalar<tstring>()();
842
843 std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns =
844 GenerateColumnsFromInput<tstring>(indices_list_in, values_list_in,
845 shapes_list_in, dense_list_in);
846 Tensor* indices_out;
847 Tensor* values_out;
848 Tensor* shape_out;
849 const int64_t batch_size =
850 CalculateBatchSize(shapes_list_in, dense_list_in);
851 std::vector<int64_t> output_start_indices(batch_size);
852 OP_REQUIRES_OK(
853 context,
854 CreateOutputTensors(columns, batch_size, context, &indices_out,
855 &values_out, &shape_out, &output_start_indices));
856 StringCrosser<tstring> crosser(columns, 0, 0, separator);
857 OutputUpdater<tstring> updater(output_start_indices, indices_out,
858 values_out);
859 auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
860 for (int b = begin; b < end; b++) {
861 ProductIterator<tstring> product_iterator(columns, b);
862 int64_t cross_count = 0;
863 while (product_iterator.HasNext()) {
864 const auto permutation = product_iterator.Next();
865 updater.Update(b, cross_count,
866 crosser.Generate(b, permutation, false));
867 cross_count++;
868 }
869 }
870 };
871
872 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
873 // TODO(zakaria): optimize kCostPerUnit
874 const int kCostPerUnit = 5000 * indices_list_in.size();
875 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
876 kCostPerUnit, do_work);
877 }
878 };
879
880 class SparseCrossHashedOp : public OpKernel {
881 public:
SparseCrossHashedOp(OpKernelConstruction * context)882 explicit SparseCrossHashedOp(OpKernelConstruction* context)
883 : OpKernel(context) {}
884
Compute(OpKernelContext * context)885 void Compute(OpKernelContext* context) override {
886 OpInputList indices_list_in;
887 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
888 OpInputList values_list_in;
889 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
890 OpInputList shapes_list_in;
891 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
892 OpInputList dense_list_in;
893 OP_REQUIRES_OK(context,
894 context->input_list("dense_inputs", &dense_list_in));
895
896 // Set internal_type to invalid_type so that the check will be ignored.
897 DataType internal_type = DT_INVALID;
898 OP_REQUIRES_OK(
899 context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
900 dense_list_in, internal_type));
901
902 const Tensor* num_buckets_t;
903 OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));
904 const int64_t num_buckets = num_buckets_t->scalar<int64_t>()();
905
906 const Tensor* strong_hash_t;
907 OP_REQUIRES_OK(context, context->input("strong_hash", &strong_hash_t));
908 const bool strong_hash = strong_hash_t->scalar<bool>()();
909
910 const Tensor* salt_t;
911 OP_REQUIRES_OK(context, context->input("salt", &salt_t));
912 const auto salt = salt_t->flat<int64_t>();
913 std::vector<int64_t> key_{salt(0), salt(1)};
914
915 std::vector<std::unique_ptr<ColumnInterface<int64_t>>> columns =
916 GenerateKeyedColumnsFromInput<int64_t>(indices_list_in, values_list_in,
917 shapes_list_in, dense_list_in,
918 key_);
919 Tensor* indices_out;
920 Tensor* values_out;
921 Tensor* shape_out;
922 const int64_t batch_size =
923 CalculateBatchSize(shapes_list_in, dense_list_in);
924 std::vector<int64_t> output_start_indices(batch_size);
925 OP_REQUIRES_OK(
926 context,
927 CreateOutputTensors(columns, batch_size, context, &indices_out,
928 &values_out, &shape_out, &output_start_indices));
929 const tstring unused_sep;
930 HashCrosserV2 crosser(columns, num_buckets, 0, unused_sep);
931 OutputUpdater<int64_t> updater(output_start_indices, indices_out,
932 values_out);
933 auto do_work = [&columns, crosser, updater, strong_hash](int64_t begin,
934 int64_t end) {
935 for (int b = begin; b < end; b++) {
936 ProductIterator<int64_t> product_iterator(columns, b);
937 int64_t cross_count = 0;
938 while (product_iterator.HasNext()) {
939 const auto permutation = product_iterator.Next();
940 updater.Update(b, cross_count,
941 crosser.Generate(b, permutation, strong_hash));
942 cross_count++;
943 }
944 }
945 };
946
947 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
948 // TODO(zakaria): optimize kCostPerUnit
949 const int kCostPerUnit = 5000 * indices_list_in.size();
950 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
951 kCostPerUnit, do_work);
952 }
953 };
954
955 REGISTER_KERNEL_BUILDER(Name("SparseCross")
956 .Device(DEVICE_CPU)
957 .TypeConstraint<tstring>("out_type")
958 .TypeConstraint<tstring>("internal_type"),
959 SparseCrossOp<false, StringPiece>);
960
961 REGISTER_KERNEL_BUILDER(Name("SparseCross")
962 .Device(DEVICE_CPU)
963 .TypeConstraint<tstring>("out_type")
964 .TypeConstraint<int64_t>("internal_type"),
965 SparseCrossOp<false, tstring>);
966
967 REGISTER_KERNEL_BUILDER(Name("SparseCross")
968 .Device(DEVICE_CPU)
969 .TypeConstraint<int64_t>("out_type")
970 .TypeConstraint<tstring>("internal_type"),
971 SparseCrossOp<true, int64>);
972
973 REGISTER_KERNEL_BUILDER(Name("SparseCross")
974 .Device(DEVICE_CPU)
975 .TypeConstraint<int64_t>("out_type")
976 .TypeConstraint<int64_t>("internal_type"),
977 SparseCrossOp<true, int64>);
978
979 REGISTER_KERNEL_BUILDER(Name("SparseCrossV2").Device(DEVICE_CPU),
980 SparseCrossV2Op);
981
982 REGISTER_KERNEL_BUILDER(Name("SparseCrossHashed").Device(DEVICE_CPU),
983 SparseCrossHashedOp);
984
985 } // namespace tensorflow
986