• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18 
19 #include <limits>
20 #include <string>
21 #include <vector>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/platform/fingerprint.h"
33 #include "tensorflow/core/platform/strong_hash.h"
34 #include "tensorflow/core/util/work_sharder.h"
35 
36 namespace tensorflow {
37 
38 namespace {
39 // An interface that represents a column with batches.
40 template <typename InternalType>
41 class ColumnInterface {
42  public:
43   // Returns the number of features in the specified batch.
44   virtual int64 FeatureCount(int64 batch) const = 0;
45 
46   // Returns the fingerprint of nth feature from the specified batch.
47   virtual InternalType Feature(int64 batch, int64 n,
48                                bool strong_hash) const = 0;
49 
~ColumnInterface()50   virtual ~ColumnInterface() {}
51 };
52 
53 // A column that is backed by a sparse tensor.
54 template <typename InternalType>
55 class SparseTensorColumn : public ColumnInterface<InternalType> {
56  public:
SparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices)57   SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
58                      std::vector<int64> feature_start_indices)
59       : values_(values),
60         feature_counts_(std::move(feature_counts)),
61         feature_start_indices_(std::move(feature_start_indices)) {
62     CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
63   }
64 
FeatureCount(int64 batch) const65   int64 FeatureCount(int64 batch) const override {
66     return feature_counts_[batch];
67   }
68 
69   InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
70 
~SparseTensorColumn()71   ~SparseTensorColumn() override {}
72 
73  private:
74   const Tensor& values_;
75   std::vector<int64> feature_counts_;
76   std::vector<int64> feature_start_indices_;
77 };
78 
79 // A column that is backed by a sparse tensor.
80 template <typename InternalType>
81 class KeyedSparseTensorColumn : public ColumnInterface<InternalType> {
82  public:
KeyedSparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices,std::vector<int64> key)83   KeyedSparseTensorColumn(const Tensor& values,
84                           std::vector<int64> feature_counts,
85                           std::vector<int64> feature_start_indices,
86                           std::vector<int64> key)
87       : values_(values),
88         feature_counts_(std::move(feature_counts)),
89         feature_start_indices_(std::move(feature_start_indices)) {
90     DCHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
91     std::memcpy(key_, key.data(), sizeof(key_));
92   }
93 
FeatureCount(int64 batch) const94   int64 FeatureCount(int64 batch) const override {
95     return feature_counts_[batch];
96   }
97 
98   InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
99 
~KeyedSparseTensorColumn()100   ~KeyedSparseTensorColumn() override {}
101 
102  private:
103   const Tensor& values_;
104   tensorflow::uint64 key_[2];
105   std::vector<int64> feature_counts_;
106   std::vector<int64> feature_start_indices_;
107 };
108 
109 // InternalType is int64 only when using HashCrosser.
110 template <>
Feature(int64 batch,int64 n,bool strong_hash) const111 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n,
112                                          bool strong_hash) const {
113   const int64 start = feature_start_indices_[batch];
114   if (DT_STRING == values_.dtype())
115     return Fingerprint64(values_.vec<tstring>().data()[start + n]);
116   return values_.vec<int64>().data()[start + n];
117 }
118 
119 template <>
Feature(int64 batch,int64 n,bool strong_hash) const120 int64 KeyedSparseTensorColumn<int64>::Feature(int64 batch, int64 n,
121                                               bool strong_hash) const {
122   const int64 start = feature_start_indices_[batch];
123   if (strong_hash) {
124     if (DT_STRING == values_.dtype()) {
125       return StrongKeyedHash(key_, values_.vec<tstring>()(start + n));
126     }
127     return StrongKeyedHash(
128         key_, {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
129                sizeof(values_.dtype())});
130   }
131   if (DT_STRING == values_.dtype())
132     return Fingerprint64(values_.vec<tstring>()(start + n));
133   return Fingerprint64(
134       {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
135        sizeof(values_.dtype())});
136 }
137 
138 // InternalType is string or StringPiece when using StringCrosser.
139 template <>
Feature(int64 batch,int64 n,bool strong_hash) const140 tstring SparseTensorColumn<tstring>::Feature(int64 batch, int64 n,
141                                              bool strong_hash) const {
142   const int64 start = feature_start_indices_[batch];
143   if (DT_STRING == values_.dtype())
144     return values_.vec<tstring>().data()[start + n];
145   return std::to_string(values_.vec<int64>().data()[start + n]);
146 }
147 
148 template <>
Feature(int64 batch,int64 n,bool strong_hash) const149 tstring KeyedSparseTensorColumn<tstring>::Feature(int64 batch, int64 n,
150                                                   bool strong_hash) const {
151   const int64 start = feature_start_indices_[batch];
152   if (DT_STRING == values_.dtype())
153     return values_.vec<tstring>().data()[start + n];
154   return std::to_string(values_.vec<int64>().data()[start + n]);
155 }
156 
157 template <>
Feature(int64 batch,int64 n,bool strong_hash) const158 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch, int64 n,
159                                                      bool strong_hash) const {
160   const int64 start = feature_start_indices_[batch];
161   return values_.vec<tstring>().data()[start + n];
162 }
163 
164 template <>
Feature(int64 batch,int64 n,bool strong_hash) const165 StringPiece KeyedSparseTensorColumn<StringPiece>::Feature(
166     int64 batch, int64 n, bool strong_hash) const {
167   const int64 start = feature_start_indices_[batch];
168   return values_.vec<tstring>().data()[start + n];
169 }
170 
171 // A column that is backed by a dense tensor.
172 template <typename InternalType>
173 class DenseTensorColumn : public ColumnInterface<InternalType> {
174  public:
DenseTensorColumn(const Tensor & tensor)175   explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
176 
FeatureCount(int64 batch) const177   int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
178 
179   InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
180 
~DenseTensorColumn()181   ~DenseTensorColumn() override {}
182 
183  private:
184   const Tensor& tensor_;
185 };
186 
187 // A column that is backed by a dense tensor.
188 template <typename InternalType>
189 class KeyedDenseTensorColumn : public ColumnInterface<InternalType> {
190  public:
KeyedDenseTensorColumn(const Tensor & tensor,std::vector<int64> key)191   explicit KeyedDenseTensorColumn(const Tensor& tensor, std::vector<int64> key)
192       : tensor_(tensor) {
193     std::memcpy(key_, key.data(), sizeof(key_));
194   }
195 
FeatureCount(int64 batch) const196   int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
197 
198   InternalType Feature(int64 batch, int64 n, bool strong_hash) const override;
199 
~KeyedDenseTensorColumn()200   ~KeyedDenseTensorColumn() override {}
201 
202  private:
203   const Tensor& tensor_;
204   tensorflow::uint64 key_[2];
205 };
206 
207 // InternalType is int64 only when using HashCrosser.
208 template <>
Feature(int64 batch,int64 n,bool strong_hash) const209 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n,
210                                         bool strong_hash) const {
211   if (DT_STRING == tensor_.dtype())
212     return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
213   return tensor_.matrix<int64>()(batch, n);
214 }
215 
216 template <>
Feature(int64 batch,int64 n,bool strong_hash) const217 int64 KeyedDenseTensorColumn<int64>::Feature(int64 batch, int64 n,
218                                              bool strong_hash) const {
219   if (strong_hash) {
220     if (DT_STRING == tensor_.dtype()) {
221       return StrongKeyedHash(key_, tensor_.matrix<tstring>()(batch, n));
222     }
223     return StrongKeyedHash(
224         key_, {reinterpret_cast<const char*>(tensor_.matrix<int64>()(batch, n)),
225                sizeof(tensor_.dtype())});
226   }
227   if (DT_STRING == tensor_.dtype())
228     return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
229   return tensor_.matrix<int64>()(batch, n);
230 }
231 
232 // Internal type is string or StringPiece when using StringCrosser.
233 template <>
Feature(int64 batch,int64 n,bool strong_hash) const234 tstring DenseTensorColumn<tstring>::Feature(int64 batch, int64 n,
235                                             bool strong_hash) const {
236   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
237   return std::to_string(tensor_.matrix<int64>()(batch, n));
238 }
239 
240 template <>
Feature(int64 batch,int64 n,bool strong_hash) const241 tstring KeyedDenseTensorColumn<tstring>::Feature(int64 batch, int64 n,
242                                                  bool strong_hash) const {
243   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
244   return std::to_string(tensor_.matrix<int64>()(batch, n));
245 }
246 
247 template <>
Feature(int64 batch,int64 n,bool strong_hash) const248 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch, int64 n,
249                                                     bool strong_hash) const {
250   return tensor_.matrix<tstring>()(batch, n);
251 }
252 
253 template <>
Feature(int64 batch,int64 n,bool strong_hash) const254 StringPiece KeyedDenseTensorColumn<StringPiece>::Feature(
255     int64 batch, int64 n, bool strong_hash) const {
256   return tensor_.matrix<tstring>()(batch, n);
257 }
258 
259 // Updates Output tensors with sparse crosses.
260 template <typename OutType>
261 class OutputUpdater {
262  public:
OutputUpdater(const std::vector<int64> & output_start_indices,Tensor * indices_out,Tensor * values_out)263   OutputUpdater(const std::vector<int64>& output_start_indices,
264                 Tensor* indices_out, Tensor* values_out)
265       : output_start_indices_(output_start_indices),
266         indices_out_(indices_out),
267         values_out_(values_out) {}
268 
Update(const int64 batch_index,const int64 cross_count,const OutType & cross) const269   void Update(const int64 batch_index, const int64 cross_count,
270               const OutType& cross) const {
271     const int64 output_index = output_start_indices_[batch_index] + cross_count;
272 
273     auto indices_matrix = indices_out_->matrix<int64>();
274     indices_matrix(output_index, 0) = batch_index;
275     indices_matrix(output_index, 1) = cross_count;
276 
277     auto value_vec = values_out_->vec<OutType>();
278     value_vec(output_index) = cross;
279   }
280 
281  private:
282   const std::vector<int64>& output_start_indices_;
283   Tensor* indices_out_;
284   Tensor* values_out_;
285 };
286 
287 // Generates the sparse crosses as concatenation of strings.
288 template <typename InternalType>
289 class StringCrosser {
290  public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64 num_buckets_unused,const uint64 hash_key_unused,const tstring k_feature_separator)291   StringCrosser(const std::vector<
292                     std::unique_ptr<ColumnInterface<InternalType>>>& columns,
293                 const int64 num_buckets_unused, const uint64 hash_key_unused,
294                 const tstring k_feature_separator)
295       : columns_(columns), k_feature_separator_(k_feature_separator) {}
296 
Generate(const int64 batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const297   string Generate(const int64 batch_index, const std::vector<int>& permutation,
298                   bool unused_strong_hash) const {
299     gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
300     for (int i = 0; i < permutation.size(); i++) {
301       cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i], false);
302     }
303     // TODO(zakaria): this will copy the string twice, might effect
304     // performance.
305     return absl::StrJoin(cross_vec, k_feature_separator_);
306   }
307 
308  private:
309   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
310   const tstring k_feature_separator_;
311 };
312 
313 // Generates the sparse crosses as nested hash to avoid string manipulations.
314 class HashCrosser {
315  public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key,const tstring k_feature_separator_unused)316   HashCrosser(
317       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
318       const int64 num_buckets, const uint64 hash_key,
319       const tstring k_feature_separator_unused)
320       : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
321 
Generate(const int64 batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const322   int64 Generate(const int64 batch_index, const std::vector<int>& permutation,
323                  bool unused_strong_hash) const {
324     // Do the fingerprint concatenation on uint64.
325     uint64 hashed_output = hash_key_;
326     for (size_t i = 0; i < permutation.size(); ++i) {
327       uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i], false);
328       hashed_output = FingerprintCat64(hashed_output, hash_i);
329     }
330     // The return value is int64 based on the number of buckets.
331     if (num_buckets_ > 0) {
332       return hashed_output % num_buckets_;
333     } else {
334       // To prevent negative output we take modulo to max int64.
335       return hashed_output % std::numeric_limits<int64>::max();
336     }
337   }
338 
339  private:
340   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
341   const int64 num_buckets_;
342   const uint64 hash_key_;
343 };
344 
345 // Generates the sparse crosses as nested hash to avoid string manipulations.
346 class HashCrosserV2 {
347  public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key_unused,const tstring k_feature_separator_unused)348   HashCrosserV2(
349       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
350       const int64 num_buckets, const uint64 hash_key_unused,
351       const tstring k_feature_separator_unused)
352       : columns_(columns), num_buckets_(num_buckets) {}
353 
Generate(const int64 batch_index,const std::vector<int> & permutation,bool strong_hash) const354   int64 Generate(const int64 batch_index, const std::vector<int>& permutation,
355                  bool strong_hash) const {
356     // Do the fingerprint concatenation on uint64.
357     uint64 hashed_output =
358         columns_[0]->Feature(batch_index, permutation[0], strong_hash);
359     for (size_t i = 1; i < permutation.size(); ++i) {
360       uint64 hash_i =
361           columns_[i]->Feature(batch_index, permutation[i], strong_hash);
362       hashed_output = FingerprintCat64(hashed_output, hash_i);
363     }
364     // The return value is int64 based on the number of buckets.
365     if (num_buckets_ > 0) {
366       return hashed_output % num_buckets_;
367     } else {
368       // To prevent negative output we take modulo to max int64.
369       return hashed_output % std::numeric_limits<int64>::max();
370     }
371   }
372 
373  private:
374   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
375   const int64 num_buckets_;
376 };
377 
378 // ProductIterator generates cartesian products based on indices.
379 template <typename InternalType>
380 class ProductIterator {
381  public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_index)382   explicit ProductIterator(
383       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
384           columns,
385       int64 batch_index)
386       : columns_(columns), batch_index_(batch_index) {
387     next_permutation_.resize(columns_.size(), 0);
388     // Sets has_next_ to false if any feature column has 0 features.
389     has_next_ = true;
390     for (int i = 0; i < columns_.size(); i++) {
391       if (columns_[i]->FeatureCount(batch_index_) == 0) {
392         has_next_ = false;
393         break;
394       }
395     }
396   }
397 
Next()398   std::vector<int> Next() {
399     std::vector<int> permutation(next_permutation_);
400 
401     // Generates next permutation, if available.
402     bool carry = true;
403     for (int i = next_permutation_.size() - 1; i >= 0; i--) {
404       if (carry) {
405         next_permutation_[i] = next_permutation_[i] + 1;
406       }
407       if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
408         next_permutation_[i] = 0;
409       } else {
410         carry = false;
411         break;
412       }
413     }
414     has_next_ = !carry;
415     return permutation;
416   }
417 
HasNext()418   bool HasNext() { return has_next_; }
419 
420  private:
421   bool has_next_;
422   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
423   const int64 batch_index_;
424   std::vector<int> next_permutation_;
425 };
426 
427 template <bool HASHED_OUTPUT, typename InternalType>
428 struct CrossTraits;
429 
430 template <typename InternalType>
431 struct CrossTraits<false, InternalType> {
432   typedef StringCrosser<InternalType> Crosser;
433   typedef StringCrosser<InternalType> CrosserV2;
434   typedef OutputUpdater<tstring> Updater;
435 };
436 
437 template <>
438 struct CrossTraits<true, int64> {
439   typedef HashCrosser Crosser;
440   typedef HashCrosserV2 CrosserV2;
441   typedef OutputUpdater<int64> Updater;
442 };
443 }  // namespace
444 
445 // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)446 int64 CalculateBatchSize(const OpInputList& shapes_list_in,
447                          const OpInputList& dense_list_in) {
448   if (shapes_list_in.size() > 0) {
449     return shapes_list_in[0].vec<int64>()(0);
450   }
451 
452   if (dense_list_in.size() > 0) {
453     return dense_list_in[0].dim_size(0);
454   }
455 
456   return 0;
457 }
458 
459 // Validates input tensors.
ValidateInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)460 Status ValidateInput(const OpInputList& indices_list_in,
461                      const OpInputList& values_list_in,
462                      const OpInputList& shapes_list_in,
463                      const OpInputList& dense_list_in) {
464   const auto size = indices_list_in.size();
465   // Validates indices_list_in OpInputList.
466   for (int i = 0; i < size; i++) {
467     if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
468       return errors::InvalidArgument(
469           "Input indices should be a matrix but received shape ",
470           indices_list_in[i].shape().DebugString(), " at position ", i);
471     }
472     if (indices_list_in[i].shape().dim_size(1) != 2) {
473       return errors::InvalidArgument("Expected D2 of index to be 2 got ",
474                                      indices_list_in[i].shape().dim_size(1),
475                                      " at position ", i);
476     }
477   }
478 
479   // Validates values_list_in OpInputList.
480   if (values_list_in.size() != size) {
481     return errors::InvalidArgument("Expected ", size, " input values, got ",
482                                    values_list_in.size());
483   }
484   for (int i = 0; i < size; i++) {
485     if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
486       return errors::InvalidArgument(
487           "Input values should be a vector but received shape ",
488           values_list_in[i].shape().DebugString(), " at position ", i);
489     }
490     if (indices_list_in[i].shape().dim_size(0) !=
491         values_list_in[i].shape().dim_size(0)) {
492       return errors::InvalidArgument(
493           "Expected size of values to be ",
494           indices_list_in[i].shape().dim_size(0), " got ",
495           values_list_in[i].shape().dim_size(0), " at position ", i);
496     }
497   }
498 
499   // Validates shapes_list_in OpInputList
500   if (shapes_list_in.size() != size) {
501     return errors::InvalidArgument("Expected ", size, " input shapes, got ",
502                                    shapes_list_in.size());
503   }
504   for (int i = 0; i < size; i++) {
505     if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
506       return errors::InvalidArgument(
507           "Input shapes should be a vector but received shape ",
508           shapes_list_in[i].shape().DebugString(), " at position ", i);
509     }
510 
511     if (shapes_list_in[i].vec<int64>().size() != 2) {
512       return errors::InvalidArgument("shape should imply a 2D tensor, but got ",
513                                      shapes_list_in[i].shape().DebugString(),
514                                      " at position ", i);
515     }
516   }
517 
518   // Validates dense_list_in OpInputList
519   for (int i = 0; i < dense_list_in.size(); ++i) {
520     if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
521       return errors::InvalidArgument(
522           "Dense inputs should be a matrix but received shape ",
523           dense_list_in[i].shape().DebugString(), " at position ", i);
524     }
525   }
526 
527   // Validates batch sizes.  (Note: we do this after validating the input
528   // shapes, because CalculateBatchSize() depends on inputs having valid
529   // shapes).
530   const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
531   for (int i = 0; i < size; i++) {
532     if (shapes_list_in[i].vec<int64>()(0) != batch_size) {
533       return errors::InvalidArgument("Expected batch size ", batch_size,
534                                      " got ", shapes_list_in[i].vec<int64>()(0),
535                                      " at position ", i);
536     }
537   }
538   for (int i = 0; i < dense_list_in.size(); ++i) {
539     if (dense_list_in[i].dim_size(0) != batch_size) {
540       return errors::InvalidArgument("Expected batch size ", batch_size,
541                                      " got ", dense_list_in[i].dim_size(0),
542                                      " at dense tensor ", i);
543     }
544   }
545 
546   return Status::OK();
547 }
548 
549 // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64 batch_size,std::vector<std::vector<int64>> * feature_counts,std::vector<std::vector<int64>> * feature_start_indices)550 void ExtractFeatureData(
551     const OpInputList& indices_list_in, int64 batch_size,
552     std::vector<std::vector<int64>>* feature_counts,
553     std::vector<std::vector<int64>>* feature_start_indices) {
554   gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
555   for (int b = 0; b < batch_size; b++) {
556     for (int i = 0; i < indices_list_in.size(); i++) {
557       const auto indices = indices_list_in[i].matrix<int64>();
558       int64 feature_count = 0;
559       int64 start_index = current_row[i];
560       // Loops until we reach next batch index for current feature column.
561       while (current_row[i] < indices_list_in[i].dim_size(0) &&
562              indices(current_row[i], 0) == b) {
563         feature_count++;
564         current_row[i]++;
565       }
566       (*feature_counts)[i].push_back(feature_count);
567       (*feature_start_indices)[i].push_back(start_index);
568     }
569   }
570 }
571 
572 // Returns number of crosses for a given batch_index
573 template <typename InternalType>
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)574 int64 CrossCountByBatchIndex(
575     const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
576     int batch_index) {
577   int64 cross_count = 1;
578   for (int i = 0; i < columns.size(); i++) {
579     const auto feature_count = columns[i]->FeatureCount(batch_index);
580     // If one column is missing any feature, there won't be any cross.
581     if (feature_count == 0) {
582       return 0;
583     }
584     cross_count *= feature_count;
585   }
586   return cross_count;
587 }
588 
589 // Generate the columns given the sparse and dense inputs.
590 template <typename InternalType>
591 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)592 GenerateColumnsFromInput(const OpInputList& indices_list_in,
593                          const OpInputList& values_list_in,
594                          const OpInputList& shapes_list_in,
595                          const OpInputList& dense_list_in) {
596   std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
597   const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
598   const int64 number_of_columns = shapes_list_in.size();
599 
600   std::vector<std::vector<int64>> feature_counts(number_of_columns,
601                                                  std::vector<int64>());
602   std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
603                                                         std::vector<int64>());
604 
605   ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
606                      &feature_start_indices);
607 
608   columns.reserve(values_list_in.size());
609   for (int i = 0; i < values_list_in.size(); ++i) {
610     columns.emplace_back(new SparseTensorColumn<InternalType>(
611         values_list_in[i], std::move(feature_counts[i]),
612         std::move(feature_start_indices[i])));
613   }
614   for (int i = 0; i < dense_list_in.size(); ++i) {
615     columns.emplace_back(new DenseTensorColumn<InternalType>(dense_list_in[i]));
616   }
617 
618   return columns;
619 }
620 
621 // Generate the columns given the sparse and dense inputs.
622 template <typename InternalType>
623 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateKeyedColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,std::vector<int64> keys)624 GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in,
625                               const OpInputList& values_list_in,
626                               const OpInputList& shapes_list_in,
627                               const OpInputList& dense_list_in,
628                               std::vector<int64> keys) {
629   std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
630   const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
631   const int64 number_of_columns = shapes_list_in.size();
632 
633   std::vector<std::vector<int64>> feature_counts(number_of_columns,
634                                                  std::vector<int64>());
635   std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
636                                                         std::vector<int64>());
637 
638   ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
639                      &feature_start_indices);
640 
641   columns.reserve(values_list_in.size());
642   for (int i = 0; i < values_list_in.size(); ++i) {
643     columns.emplace_back(new KeyedSparseTensorColumn<InternalType>(
644         values_list_in[i], std::move(feature_counts[i]),
645         std::move(feature_start_indices[i]), keys));
646   }
647   for (int i = 0; i < dense_list_in.size(); ++i) {
648     columns.emplace_back(
649         new KeyedDenseTensorColumn<InternalType>(dense_list_in[i], keys));
650   }
651 
652   return columns;
653 }
654 
655 // Allocates output tensors with proper size and sets the shape tensor of
656 // the output SparseTensor.
657 // It also output_start_indices which contains the start indices for each
658 // input in the output SparseTensor.
659 template <typename InternalType>
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64> * output_start_indices)660 Status CreateOutputTensors(
661     const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
662     int64 batch_size, OpKernelContext* context, Tensor** indices_out,
663     Tensor** values_out, Tensor** shape_out,
664     std::vector<int64>* output_start_indices) {
665   // Calculates dimensions for output tensors.
666   int64 cross_count_total = 0;
667   int64 max_cross_count = 0;
668   for (int64 b = 0; b < batch_size; b++) {
669     // For each input, sets starting indices in output SparseTensor
670     (*output_start_indices)[b] = cross_count_total;
671     const auto cross_count = CrossCountByBatchIndex(columns, b);
672     max_cross_count = std::max(max_cross_count, cross_count);
673     cross_count_total += cross_count;
674   }
675 
676   // Allocates tensors.
677   TF_RETURN_IF_ERROR(context->allocate_output(
678       0, TensorShape({cross_count_total, 2}), indices_out));
679   TF_RETURN_IF_ERROR(context->allocate_output(
680       1, TensorShape({cross_count_total}), values_out));
681   TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({2}), shape_out));
682 
683   // Sets shape.
684   auto shape_vec = (*shape_out)->vec<int64>();
685   shape_vec(0) = batch_size;
686   shape_vec(1) = max_cross_count;
687 
688   return Status::OK();
689 }
690 
691 template <bool HASHED_OUTPUT, typename InternalType>
692 class SparseCrossOp : public OpKernel {
693  public:
SparseCrossOp(OpKernelConstruction * context)694   explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
695     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
696     // Read signed_hash_key_ as int64 since uint64 attributes are not
697     // supported by REGISTER_OP.
698     int64 signed_hash_key_;
699     OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
700     hash_key_ = static_cast<uint64>(signed_hash_key_);
701   }
702 
Compute(OpKernelContext * context)703   void Compute(OpKernelContext* context) override {
704     OpInputList indices_list_in;
705     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
706     OpInputList values_list_in;
707     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
708     OpInputList shapes_list_in;
709     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
710     OpInputList dense_list_in;
711     OP_REQUIRES_OK(context,
712                    context->input_list("dense_inputs", &dense_list_in));
713 
714     OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
715                                           shapes_list_in, dense_list_in));
716 
717     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
718         GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
719                                                shapes_list_in, dense_list_in);
720 
721     const tstring k_feature_separator = "_X_";
722     typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
723         columns, num_buckets_, hash_key_, k_feature_separator);
724     Tensor* indices_out;
725     Tensor* values_out;
726     Tensor* shape_out;
727     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
728     std::vector<int64> output_start_indices(batch_size);
729     OP_REQUIRES_OK(
730         context,
731         CreateOutputTensors(columns, batch_size, context, &indices_out,
732                             &values_out, &shape_out, &output_start_indices));
733 
734     typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
735         output_start_indices, indices_out, values_out);
736     auto do_work = [&columns, crosser, updater](int64 begin, int64 end) {
737       for (int b = begin; b < end; b++) {
738         ProductIterator<InternalType> product_iterator(columns, b);
739         int64 cross_count = 0;
740         while (product_iterator.HasNext()) {
741           const auto permutation = product_iterator.Next();
742           updater.Update(b, cross_count,
743                          crosser.Generate(b, permutation, false));
744           cross_count++;
745         }
746       }
747     };
748 
749     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
750     // TODO(zakaria): optimize kCostPerUnit
751     const int kCostPerUnit = 5000 * indices_list_in.size();
752     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
753           kCostPerUnit, do_work);
754   }
755 
756  private:
757   int64 num_buckets_;
758   uint64 hash_key_;
759 };
760 
761 class SparseCrossV2Op : public OpKernel {
762  public:
SparseCrossV2Op(OpKernelConstruction * context)763   explicit SparseCrossV2Op(OpKernelConstruction* context) : OpKernel(context) {}
764 
Compute(OpKernelContext * context)765   void Compute(OpKernelContext* context) override {
766     OpInputList indices_list_in;
767     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
768     OpInputList values_list_in;
769     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
770     OpInputList shapes_list_in;
771     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
772     OpInputList dense_list_in;
773     OP_REQUIRES_OK(context,
774                    context->input_list("dense_inputs", &dense_list_in));
775 
776     OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
777                                           shapes_list_in, dense_list_in));
778 
779     const Tensor* sep_t;
780     OP_REQUIRES_OK(context, context->input("sep", &sep_t));
781     const tstring separator = sep_t->scalar<tstring>()();
782 
783     std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns =
784         GenerateColumnsFromInput<tstring>(indices_list_in, values_list_in,
785                                           shapes_list_in, dense_list_in);
786     Tensor* indices_out;
787     Tensor* values_out;
788     Tensor* shape_out;
789     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
790     std::vector<int64> output_start_indices(batch_size);
791     OP_REQUIRES_OK(
792         context,
793         CreateOutputTensors(columns, batch_size, context, &indices_out,
794                             &values_out, &shape_out, &output_start_indices));
795     StringCrosser<tstring> crosser(columns, 0, 0, separator);
796     OutputUpdater<tstring> updater(output_start_indices, indices_out,
797                                    values_out);
798     auto do_work = [&columns, crosser, updater](int64 begin, int64 end) {
799       for (int b = begin; b < end; b++) {
800         ProductIterator<tstring> product_iterator(columns, b);
801         int64 cross_count = 0;
802         while (product_iterator.HasNext()) {
803           const auto permutation = product_iterator.Next();
804           updater.Update(b, cross_count,
805                          crosser.Generate(b, permutation, false));
806           cross_count++;
807         }
808       }
809     };
810 
811     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
812     // TODO(zakaria): optimize kCostPerUnit
813     const int kCostPerUnit = 5000 * indices_list_in.size();
814     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
815           kCostPerUnit, do_work);
816   }
817 };
818 
819 class SparseCrossHashedOp : public OpKernel {
820  public:
SparseCrossHashedOp(OpKernelConstruction * context)821   explicit SparseCrossHashedOp(OpKernelConstruction* context)
822       : OpKernel(context) {}
823 
Compute(OpKernelContext * context)824   void Compute(OpKernelContext* context) override {
825     OpInputList indices_list_in;
826     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
827     OpInputList values_list_in;
828     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
829     OpInputList shapes_list_in;
830     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
831     OpInputList dense_list_in;
832     OP_REQUIRES_OK(context,
833                    context->input_list("dense_inputs", &dense_list_in));
834 
835     OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in,
836                                           shapes_list_in, dense_list_in));
837 
838     const Tensor* num_buckets_t;
839     OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));
840     const int64 num_buckets = num_buckets_t->scalar<int64>()();
841 
842     const Tensor* strong_hash_t;
843     OP_REQUIRES_OK(context, context->input("strong_hash", &strong_hash_t));
844     const bool strong_hash = strong_hash_t->scalar<bool>()();
845 
846     const Tensor* salt_t;
847     OP_REQUIRES_OK(context, context->input("salt", &salt_t));
848     const auto salt = salt_t->flat<int64>();
849     std::vector<int64> key_{salt(0), salt(1)};
850 
851     std::vector<std::unique_ptr<ColumnInterface<int64>>> columns =
852         GenerateKeyedColumnsFromInput<int64>(indices_list_in, values_list_in,
853                                              shapes_list_in, dense_list_in,
854                                              key_);
855     Tensor* indices_out;
856     Tensor* values_out;
857     Tensor* shape_out;
858     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
859     std::vector<int64> output_start_indices(batch_size);
860     OP_REQUIRES_OK(
861         context,
862         CreateOutputTensors(columns, batch_size, context, &indices_out,
863                             &values_out, &shape_out, &output_start_indices));
864     const tstring unused_sep;
865     HashCrosserV2 crosser(columns, num_buckets, 0, unused_sep);
866     OutputUpdater<int64> updater(output_start_indices, indices_out, values_out);
867     auto do_work = [&columns, crosser, updater, strong_hash](int64 begin,
868                                                              int64 end) {
869       for (int b = begin; b < end; b++) {
870         ProductIterator<int64> product_iterator(columns, b);
871         int64 cross_count = 0;
872         while (product_iterator.HasNext()) {
873           const auto permutation = product_iterator.Next();
874           updater.Update(b, cross_count,
875                          crosser.Generate(b, permutation, strong_hash));
876           cross_count++;
877         }
878       }
879     };
880 
881     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
882     // TODO(zakaria): optimize kCostPerUnit
883     const int kCostPerUnit = 5000 * indices_list_in.size();
884     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
885           kCostPerUnit, do_work);
886   }
887 };
888 
889 REGISTER_KERNEL_BUILDER(Name("SparseCross")
890                             .Device(DEVICE_CPU)
891                             .TypeConstraint<tstring>("out_type")
892                             .TypeConstraint<tstring>("internal_type"),
893                         SparseCrossOp<false, StringPiece>);
894 
895 REGISTER_KERNEL_BUILDER(Name("SparseCross")
896                             .Device(DEVICE_CPU)
897                             .TypeConstraint<tstring>("out_type")
898                             .TypeConstraint<int64>("internal_type"),
899                         SparseCrossOp<false, tstring>);
900 
901 REGISTER_KERNEL_BUILDER(Name("SparseCross")
902                             .Device(DEVICE_CPU)
903                             .TypeConstraint<int64>("out_type")
904                             .TypeConstraint<tstring>("internal_type"),
905                         SparseCrossOp<true, int64>);
906 
907 REGISTER_KERNEL_BUILDER(Name("SparseCross")
908                             .Device(DEVICE_CPU)
909                             .TypeConstraint<int64>("out_type")
910                             .TypeConstraint<int64>("internal_type"),
911                         SparseCrossOp<true, int64>);
912 
913 REGISTER_KERNEL_BUILDER(Name("SparseCrossV2").Device(DEVICE_CPU),
914                         SparseCrossV2Op);
915 
916 REGISTER_KERNEL_BUILDER(Name("SparseCrossHashed").Device(DEVICE_CPU),
917                         SparseCrossHashedOp);
918 
919 }  // namespace tensorflow
920