• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18 
19 #include <limits>
20 #include <string>
21 #include <vector>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/fingerprint.h"
34 #include "tensorflow/core/platform/strong_hash.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 namespace tensorflow {
38 
39 namespace {
40 // An interface that represents a column with batches.
41 template <typename InternalType>
42 class ColumnInterface {
43  public:
44   // Returns the number of features in the specified batch.
45   virtual int64 FeatureCount(int64_t batch) const = 0;
46 
47   // Returns the fingerprint of nth feature from the specified batch.
48   virtual InternalType Feature(int64_t batch, int64_t n,
49                                bool strong_hash) const = 0;
50 
~ColumnInterface()51   virtual ~ColumnInterface() {}
52 };
53 
54 // A column that is backed by a sparse tensor.
55 template <typename InternalType>
56 class SparseTensorColumn : public ColumnInterface<InternalType> {
57  public:
SparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices)58   SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
59                      std::vector<int64> feature_start_indices)
60       : values_(values),
61         feature_counts_(std::move(feature_counts)),
62         feature_start_indices_(std::move(feature_start_indices)) {
63     CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
64   }
65 
FeatureCount(int64_t batch) const66   int64 FeatureCount(int64_t batch) const override {
67     return feature_counts_[batch];
68   }
69 
70   InternalType Feature(int64_t batch, int64_t n,
71                        bool strong_hash) const override;
72 
~SparseTensorColumn()73   ~SparseTensorColumn() override {}
74 
75  private:
76   const Tensor& values_;
77   std::vector<int64> feature_counts_;
78   std::vector<int64> feature_start_indices_;
79 };
80 
81 // A column that is backed by a sparse tensor.
82 template <typename InternalType>
83 class KeyedSparseTensorColumn : public ColumnInterface<InternalType> {
84  public:
KeyedSparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices,std::vector<int64> key)85   KeyedSparseTensorColumn(const Tensor& values,
86                           std::vector<int64> feature_counts,
87                           std::vector<int64> feature_start_indices,
88                           std::vector<int64> key)
89       : values_(values),
90         feature_counts_(std::move(feature_counts)),
91         feature_start_indices_(std::move(feature_start_indices)) {
92     DCHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
93     std::memcpy(key_, key.data(), sizeof(key_));
94   }
95 
FeatureCount(int64_t batch) const96   int64 FeatureCount(int64_t batch) const override {
97     return feature_counts_[batch];
98   }
99 
100   InternalType Feature(int64_t batch, int64_t n,
101                        bool strong_hash) const override;
102 
~KeyedSparseTensorColumn()103   ~KeyedSparseTensorColumn() override {}
104 
105  private:
106   const Tensor& values_;
107   tensorflow::uint64 key_[2];
108   std::vector<int64> feature_counts_;
109   std::vector<int64> feature_start_indices_;
110 };
111 
112 // InternalType is int64 only when using HashCrosser.
113 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const114 int64 SparseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
115                                          bool strong_hash) const {
116   const int64_t start = feature_start_indices_[batch];
117   if (DT_STRING == values_.dtype())
118     return Fingerprint64(values_.vec<tstring>().data()[start + n]);
119   return values_.vec<int64>().data()[start + n];
120 }
121 
122 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const123 int64 KeyedSparseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
124                                               bool strong_hash) const {
125   const int64_t start = feature_start_indices_[batch];
126   if (strong_hash) {
127     if (DT_STRING == values_.dtype()) {
128       return StrongKeyedHash(key_, values_.vec<tstring>()(start + n));
129     }
130     return StrongKeyedHash(
131         key_, {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
132                sizeof(values_.dtype())});
133   }
134   if (DT_STRING == values_.dtype())
135     return Fingerprint64(values_.vec<tstring>()(start + n));
136   return Fingerprint64(
137       {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)),
138        sizeof(values_.dtype())});
139 }
140 
141 // InternalType is string or StringPiece when using StringCrosser.
142 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const143 tstring SparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
144                                              bool strong_hash) const {
145   const int64_t start = feature_start_indices_[batch];
146   if (DT_STRING == values_.dtype())
147     return values_.vec<tstring>().data()[start + n];
148   return std::to_string(values_.vec<int64>().data()[start + n]);
149 }
150 
151 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const152 tstring KeyedSparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
153                                                   bool strong_hash) const {
154   const int64_t start = feature_start_indices_[batch];
155   if (DT_STRING == values_.dtype())
156     return values_.vec<tstring>().data()[start + n];
157   return std::to_string(values_.vec<int64>().data()[start + n]);
158 }
159 
160 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const161 StringPiece SparseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
162                                                      bool strong_hash) const {
163   const int64_t start = feature_start_indices_[batch];
164   return values_.vec<tstring>().data()[start + n];
165 }
166 
167 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const168 StringPiece KeyedSparseTensorColumn<StringPiece>::Feature(
169     int64_t batch, int64_t n, bool strong_hash) const {
170   const int64_t start = feature_start_indices_[batch];
171   return values_.vec<tstring>().data()[start + n];
172 }
173 
174 // A column that is backed by a dense tensor.
175 template <typename InternalType>
176 class DenseTensorColumn : public ColumnInterface<InternalType> {
177  public:
DenseTensorColumn(const Tensor & tensor)178   explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
179 
FeatureCount(int64_t batch) const180   int64 FeatureCount(int64_t batch) const override {
181     return tensor_.dim_size(1);
182   }
183 
184   InternalType Feature(int64_t batch, int64_t n,
185                        bool strong_hash) const override;
186 
~DenseTensorColumn()187   ~DenseTensorColumn() override {}
188 
189  private:
190   const Tensor& tensor_;
191 };
192 
193 // A column that is backed by a dense tensor.
194 template <typename InternalType>
195 class KeyedDenseTensorColumn : public ColumnInterface<InternalType> {
196  public:
KeyedDenseTensorColumn(const Tensor & tensor,std::vector<int64> key)197   explicit KeyedDenseTensorColumn(const Tensor& tensor, std::vector<int64> key)
198       : tensor_(tensor) {
199     std::memcpy(key_, key.data(), sizeof(key_));
200   }
201 
FeatureCount(int64_t batch) const202   int64 FeatureCount(int64_t batch) const override {
203     return tensor_.dim_size(1);
204   }
205 
206   InternalType Feature(int64_t batch, int64_t n,
207                        bool strong_hash) const override;
208 
~KeyedDenseTensorColumn()209   ~KeyedDenseTensorColumn() override {}
210 
211  private:
212   const Tensor& tensor_;
213   tensorflow::uint64 key_[2];
214 };
215 
216 // InternalType is int64 only when using HashCrosser.
217 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const218 int64 DenseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
219                                         bool strong_hash) const {
220   if (DT_STRING == tensor_.dtype())
221     return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
222   return tensor_.matrix<int64>()(batch, n);
223 }
224 
225 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const226 int64 KeyedDenseTensorColumn<int64>::Feature(int64_t batch, int64_t n,
227                                              bool strong_hash) const {
228   if (strong_hash) {
229     if (DT_STRING == tensor_.dtype()) {
230       return StrongKeyedHash(key_, tensor_.matrix<tstring>()(batch, n));
231     }
232     return StrongKeyedHash(
233         key_, {reinterpret_cast<const char*>(tensor_.matrix<int64>()(batch, n)),
234                sizeof(tensor_.dtype())});
235   }
236   if (DT_STRING == tensor_.dtype())
237     return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
238   return tensor_.matrix<int64>()(batch, n);
239 }
240 
241 // Internal type is string or StringPiece when using StringCrosser.
242 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const243 tstring DenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
244                                             bool strong_hash) const {
245   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
246   return std::to_string(tensor_.matrix<int64>()(batch, n));
247 }
248 
249 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const250 tstring KeyedDenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
251                                                  bool strong_hash) const {
252   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
253   return std::to_string(tensor_.matrix<int64>()(batch, n));
254 }
255 
256 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const257 StringPiece DenseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
258                                                     bool strong_hash) const {
259   return tensor_.matrix<tstring>()(batch, n);
260 }
261 
262 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const263 StringPiece KeyedDenseTensorColumn<StringPiece>::Feature(
264     int64_t batch, int64_t n, bool strong_hash) const {
265   return tensor_.matrix<tstring>()(batch, n);
266 }
267 
268 // Updates Output tensors with sparse crosses.
269 template <typename OutType>
270 class OutputUpdater {
271  public:
OutputUpdater(const std::vector<int64> & output_start_indices,Tensor * indices_out,Tensor * values_out)272   OutputUpdater(const std::vector<int64>& output_start_indices,
273                 Tensor* indices_out, Tensor* values_out)
274       : output_start_indices_(output_start_indices),
275         indices_out_(indices_out),
276         values_out_(values_out) {}
277 
Update(const int64_t batch_index,const int64_t cross_count,const OutType & cross) const278   void Update(const int64_t batch_index, const int64_t cross_count,
279               const OutType& cross) const {
280     const int64_t output_index =
281         output_start_indices_[batch_index] + cross_count;
282 
283     auto indices_matrix = indices_out_->matrix<int64>();
284     indices_matrix(output_index, 0) = batch_index;
285     indices_matrix(output_index, 1) = cross_count;
286 
287     auto value_vec = values_out_->vec<OutType>();
288     value_vec(output_index) = cross;
289   }
290 
291  private:
292   const std::vector<int64>& output_start_indices_;
293   Tensor* indices_out_;
294   Tensor* values_out_;
295 };
296 
297 // Generates the sparse crosses as concatenation of strings.
298 template <typename InternalType>
299 class StringCrosser {
300  public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64_t num_buckets_unused,const uint64 hash_key_unused,const tstring k_feature_separator)301   StringCrosser(const std::vector<
302                     std::unique_ptr<ColumnInterface<InternalType>>>& columns,
303                 const int64_t num_buckets_unused, const uint64 hash_key_unused,
304                 const tstring k_feature_separator)
305       : columns_(columns), k_feature_separator_(k_feature_separator) {}
306 
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const307   string Generate(const int64_t batch_index,
308                   const std::vector<int>& permutation,
309                   bool unused_strong_hash) const {
310     gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
311     for (int i = 0; i < permutation.size(); i++) {
312       cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i], false);
313     }
314     // TODO(zakaria): this will copy the string twice, might effect
315     // performance.
316     return absl::StrJoin(cross_vec, k_feature_separator_);
317   }
318 
319  private:
320   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
321   const tstring k_feature_separator_;
322 };
323 
324 // Generates the sparse crosses as nested hash to avoid string manipulations.
325 class HashCrosser {
326  public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64_t num_buckets,const uint64 hash_key,const tstring k_feature_separator_unused)327   HashCrosser(
328       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
329       const int64_t num_buckets, const uint64 hash_key,
330       const tstring k_feature_separator_unused)
331       : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
332 
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const333   int64 Generate(const int64_t batch_index, const std::vector<int>& permutation,
334                  bool unused_strong_hash) const {
335     // Do the fingerprint concatenation on uint64.
336     uint64 hashed_output = hash_key_;
337     for (size_t i = 0; i < permutation.size(); ++i) {
338       uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i], false);
339       hashed_output = FingerprintCat64(hashed_output, hash_i);
340     }
341     // The return value is int64 based on the number of buckets.
342     if (num_buckets_ > 0) {
343       return hashed_output % num_buckets_;
344     } else {
345       // To prevent negative output we take modulo to max int64.
346       return hashed_output % std::numeric_limits<int64>::max();
347     }
348   }
349 
350  private:
351   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
352   const int64 num_buckets_;
353   const uint64 hash_key_;
354 };
355 
356 // Generates the sparse crosses as nested hash to avoid string manipulations.
357 class HashCrosserV2 {
358  public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64_t num_buckets,const uint64 hash_key_unused,const tstring k_feature_separator_unused)359   HashCrosserV2(
360       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
361       const int64_t num_buckets, const uint64 hash_key_unused,
362       const tstring k_feature_separator_unused)
363       : columns_(columns), num_buckets_(num_buckets) {}
364 
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool strong_hash) const365   int64 Generate(const int64_t batch_index, const std::vector<int>& permutation,
366                  bool strong_hash) const {
367     // Do the fingerprint concatenation on uint64.
368     uint64 hashed_output =
369         columns_[0]->Feature(batch_index, permutation[0], strong_hash);
370     for (size_t i = 1; i < permutation.size(); ++i) {
371       uint64 hash_i =
372           columns_[i]->Feature(batch_index, permutation[i], strong_hash);
373       hashed_output = FingerprintCat64(hashed_output, hash_i);
374     }
375     // The return value is int64 based on the number of buckets.
376     if (num_buckets_ > 0) {
377       return hashed_output % num_buckets_;
378     } else {
379       // To prevent negative output we take modulo to max int64.
380       return hashed_output % std::numeric_limits<int64>::max();
381     }
382   }
383 
384  private:
385   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
386   const int64 num_buckets_;
387 };
388 
389 // ProductIterator generates cartesian products based on indices.
390 template <typename InternalType>
391 class ProductIterator {
392  public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_index)393   explicit ProductIterator(
394       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
395           columns,
396       int64_t batch_index)
397       : columns_(columns), batch_index_(batch_index) {
398     next_permutation_.resize(columns_.size(), 0);
399     // Sets has_next_ to false if any feature column has 0 features.
400     has_next_ = true;
401     for (int i = 0; i < columns_.size(); i++) {
402       if (columns_[i]->FeatureCount(batch_index_) == 0) {
403         has_next_ = false;
404         break;
405       }
406     }
407   }
408 
Next()409   std::vector<int> Next() {
410     std::vector<int> permutation(next_permutation_);
411 
412     // Generates next permutation, if available.
413     bool carry = true;
414     for (int i = next_permutation_.size() - 1; i >= 0; i--) {
415       if (carry) {
416         next_permutation_[i] = next_permutation_[i] + 1;
417       }
418       if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
419         next_permutation_[i] = 0;
420       } else {
421         carry = false;
422         break;
423       }
424     }
425     has_next_ = !carry;
426     return permutation;
427   }
428 
HasNext()429   bool HasNext() { return has_next_; }
430 
431  private:
432   bool has_next_;
433   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
434   const int64 batch_index_;
435   std::vector<int> next_permutation_;
436 };
437 
438 template <bool HASHED_OUTPUT, typename InternalType>
439 struct CrossTraits;
440 
441 template <typename InternalType>
442 struct CrossTraits<false, InternalType> {
443   typedef StringCrosser<InternalType> Crosser;
444   typedef StringCrosser<InternalType> CrosserV2;
445   typedef OutputUpdater<tstring> Updater;
446 };
447 
448 template <>
449 struct CrossTraits<true, int64> {
450   typedef HashCrosser Crosser;
451   typedef HashCrosserV2 CrosserV2;
452   typedef OutputUpdater<int64> Updater;
453 };
454 }  // namespace
455 
456 // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)457 int64 CalculateBatchSize(const OpInputList& shapes_list_in,
458                          const OpInputList& dense_list_in) {
459   if (shapes_list_in.size() > 0) {
460     return shapes_list_in[0].vec<int64>()(0);
461   }
462 
463   if (dense_list_in.size() > 0) {
464     return dense_list_in[0].dim_size(0);
465   }
466 
467   return 0;
468 }
469 
470 // Validates input tensors.
ValidateInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,const DataType & internal_type)471 Status ValidateInput(const OpInputList& indices_list_in,
472                      const OpInputList& values_list_in,
473                      const OpInputList& shapes_list_in,
474                      const OpInputList& dense_list_in,
475                      const DataType& internal_type) {
476   const auto size = indices_list_in.size();
477   // Only perform internal_type check for SparseCrossOp.
478   // Check if the internal_type is not invalid before doing so.
479   bool check_type = internal_type != DT_INVALID;
480   // Validates indices_list_in OpInputList.
481   for (int i = 0; i < size; i++) {
482     if (check_type && indices_list_in[i].dtype() != DT_INT64) {
483       return errors::InvalidArgument("Input indices should be of type ",
484                                      DT_INT64, " but received ",
485                                      indices_list_in[i].dtype());
486     }
487     if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
488       return errors::InvalidArgument(
489           "Input indices should be a matrix but received shape ",
490           indices_list_in[i].shape().DebugString(), " at position ", i);
491     }
492     if (indices_list_in[i].shape().dim_size(1) != 2) {
493       return errors::InvalidArgument("Expected D2 of index to be 2 got ",
494                                      indices_list_in[i].shape().dim_size(1),
495                                      " at position ", i);
496     }
497   }
498 
499   // Validates values_list_in OpInputList.
500   if (values_list_in.size() != size) {
501     return errors::InvalidArgument("Expected ", size, " input values, got ",
502                                    values_list_in.size());
503   }
504   for (int i = 0; i < size; i++) {
505     // Make sure to avoid the expected type to be string, but input values to be
506     // int64.
507     if (check_type && internal_type == DT_STRING &&
508         values_list_in[i].dtype() == DT_INT64) {
509       return errors::InvalidArgument("Input values should be of internal type ",
510                                      internal_type, " but received ",
511                                      values_list_in[i].dtype());
512     }
513     if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
514       return errors::InvalidArgument(
515           "Input values should be a vector but received shape ",
516           values_list_in[i].shape().DebugString(), " at position ", i);
517     }
518     if (indices_list_in[i].shape().dim_size(0) !=
519         values_list_in[i].shape().dim_size(0)) {
520       return errors::InvalidArgument(
521           "Expected size of values to be ",
522           indices_list_in[i].shape().dim_size(0), " got ",
523           values_list_in[i].shape().dim_size(0), " at position ", i);
524     }
525   }
526 
527   // Validates shapes_list_in OpInputList
528   if (shapes_list_in.size() != size) {
529     return errors::InvalidArgument("Expected ", size, " input shapes, got ",
530                                    shapes_list_in.size());
531   }
532   for (int i = 0; i < size; i++) {
533     if (check_type && shapes_list_in[i].dtype() != DT_INT64) {
534       return errors::InvalidArgument("Input shape should be of type ", DT_INT64,
535                                      " but received ",
536                                      shapes_list_in[i].dtype());
537     }
538     if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
539       return errors::InvalidArgument(
540           "Input shapes should be a vector but received shape ",
541           shapes_list_in[i].shape().DebugString(), " at position ", i);
542     }
543 
544     if (shapes_list_in[i].vec<int64>().size() != 2) {
545       return errors::InvalidArgument("shape should imply a 2D tensor, but got ",
546                                      shapes_list_in[i].shape().DebugString(),
547                                      " at position ", i);
548     }
549   }
550 
551   // Validates dense_list_in OpInputList
552   for (int i = 0; i < dense_list_in.size(); ++i) {
553     // Make sure to avoid the expected type to be string, but input values to be
554     // int64.
555     if (check_type && internal_type == DT_STRING &&
556         dense_list_in[i].dtype() == DT_INT64) {
557       return errors::InvalidArgument("Dense inputs should be of internal type ",
558                                      internal_type, " but received ",
559                                      dense_list_in[i].dtype());
560     }
561     if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
562       return errors::InvalidArgument(
563           "Dense inputs should be a matrix but received shape ",
564           dense_list_in[i].shape().DebugString(), " at position ", i);
565     }
566   }
567 
568   // Validates batch sizes.  (Note: we do this after validating the input
569   // shapes, because CalculateBatchSize() depends on inputs having valid
570   // shapes).
571   const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
572   for (int i = 0; i < size; i++) {
573     if (shapes_list_in[i].vec<int64>()(0) != batch_size) {
574       return errors::InvalidArgument("Expected batch size ", batch_size,
575                                      " got ", shapes_list_in[i].vec<int64>()(0),
576                                      " at position ", i);
577     }
578   }
579   for (int i = 0; i < dense_list_in.size(); ++i) {
580     if (dense_list_in[i].dim_size(0) != batch_size) {
581       return errors::InvalidArgument("Expected batch size ", batch_size,
582                                      " got ", dense_list_in[i].dim_size(0),
583                                      " at dense tensor ", i);
584     }
585   }
586 
587   return Status::OK();
588 }
589 
590 // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64_t batch_size,std::vector<std::vector<int64>> * feature_counts,std::vector<std::vector<int64>> * feature_start_indices)591 void ExtractFeatureData(
592     const OpInputList& indices_list_in, int64_t batch_size,
593     std::vector<std::vector<int64>>* feature_counts,
594     std::vector<std::vector<int64>>* feature_start_indices) {
595   gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
596   for (int b = 0; b < batch_size; b++) {
597     for (int i = 0; i < indices_list_in.size(); i++) {
598       const auto indices = indices_list_in[i].matrix<int64>();
599       int64_t feature_count = 0;
600       int64_t start_index = current_row[i];
601       // Loops until we reach next batch index for current feature column.
602       while (current_row[i] < indices_list_in[i].dim_size(0) &&
603              indices(current_row[i], 0) == b) {
604         feature_count++;
605         current_row[i]++;
606       }
607       (*feature_counts)[i].push_back(feature_count);
608       (*feature_start_indices)[i].push_back(start_index);
609     }
610   }
611 }
612 
613 // Returns number of crosses for a given batch_index
614 template <typename InternalType>
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)615 int64 CrossCountByBatchIndex(
616     const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
617     int batch_index) {
618   int64_t cross_count = 1;
619   for (int i = 0; i < columns.size(); i++) {
620     const auto feature_count = columns[i]->FeatureCount(batch_index);
621     // If one column is missing any feature, there won't be any cross.
622     if (feature_count == 0) {
623       return 0;
624     }
625     cross_count *= feature_count;
626   }
627   return cross_count;
628 }
629 
630 // Generate the columns given the sparse and dense inputs.
631 template <typename InternalType>
632 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)633 GenerateColumnsFromInput(const OpInputList& indices_list_in,
634                          const OpInputList& values_list_in,
635                          const OpInputList& shapes_list_in,
636                          const OpInputList& dense_list_in) {
637   std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
638   const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
639   const int64_t number_of_columns = shapes_list_in.size();
640 
641   std::vector<std::vector<int64>> feature_counts(number_of_columns,
642                                                  std::vector<int64>());
643   std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
644                                                         std::vector<int64>());
645 
646   ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
647                      &feature_start_indices);
648 
649   columns.reserve(values_list_in.size());
650   for (int i = 0; i < values_list_in.size(); ++i) {
651     columns.emplace_back(new SparseTensorColumn<InternalType>(
652         values_list_in[i], std::move(feature_counts[i]),
653         std::move(feature_start_indices[i])));
654   }
655   for (int i = 0; i < dense_list_in.size(); ++i) {
656     columns.emplace_back(new DenseTensorColumn<InternalType>(dense_list_in[i]));
657   }
658 
659   return columns;
660 }
661 
662 // Generate the columns given the sparse and dense inputs.
663 template <typename InternalType>
664 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateKeyedColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,std::vector<int64> keys)665 GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in,
666                               const OpInputList& values_list_in,
667                               const OpInputList& shapes_list_in,
668                               const OpInputList& dense_list_in,
669                               std::vector<int64> keys) {
670   std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
671   const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
672   const int64_t number_of_columns = shapes_list_in.size();
673 
674   std::vector<std::vector<int64>> feature_counts(number_of_columns,
675                                                  std::vector<int64>());
676   std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
677                                                         std::vector<int64>());
678 
679   ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
680                      &feature_start_indices);
681 
682   columns.reserve(values_list_in.size());
683   for (int i = 0; i < values_list_in.size(); ++i) {
684     columns.emplace_back(new KeyedSparseTensorColumn<InternalType>(
685         values_list_in[i], std::move(feature_counts[i]),
686         std::move(feature_start_indices[i]), keys));
687   }
688   for (int i = 0; i < dense_list_in.size(); ++i) {
689     columns.emplace_back(
690         new KeyedDenseTensorColumn<InternalType>(dense_list_in[i], keys));
691   }
692 
693   return columns;
694 }
695 
696 // Allocates output tensors with proper size and sets the shape tensor of
697 // the output SparseTensor.
698 // It also output_start_indices which contains the start indices for each
699 // input in the output SparseTensor.
700 template <typename InternalType>
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64> * output_start_indices)701 Status CreateOutputTensors(
702     const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
703     int64_t batch_size, OpKernelContext* context, Tensor** indices_out,
704     Tensor** values_out, Tensor** shape_out,
705     std::vector<int64>* output_start_indices) {
706   // Calculates dimensions for output tensors.
707   int64_t cross_count_total = 0;
708   int64_t max_cross_count = 0;
709   for (int64_t b = 0; b < batch_size; b++) {
710     // For each input, sets starting indices in output SparseTensor
711     (*output_start_indices)[b] = cross_count_total;
712     const auto cross_count = CrossCountByBatchIndex(columns, b);
713     max_cross_count = std::max(max_cross_count, cross_count);
714     cross_count_total += cross_count;
715   }
716 
717   // Allocates tensors.
718   TF_RETURN_IF_ERROR(context->allocate_output(
719       0, TensorShape({cross_count_total, 2}), indices_out));
720   TF_RETURN_IF_ERROR(context->allocate_output(
721       1, TensorShape({cross_count_total}), values_out));
722   TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({2}), shape_out));
723 
724   // Sets shape.
725   auto shape_vec = (*shape_out)->vec<int64>();
726   shape_vec(0) = batch_size;
727   shape_vec(1) = max_cross_count;
728 
729   return Status::OK();
730 }
731 
732 template <bool HASHED_OUTPUT, typename InternalType>
733 class SparseCrossOp : public OpKernel {
734  public:
SparseCrossOp(OpKernelConstruction * context)735   explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
736     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
737     // Read signed_hash_key_ as int64 since uint64 attributes are not
738     // supported by REGISTER_OP.
739     int64_t signed_hash_key_;
740     OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
741     hash_key_ = static_cast<uint64>(signed_hash_key_);
742     OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_));
743   }
744 
Compute(OpKernelContext * context)745   void Compute(OpKernelContext* context) override {
746     OpInputList indices_list_in;
747     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
748     OpInputList values_list_in;
749     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
750     OpInputList shapes_list_in;
751     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
752     OpInputList dense_list_in;
753     OP_REQUIRES_OK(context,
754                    context->input_list("dense_inputs", &dense_list_in));
755 
756     DataType internal_type = internal_type_;
757     OP_REQUIRES_OK(
758         context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
759                                dense_list_in, internal_type));
760 
761     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
762         GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
763                                                shapes_list_in, dense_list_in);
764 
765     const tstring k_feature_separator = "_X_";
766     typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
767         columns, num_buckets_, hash_key_, k_feature_separator);
768     Tensor* indices_out;
769     Tensor* values_out;
770     Tensor* shape_out;
771     const int64_t batch_size =
772         CalculateBatchSize(shapes_list_in, dense_list_in);
773     std::vector<int64> output_start_indices(batch_size);
774     OP_REQUIRES_OK(
775         context,
776         CreateOutputTensors(columns, batch_size, context, &indices_out,
777                             &values_out, &shape_out, &output_start_indices));
778 
779     typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
780         output_start_indices, indices_out, values_out);
781     auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
782       for (int b = begin; b < end; b++) {
783         ProductIterator<InternalType> product_iterator(columns, b);
784         int64_t cross_count = 0;
785         while (product_iterator.HasNext()) {
786           const auto permutation = product_iterator.Next();
787           updater.Update(b, cross_count,
788                          crosser.Generate(b, permutation, false));
789           cross_count++;
790         }
791       }
792     };
793 
794     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
795     // TODO(zakaria): optimize kCostPerUnit
796     const int kCostPerUnit = 5000 * indices_list_in.size();
797     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
798           kCostPerUnit, do_work);
799   }
800 
801  private:
802   int64 num_buckets_;
803   uint64 hash_key_;
804   DataType internal_type_;
805 };
806 
807 class SparseCrossV2Op : public OpKernel {
808  public:
SparseCrossV2Op(OpKernelConstruction * context)809   explicit SparseCrossV2Op(OpKernelConstruction* context) : OpKernel(context) {}
810 
Compute(OpKernelContext * context)811   void Compute(OpKernelContext* context) override {
812     OpInputList indices_list_in;
813     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
814     OpInputList values_list_in;
815     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
816     OpInputList shapes_list_in;
817     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
818     OpInputList dense_list_in;
819     OP_REQUIRES_OK(context,
820                    context->input_list("dense_inputs", &dense_list_in));
821 
822     // Set internal_type to invalid_type so that the check will be ignored.
823     DataType internal_type = DT_INVALID;
824     OP_REQUIRES_OK(
825         context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
826                                dense_list_in, internal_type));
827 
828     const Tensor* sep_t;
829     OP_REQUIRES_OK(context, context->input("sep", &sep_t));
830     const tstring separator = sep_t->scalar<tstring>()();
831 
832     std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns =
833         GenerateColumnsFromInput<tstring>(indices_list_in, values_list_in,
834                                           shapes_list_in, dense_list_in);
835     Tensor* indices_out;
836     Tensor* values_out;
837     Tensor* shape_out;
838     const int64_t batch_size =
839         CalculateBatchSize(shapes_list_in, dense_list_in);
840     std::vector<int64> output_start_indices(batch_size);
841     OP_REQUIRES_OK(
842         context,
843         CreateOutputTensors(columns, batch_size, context, &indices_out,
844                             &values_out, &shape_out, &output_start_indices));
845     StringCrosser<tstring> crosser(columns, 0, 0, separator);
846     OutputUpdater<tstring> updater(output_start_indices, indices_out,
847                                    values_out);
848     auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
849       for (int b = begin; b < end; b++) {
850         ProductIterator<tstring> product_iterator(columns, b);
851         int64_t cross_count = 0;
852         while (product_iterator.HasNext()) {
853           const auto permutation = product_iterator.Next();
854           updater.Update(b, cross_count,
855                          crosser.Generate(b, permutation, false));
856           cross_count++;
857         }
858       }
859     };
860 
861     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
862     // TODO(zakaria): optimize kCostPerUnit
863     const int kCostPerUnit = 5000 * indices_list_in.size();
864     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
865           kCostPerUnit, do_work);
866   }
867 };
868 
869 class SparseCrossHashedOp : public OpKernel {
870  public:
SparseCrossHashedOp(OpKernelConstruction * context)871   explicit SparseCrossHashedOp(OpKernelConstruction* context)
872       : OpKernel(context) {}
873 
Compute(OpKernelContext * context)874   void Compute(OpKernelContext* context) override {
875     OpInputList indices_list_in;
876     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
877     OpInputList values_list_in;
878     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
879     OpInputList shapes_list_in;
880     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
881     OpInputList dense_list_in;
882     OP_REQUIRES_OK(context,
883                    context->input_list("dense_inputs", &dense_list_in));
884 
885     // Set internal_type to invalid_type so that the check will be ignored.
886     DataType internal_type = DT_INVALID;
887     OP_REQUIRES_OK(
888         context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
889                                dense_list_in, internal_type));
890 
891     const Tensor* num_buckets_t;
892     OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));
893     const int64_t num_buckets = num_buckets_t->scalar<int64>()();
894 
895     const Tensor* strong_hash_t;
896     OP_REQUIRES_OK(context, context->input("strong_hash", &strong_hash_t));
897     const bool strong_hash = strong_hash_t->scalar<bool>()();
898 
899     const Tensor* salt_t;
900     OP_REQUIRES_OK(context, context->input("salt", &salt_t));
901     const auto salt = salt_t->flat<int64>();
902     std::vector<int64> key_{salt(0), salt(1)};
903 
904     std::vector<std::unique_ptr<ColumnInterface<int64>>> columns =
905         GenerateKeyedColumnsFromInput<int64>(indices_list_in, values_list_in,
906                                              shapes_list_in, dense_list_in,
907                                              key_);
908     Tensor* indices_out;
909     Tensor* values_out;
910     Tensor* shape_out;
911     const int64_t batch_size =
912         CalculateBatchSize(shapes_list_in, dense_list_in);
913     std::vector<int64> output_start_indices(batch_size);
914     OP_REQUIRES_OK(
915         context,
916         CreateOutputTensors(columns, batch_size, context, &indices_out,
917                             &values_out, &shape_out, &output_start_indices));
918     const tstring unused_sep;
919     HashCrosserV2 crosser(columns, num_buckets, 0, unused_sep);
920     OutputUpdater<int64> updater(output_start_indices, indices_out, values_out);
921     auto do_work = [&columns, crosser, updater, strong_hash](int64_t begin,
922                                                              int64_t end) {
923       for (int b = begin; b < end; b++) {
924         ProductIterator<int64> product_iterator(columns, b);
925         int64_t cross_count = 0;
926         while (product_iterator.HasNext()) {
927           const auto permutation = product_iterator.Next();
928           updater.Update(b, cross_count,
929                          crosser.Generate(b, permutation, strong_hash));
930           cross_count++;
931         }
932       }
933     };
934 
935     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
936     // TODO(zakaria): optimize kCostPerUnit
937     const int kCostPerUnit = 5000 * indices_list_in.size();
938     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
939           kCostPerUnit, do_work);
940   }
941 };
942 
943 REGISTER_KERNEL_BUILDER(Name("SparseCross")
944                             .Device(DEVICE_CPU)
945                             .TypeConstraint<tstring>("out_type")
946                             .TypeConstraint<tstring>("internal_type"),
947                         SparseCrossOp<false, StringPiece>);
948 
949 REGISTER_KERNEL_BUILDER(Name("SparseCross")
950                             .Device(DEVICE_CPU)
951                             .TypeConstraint<tstring>("out_type")
952                             .TypeConstraint<int64>("internal_type"),
953                         SparseCrossOp<false, tstring>);
954 
955 REGISTER_KERNEL_BUILDER(Name("SparseCross")
956                             .Device(DEVICE_CPU)
957                             .TypeConstraint<int64>("out_type")
958                             .TypeConstraint<tstring>("internal_type"),
959                         SparseCrossOp<true, int64>);
960 
961 REGISTER_KERNEL_BUILDER(Name("SparseCross")
962                             .Device(DEVICE_CPU)
963                             .TypeConstraint<int64>("out_type")
964                             .TypeConstraint<int64>("internal_type"),
965                         SparseCrossOp<true, int64>);
966 
967 REGISTER_KERNEL_BUILDER(Name("SparseCrossV2").Device(DEVICE_CPU),
968                         SparseCrossV2Op);
969 
970 REGISTER_KERNEL_BUILDER(Name("SparseCrossHashed").Device(DEVICE_CPU),
971                         SparseCrossHashedOp);
972 
973 }  // namespace tensorflow
974