• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18 
19 #include <limits>
20 #include <string>
21 #include <vector>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/op_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/op_requires.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/stringpiece.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/fingerprint.h"
36 #include "tensorflow/core/platform/strong_hash.h"
37 #include "tensorflow/core/util/work_sharder.h"
38 
39 namespace tensorflow {
40 
41 namespace {
42 // An interface that represents a column with batches.
43 template <typename InternalType>
44 class ColumnInterface {
45  public:
46   // Returns the number of features in the specified batch.
47   virtual int64_t FeatureCount(int64_t batch) const = 0;
48 
49   // Returns the fingerprint of nth feature from the specified batch.
50   virtual InternalType Feature(int64_t batch, int64_t n,
51                                bool strong_hash) const = 0;
52 
~ColumnInterface()53   virtual ~ColumnInterface() {}
54 };
55 
56 // A column that is backed by a sparse tensor.
57 template <typename InternalType>
58 class SparseTensorColumn : public ColumnInterface<InternalType> {
59  public:
SparseTensorColumn(const Tensor & values,std::vector<int64_t> feature_counts,std::vector<int64_t> feature_start_indices)60   SparseTensorColumn(const Tensor& values, std::vector<int64_t> feature_counts,
61                      std::vector<int64_t> feature_start_indices)
62       : values_(values),
63         feature_counts_(std::move(feature_counts)),
64         feature_start_indices_(std::move(feature_start_indices)) {
65     CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
66   }
67 
FeatureCount(int64_t batch) const68   int64_t FeatureCount(int64_t batch) const override {
69     return feature_counts_[batch];
70   }
71 
72   InternalType Feature(int64_t batch, int64_t n,
73                        bool strong_hash) const override;
74 
~SparseTensorColumn()75   ~SparseTensorColumn() override {}
76 
77  private:
78   const Tensor& values_;
79   std::vector<int64_t> feature_counts_;
80   std::vector<int64_t> feature_start_indices_;
81 };
82 
83 // A column that is backed by a sparse tensor.
84 template <typename InternalType>
85 class KeyedSparseTensorColumn : public ColumnInterface<InternalType> {
86  public:
KeyedSparseTensorColumn(const Tensor & values,std::vector<int64_t> feature_counts,std::vector<int64_t> feature_start_indices,std::vector<int64_t> key)87   KeyedSparseTensorColumn(const Tensor& values,
88                           std::vector<int64_t> feature_counts,
89                           std::vector<int64_t> feature_start_indices,
90                           std::vector<int64_t> key)
91       : values_(values),
92         feature_counts_(std::move(feature_counts)),
93         feature_start_indices_(std::move(feature_start_indices)) {
94     DCHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
95     std::memcpy(key_, key.data(), sizeof(key_));
96   }
97 
FeatureCount(int64_t batch) const98   int64_t FeatureCount(int64_t batch) const override {
99     return feature_counts_[batch];
100   }
101 
102   InternalType Feature(int64_t batch, int64_t n,
103                        bool strong_hash) const override;
104 
~KeyedSparseTensorColumn()105   ~KeyedSparseTensorColumn() override {}
106 
107  private:
108   const Tensor& values_;
109   tensorflow::uint64 key_[2];
110   std::vector<int64_t> feature_counts_;
111   std::vector<int64_t> feature_start_indices_;
112 };
113 
114 // InternalType is int64 only when using HashCrosser.
115 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const116 int64_t SparseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
117                                              bool strong_hash) const {
118   const int64_t start = feature_start_indices_[batch];
119   if (DT_STRING == values_.dtype())
120     return Fingerprint64(values_.vec<tstring>().data()[start + n]);
121   return values_.vec<int64_t>().data()[start + n];
122 }
123 
124 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const125 int64_t KeyedSparseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
126                                                   bool strong_hash) const {
127   const int64_t start = feature_start_indices_[batch];
128   if (strong_hash) {
129     if (DT_STRING == values_.dtype()) {
130       return StrongKeyedHash(key_, values_.vec<tstring>()(start + n));
131     }
132     return StrongKeyedHash(
133         key_,
134         {reinterpret_cast<const char*>(&values_.vec<int64_t>()(start + n)),
135          sizeof(values_.dtype())});
136   }
137   if (DT_STRING == values_.dtype())
138     return Fingerprint64(values_.vec<tstring>()(start + n));
139   return Fingerprint64(
140       {reinterpret_cast<const char*>(&values_.vec<int64_t>()(start + n)),
141        sizeof(values_.dtype())});
142 }
143 
144 // InternalType is string or StringPiece when using StringCrosser.
145 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const146 tstring SparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
147                                              bool strong_hash) const {
148   const int64_t start = feature_start_indices_[batch];
149   if (DT_STRING == values_.dtype())
150     return values_.vec<tstring>().data()[start + n];
151   return std::to_string(values_.vec<int64_t>().data()[start + n]);
152 }
153 
154 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const155 tstring KeyedSparseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
156                                                   bool strong_hash) const {
157   const int64_t start = feature_start_indices_[batch];
158   if (DT_STRING == values_.dtype())
159     return values_.vec<tstring>().data()[start + n];
160   return std::to_string(values_.vec<int64_t>().data()[start + n]);
161 }
162 
163 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const164 StringPiece SparseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
165                                                      bool strong_hash) const {
166   const int64_t start = feature_start_indices_[batch];
167   return values_.vec<tstring>().data()[start + n];
168 }
169 
170 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const171 StringPiece KeyedSparseTensorColumn<StringPiece>::Feature(
172     int64_t batch, int64_t n, bool strong_hash) const {
173   const int64_t start = feature_start_indices_[batch];
174   return values_.vec<tstring>().data()[start + n];
175 }
176 
177 // A column that is backed by a dense tensor.
178 template <typename InternalType>
179 class DenseTensorColumn : public ColumnInterface<InternalType> {
180  public:
DenseTensorColumn(const Tensor & tensor)181   explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
182 
FeatureCount(int64_t batch) const183   int64_t FeatureCount(int64_t batch) const override {
184     return tensor_.dim_size(1);
185   }
186 
187   InternalType Feature(int64_t batch, int64_t n,
188                        bool strong_hash) const override;
189 
~DenseTensorColumn()190   ~DenseTensorColumn() override {}
191 
192  private:
193   const Tensor& tensor_;
194 };
195 
196 // A column that is backed by a dense tensor.
197 template <typename InternalType>
198 class KeyedDenseTensorColumn : public ColumnInterface<InternalType> {
199  public:
KeyedDenseTensorColumn(const Tensor & tensor,std::vector<int64_t> key)200   explicit KeyedDenseTensorColumn(const Tensor& tensor,
201                                   std::vector<int64_t> key)
202       : tensor_(tensor) {
203     std::memcpy(key_, key.data(), sizeof(key_));
204   }
205 
FeatureCount(int64_t batch) const206   int64_t FeatureCount(int64_t batch) const override {
207     return tensor_.dim_size(1);
208   }
209 
210   InternalType Feature(int64_t batch, int64_t n,
211                        bool strong_hash) const override;
212 
~KeyedDenseTensorColumn()213   ~KeyedDenseTensorColumn() override {}
214 
215  private:
216   const Tensor& tensor_;
217   tensorflow::uint64 key_[2];
218 };
219 
220 // InternalType is int64 only when using HashCrosser.
221 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const222 int64_t DenseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
223                                             bool strong_hash) const {
224   if (DT_STRING == tensor_.dtype())
225     return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
226   return tensor_.matrix<int64_t>()(batch, n);
227 }
228 
229 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const230 int64_t KeyedDenseTensorColumn<int64_t>::Feature(int64_t batch, int64_t n,
231                                                  bool strong_hash) const {
232   if (strong_hash) {
233     if (DT_STRING == tensor_.dtype()) {
234       return StrongKeyedHash(key_, tensor_.matrix<tstring>()(batch, n));
235     }
236     return StrongKeyedHash(
237         key_,
238         {reinterpret_cast<const char*>(tensor_.matrix<int64_t>()(batch, n)),
239          sizeof(tensor_.dtype())});
240   }
241   if (DT_STRING == tensor_.dtype())
242     return Fingerprint64(tensor_.matrix<tstring>()(batch, n));
243   return tensor_.matrix<int64_t>()(batch, n);
244 }
245 
246 // Internal type is string or StringPiece when using StringCrosser.
247 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const248 tstring DenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
249                                             bool strong_hash) const {
250   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
251   return std::to_string(tensor_.matrix<int64_t>()(batch, n));
252 }
253 
254 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const255 tstring KeyedDenseTensorColumn<tstring>::Feature(int64_t batch, int64_t n,
256                                                  bool strong_hash) const {
257   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n);
258   return std::to_string(tensor_.matrix<int64_t>()(batch, n));
259 }
260 
261 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const262 StringPiece DenseTensorColumn<StringPiece>::Feature(int64_t batch, int64_t n,
263                                                     bool strong_hash) const {
264   return tensor_.matrix<tstring>()(batch, n);
265 }
266 
267 template <>
Feature(int64_t batch,int64_t n,bool strong_hash) const268 StringPiece KeyedDenseTensorColumn<StringPiece>::Feature(
269     int64_t batch, int64_t n, bool strong_hash) const {
270   return tensor_.matrix<tstring>()(batch, n);
271 }
272 
273 // Updates Output tensors with sparse crosses.
274 template <typename OutType>
275 class OutputUpdater {
276  public:
OutputUpdater(const std::vector<int64_t> & output_start_indices,Tensor * indices_out,Tensor * values_out)277   OutputUpdater(const std::vector<int64_t>& output_start_indices,
278                 Tensor* indices_out, Tensor* values_out)
279       : output_start_indices_(output_start_indices),
280         indices_out_(indices_out),
281         values_out_(values_out) {}
282 
Update(const int64_t batch_index,const int64_t cross_count,const OutType & cross) const283   void Update(const int64_t batch_index, const int64_t cross_count,
284               const OutType& cross) const {
285     const int64_t output_index =
286         output_start_indices_[batch_index] + cross_count;
287 
288     auto indices_matrix = indices_out_->matrix<int64_t>();
289     indices_matrix(output_index, 0) = batch_index;
290     indices_matrix(output_index, 1) = cross_count;
291 
292     auto value_vec = values_out_->vec<OutType>();
293     value_vec(output_index) = cross;
294   }
295 
296  private:
297   const std::vector<int64_t>& output_start_indices_;
298   Tensor* indices_out_;
299   Tensor* values_out_;
300 };
301 
302 // Generates the sparse crosses as concatenation of strings.
303 template <typename InternalType>
304 class StringCrosser {
305  public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64_t num_buckets_unused,const uint64 hash_key_unused,const tstring k_feature_separator)306   StringCrosser(const std::vector<
307                     std::unique_ptr<ColumnInterface<InternalType>>>& columns,
308                 const int64_t num_buckets_unused, const uint64 hash_key_unused,
309                 const tstring k_feature_separator)
310       : columns_(columns), k_feature_separator_(k_feature_separator) {}
311 
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const312   string Generate(const int64_t batch_index,
313                   const std::vector<int>& permutation,
314                   bool unused_strong_hash) const {
315     gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
316     for (int i = 0; i < permutation.size(); i++) {
317       cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i], false);
318     }
319     // TODO(zakaria): this will copy the string twice, might effect
320     // performance.
321     return absl::StrJoin(cross_vec, k_feature_separator_);
322   }
323 
324  private:
325   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
326   const tstring k_feature_separator_;
327 };
328 
329 // Generates the sparse crosses as nested hash to avoid string manipulations.
330 class HashCrosser {
331  public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64_t>>> & columns,const int64_t num_buckets,const uint64 hash_key,const tstring k_feature_separator_unused)332   HashCrosser(
333       const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns,
334       const int64_t num_buckets, const uint64 hash_key,
335       const tstring k_feature_separator_unused)
336       : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
337 
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool unused_strong_hash) const338   int64_t Generate(const int64_t batch_index,
339                    const std::vector<int>& permutation,
340                    bool unused_strong_hash) const {
341     // Do the fingerprint concatenation on uint64.
342     uint64 hashed_output = hash_key_;
343     for (size_t i = 0; i < permutation.size(); ++i) {
344       uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i], false);
345       hashed_output = FingerprintCat64(hashed_output, hash_i);
346     }
347     // The return value is int64 based on the number of buckets.
348     if (num_buckets_ > 0) {
349       return hashed_output % num_buckets_;
350     } else {
351       // To prevent negative output we take modulo to max int64.
352       return hashed_output % std::numeric_limits<int64_t>::max();
353     }
354   }
355 
356  private:
357   const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns_;
358   const int64_t num_buckets_;
359   const uint64 hash_key_;
360 };
361 
362 // Generates the sparse crosses as nested hash to avoid string manipulations.
363 class HashCrosserV2 {
364  public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64_t>>> & columns,const int64_t num_buckets,const uint64 hash_key_unused,const tstring k_feature_separator_unused)365   HashCrosserV2(
366       const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns,
367       const int64_t num_buckets, const uint64 hash_key_unused,
368       const tstring k_feature_separator_unused)
369       : columns_(columns), num_buckets_(num_buckets) {}
370 
Generate(const int64_t batch_index,const std::vector<int> & permutation,bool strong_hash) const371   int64_t Generate(const int64_t batch_index,
372                    const std::vector<int>& permutation,
373                    bool strong_hash) const {
374     // Do the fingerprint concatenation on uint64.
375     uint64 hashed_output =
376         columns_[0]->Feature(batch_index, permutation[0], strong_hash);
377     for (size_t i = 1; i < permutation.size(); ++i) {
378       uint64 hash_i =
379           columns_[i]->Feature(batch_index, permutation[i], strong_hash);
380       hashed_output = FingerprintCat64(hashed_output, hash_i);
381     }
382     // The return value is int64 based on the number of buckets.
383     if (num_buckets_ > 0) {
384       return hashed_output % num_buckets_;
385     } else {
386       // To prevent negative output we take modulo to max int64.
387       return hashed_output % std::numeric_limits<int64_t>::max();
388     }
389   }
390 
391  private:
392   const std::vector<std::unique_ptr<ColumnInterface<int64_t>>>& columns_;
393   const int64_t num_buckets_;
394 };
395 
396 // ProductIterator generates cartesian products based on indices.
397 template <typename InternalType>
398 class ProductIterator {
399  public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_index)400   explicit ProductIterator(
401       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
402           columns,
403       int64_t batch_index)
404       : columns_(columns), batch_index_(batch_index) {
405     next_permutation_.resize(columns_.size(), 0);
406     // Sets has_next_ to false if any feature column has 0 features.
407     has_next_ = true;
408     for (int i = 0; i < columns_.size(); i++) {
409       if (columns_[i]->FeatureCount(batch_index_) == 0) {
410         has_next_ = false;
411         break;
412       }
413     }
414   }
415 
Next()416   std::vector<int> Next() {
417     std::vector<int> permutation(next_permutation_);
418 
419     // Generates next permutation, if available.
420     bool carry = true;
421     for (int i = next_permutation_.size() - 1; i >= 0; i--) {
422       if (carry) {
423         next_permutation_[i] = next_permutation_[i] + 1;
424       }
425       if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
426         next_permutation_[i] = 0;
427       } else {
428         carry = false;
429         break;
430       }
431     }
432     has_next_ = !carry;
433     return permutation;
434   }
435 
HasNext()436   bool HasNext() { return has_next_; }
437 
438  private:
439   bool has_next_;
440   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
441   const int64_t batch_index_;
442   std::vector<int> next_permutation_;
443 };
444 
445 template <bool HASHED_OUTPUT, typename InternalType>
446 struct CrossTraits;
447 
448 template <typename InternalType>
449 struct CrossTraits<false, InternalType> {
450   typedef StringCrosser<InternalType> Crosser;
451   typedef StringCrosser<InternalType> CrosserV2;
452   typedef OutputUpdater<tstring> Updater;
453 };
454 
455 template <>
456 struct CrossTraits<true, int64_t> {
457   typedef HashCrosser Crosser;
458   typedef HashCrosserV2 CrosserV2;
459   typedef OutputUpdater<int64_t> Updater;
460 };
461 }  // namespace
462 
463 // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)464 int64_t CalculateBatchSize(const OpInputList& shapes_list_in,
465                            const OpInputList& dense_list_in) {
466   if (shapes_list_in.size() > 0) {
467     return shapes_list_in[0].vec<int64_t>()(0);
468   }
469 
470   if (dense_list_in.size() > 0) {
471     return dense_list_in[0].dim_size(0);
472   }
473 
474   return 0;
475 }
476 
477 // Validates input tensors.
ValidateInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,const DataType & internal_type)478 Status ValidateInput(const OpInputList& indices_list_in,
479                      const OpInputList& values_list_in,
480                      const OpInputList& shapes_list_in,
481                      const OpInputList& dense_list_in,
482                      const DataType& internal_type) {
483   const auto size = indices_list_in.size();
484   // Only perform internal_type check for SparseCrossOp.
485   // Check if the internal_type is not invalid before doing so.
486   bool check_type = internal_type != DT_INVALID;
487   // Validates indices_list_in OpInputList.
488   for (int i = 0; i < size; i++) {
489     if (check_type && indices_list_in[i].dtype() != DT_INT64) {
490       return errors::InvalidArgument("Input indices should be of type ",
491                                      DT_INT64, " but received ",
492                                      indices_list_in[i].dtype());
493     }
494     if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) {
495       return errors::InvalidArgument(
496           "Input indices should be a matrix but received shape ",
497           indices_list_in[i].shape().DebugString(), " at position ", i);
498     }
499     if (indices_list_in[i].shape().dim_size(1) != 2) {
500       return errors::InvalidArgument("Expected D2 of index to be 2 got ",
501                                      indices_list_in[i].shape().dim_size(1),
502                                      " at position ", i);
503     }
504   }
505 
506   // Validates values_list_in OpInputList.
507   if (values_list_in.size() != size) {
508     return errors::InvalidArgument("Expected ", size, " input values, got ",
509                                    values_list_in.size());
510   }
511   for (int i = 0; i < size; i++) {
512     // Make sure to avoid the expected type to be string, but input values to be
513     // int64.
514     if (check_type && internal_type == DT_STRING &&
515         values_list_in[i].dtype() == DT_INT64) {
516       return errors::InvalidArgument("Input values should be of internal type ",
517                                      internal_type, " but received ",
518                                      values_list_in[i].dtype());
519     }
520     if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) {
521       return errors::InvalidArgument(
522           "Input values should be a vector but received shape ",
523           values_list_in[i].shape().DebugString(), " at position ", i);
524     }
525     if (indices_list_in[i].shape().dim_size(0) !=
526         values_list_in[i].shape().dim_size(0)) {
527       return errors::InvalidArgument(
528           "Expected size of values to be ",
529           indices_list_in[i].shape().dim_size(0), " got ",
530           values_list_in[i].shape().dim_size(0), " at position ", i);
531     }
532   }
533 
534   // Validates shapes_list_in OpInputList
535   if (shapes_list_in.size() != size) {
536     return errors::InvalidArgument("Expected ", size, " input shapes, got ",
537                                    shapes_list_in.size());
538   }
539   for (int i = 0; i < size; i++) {
540     if (check_type && shapes_list_in[i].dtype() != DT_INT64) {
541       return errors::InvalidArgument("Input shape should be of type ", DT_INT64,
542                                      " but received ",
543                                      shapes_list_in[i].dtype());
544     }
545     if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) {
546       return errors::InvalidArgument(
547           "Input shapes should be a vector but received shape ",
548           shapes_list_in[i].shape().DebugString(), " at position ", i);
549     }
550 
551     if (shapes_list_in[i].vec<int64_t>().size() != 2) {
552       return errors::InvalidArgument("shape should imply a 2D tensor, but got ",
553                                      shapes_list_in[i].shape().DebugString(),
554                                      " at position ", i);
555     }
556   }
557 
558   // Validates dense_list_in OpInputList
559   for (int i = 0; i < dense_list_in.size(); ++i) {
560     // Make sure to avoid the expected type to be string, but input values to be
561     // int64.
562     if (check_type && internal_type == DT_STRING &&
563         dense_list_in[i].dtype() == DT_INT64) {
564       return errors::InvalidArgument("Dense inputs should be of internal type ",
565                                      internal_type, " but received ",
566                                      dense_list_in[i].dtype());
567     }
568     if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) {
569       return errors::InvalidArgument(
570           "Dense inputs should be a matrix but received shape ",
571           dense_list_in[i].shape().DebugString(), " at position ", i);
572     }
573   }
574 
575   // Validates batch sizes.  (Note: we do this after validating the input
576   // shapes, because CalculateBatchSize() depends on inputs having valid
577   // shapes).
578   const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
579   for (int i = 0; i < size; i++) {
580     if (shapes_list_in[i].vec<int64_t>()(0) != batch_size) {
581       return errors::InvalidArgument(
582           "Expected batch size ", batch_size, " got ",
583           shapes_list_in[i].vec<int64_t>()(0), " at position ", i);
584     }
585   }
586   for (int i = 0; i < dense_list_in.size(); ++i) {
587     if (dense_list_in[i].dim_size(0) != batch_size) {
588       return errors::InvalidArgument("Expected batch size ", batch_size,
589                                      " got ", dense_list_in[i].dim_size(0),
590                                      " at dense tensor ", i);
591     }
592   }
593 
594   return OkStatus();
595 }
596 
597 // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64_t batch_size,std::vector<std::vector<int64_t>> * feature_counts,std::vector<std::vector<int64_t>> * feature_start_indices)598 void ExtractFeatureData(
599     const OpInputList& indices_list_in, int64_t batch_size,
600     std::vector<std::vector<int64_t>>* feature_counts,
601     std::vector<std::vector<int64_t>>* feature_start_indices) {
602   gtl::InlinedVector<int64_t, 8> current_row(indices_list_in.size(), 0);
603   for (int b = 0; b < batch_size; b++) {
604     for (int i = 0; i < indices_list_in.size(); i++) {
605       const auto indices = indices_list_in[i].matrix<int64_t>();
606       int64_t feature_count = 0;
607       int64_t start_index = current_row[i];
608       // Loops until we reach next batch index for current feature column.
609       while (current_row[i] < indices_list_in[i].dim_size(0) &&
610              indices(current_row[i], 0) == b) {
611         feature_count++;
612         current_row[i]++;
613       }
614       (*feature_counts)[i].push_back(feature_count);
615       (*feature_start_indices)[i].push_back(start_index);
616     }
617   }
618 }
619 
620 // Returns number of crosses for a given batch_index
621 template <typename InternalType>
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)622 int64_t CrossCountByBatchIndex(
623     const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
624     int batch_index) {
625   int64_t cross_count = 1;
626   for (int i = 0; i < columns.size(); i++) {
627     const auto feature_count = columns[i]->FeatureCount(batch_index);
628     // If one column is missing any feature, there won't be any cross.
629     if (feature_count == 0) {
630       return 0;
631     }
632     cross_count *= feature_count;
633   }
634   return cross_count;
635 }
636 
637 // Generate the columns given the sparse and dense inputs.
638 template <typename InternalType>
639 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)640 GenerateColumnsFromInput(const OpInputList& indices_list_in,
641                          const OpInputList& values_list_in,
642                          const OpInputList& shapes_list_in,
643                          const OpInputList& dense_list_in) {
644   std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
645   const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
646   const int64_t number_of_columns = shapes_list_in.size();
647 
648   std::vector<std::vector<int64_t>> feature_counts(number_of_columns,
649                                                    std::vector<int64_t>());
650   std::vector<std::vector<int64_t>> feature_start_indices(
651       number_of_columns, std::vector<int64_t>());
652 
653   ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
654                      &feature_start_indices);
655 
656   columns.reserve(values_list_in.size());
657   for (int i = 0; i < values_list_in.size(); ++i) {
658     columns.emplace_back(new SparseTensorColumn<InternalType>(
659         values_list_in[i], std::move(feature_counts[i]),
660         std::move(feature_start_indices[i])));
661   }
662   for (int i = 0; i < dense_list_in.size(); ++i) {
663     columns.emplace_back(new DenseTensorColumn<InternalType>(dense_list_in[i]));
664   }
665 
666   return columns;
667 }
668 
669 // Generate the columns given the sparse and dense inputs.
670 template <typename InternalType>
671 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateKeyedColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in,std::vector<int64_t> keys)672 GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in,
673                               const OpInputList& values_list_in,
674                               const OpInputList& shapes_list_in,
675                               const OpInputList& dense_list_in,
676                               std::vector<int64_t> keys) {
677   std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
678   const int64_t batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
679   const int64_t number_of_columns = shapes_list_in.size();
680 
681   std::vector<std::vector<int64_t>> feature_counts(number_of_columns,
682                                                    std::vector<int64_t>());
683   std::vector<std::vector<int64_t>> feature_start_indices(
684       number_of_columns, std::vector<int64_t>());
685 
686   ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
687                      &feature_start_indices);
688 
689   columns.reserve(values_list_in.size());
690   for (int i = 0; i < values_list_in.size(); ++i) {
691     columns.emplace_back(new KeyedSparseTensorColumn<InternalType>(
692         values_list_in[i], std::move(feature_counts[i]),
693         std::move(feature_start_indices[i]), keys));
694   }
695   for (int i = 0; i < dense_list_in.size(); ++i) {
696     columns.emplace_back(
697         new KeyedDenseTensorColumn<InternalType>(dense_list_in[i], keys));
698   }
699 
700   return columns;
701 }
702 
703 // Allocates output tensors with proper size and sets the shape tensor of
704 // the output SparseTensor.
705 // It also output_start_indices which contains the start indices for each
706 // input in the output SparseTensor.
707 template <typename InternalType>
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64_t batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64_t> * output_start_indices)708 Status CreateOutputTensors(
709     const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns,
710     int64_t batch_size, OpKernelContext* context, Tensor** indices_out,
711     Tensor** values_out, Tensor** shape_out,
712     std::vector<int64_t>* output_start_indices) {
713   // Calculates dimensions for output tensors.
714   int64_t cross_count_total = 0;
715   int64_t max_cross_count = 0;
716   for (int64_t b = 0; b < batch_size; b++) {
717     // For each input, sets starting indices in output SparseTensor
718     (*output_start_indices)[b] = cross_count_total;
719     const auto cross_count = CrossCountByBatchIndex(columns, b);
720     max_cross_count = std::max(max_cross_count, cross_count);
721     cross_count_total += cross_count;
722   }
723 
724   // Allocates tensors.
725   TF_RETURN_IF_ERROR(context->allocate_output(
726       0, TensorShape({cross_count_total, 2}), indices_out));
727   TF_RETURN_IF_ERROR(context->allocate_output(
728       1, TensorShape({cross_count_total}), values_out));
729   TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({2}), shape_out));
730 
731   // Sets shape.
732   auto shape_vec = (*shape_out)->vec<int64_t>();
733   shape_vec(0) = batch_size;
734   shape_vec(1) = max_cross_count;
735 
736   return OkStatus();
737 }
738 
739 template <bool HASHED_OUTPUT, typename InternalType>
740 class SparseCrossOp : public OpKernel {
741  public:
SparseCrossOp(OpKernelConstruction * context)742   explicit SparseCrossOp(OpKernelConstruction* context) : OpKernel(context) {
743     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
744     // Read signed_hash_key_ as int64 since uint64 attributes are not
745     // supported by REGISTER_OP.
746     int64_t signed_hash_key_;
747     OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
748     hash_key_ = static_cast<uint64>(signed_hash_key_);
749     OP_REQUIRES_OK(context, context->GetAttr("internal_type", &internal_type_));
750   }
751 
Compute(OpKernelContext * context)752   void Compute(OpKernelContext* context) override {
753     OpInputList indices_list_in;
754     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
755     OpInputList values_list_in;
756     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
757     OpInputList shapes_list_in;
758     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
759     OpInputList dense_list_in;
760     OP_REQUIRES_OK(context,
761                    context->input_list("dense_inputs", &dense_list_in));
762 
763     DataType internal_type = internal_type_;
764     OP_REQUIRES_OK(
765         context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
766                                dense_list_in, internal_type));
767 
768     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
769         GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in,
770                                                shapes_list_in, dense_list_in);
771 
772     const tstring k_feature_separator = "_X_";
773     typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser(
774         columns, num_buckets_, hash_key_, k_feature_separator);
775     Tensor* indices_out;
776     Tensor* values_out;
777     Tensor* shape_out;
778     const int64_t batch_size =
779         CalculateBatchSize(shapes_list_in, dense_list_in);
780     std::vector<int64_t> output_start_indices(batch_size);
781     OP_REQUIRES_OK(
782         context,
783         CreateOutputTensors(columns, batch_size, context, &indices_out,
784                             &values_out, &shape_out, &output_start_indices));
785 
786     typename CrossTraits<HASHED_OUTPUT, InternalType>::Updater updater(
787         output_start_indices, indices_out, values_out);
788     auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
789       for (int b = begin; b < end; b++) {
790         ProductIterator<InternalType> product_iterator(columns, b);
791         int64_t cross_count = 0;
792         while (product_iterator.HasNext()) {
793           const auto permutation = product_iterator.Next();
794           updater.Update(b, cross_count,
795                          crosser.Generate(b, permutation, false));
796           cross_count++;
797         }
798       }
799     };
800 
801     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
802     // TODO(zakaria): optimize kCostPerUnit
803     const int kCostPerUnit = 5000 * indices_list_in.size();
804     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
805           kCostPerUnit, do_work);
806   }
807 
808  private:
809   int64_t num_buckets_;
810   uint64 hash_key_;
811   DataType internal_type_;
812 };
813 
814 class SparseCrossV2Op : public OpKernel {
815  public:
SparseCrossV2Op(OpKernelConstruction * context)816   explicit SparseCrossV2Op(OpKernelConstruction* context) : OpKernel(context) {}
817 
Compute(OpKernelContext * context)818   void Compute(OpKernelContext* context) override {
819     OpInputList indices_list_in;
820     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
821     OpInputList values_list_in;
822     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
823     OpInputList shapes_list_in;
824     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
825     OpInputList dense_list_in;
826     OP_REQUIRES_OK(context,
827                    context->input_list("dense_inputs", &dense_list_in));
828 
829     // Set internal_type to invalid_type so that the check will be ignored.
830     DataType internal_type = DT_INVALID;
831     OP_REQUIRES_OK(
832         context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
833                                dense_list_in, internal_type));
834 
835     const Tensor* sep_t;
836     OP_REQUIRES_OK(context, context->input("sep", &sep_t));
837     OP_REQUIRES(context, TensorShapeUtils::IsScalar(sep_t->shape()),
838                 errors::InvalidArgument("Input separator should be a scalar. "
839                                         "Received: ",
840                                         sep_t->DebugString()));
841     const tstring separator = sep_t->scalar<tstring>()();
842 
843     std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns =
844         GenerateColumnsFromInput<tstring>(indices_list_in, values_list_in,
845                                           shapes_list_in, dense_list_in);
846     Tensor* indices_out;
847     Tensor* values_out;
848     Tensor* shape_out;
849     const int64_t batch_size =
850         CalculateBatchSize(shapes_list_in, dense_list_in);
851     std::vector<int64_t> output_start_indices(batch_size);
852     OP_REQUIRES_OK(
853         context,
854         CreateOutputTensors(columns, batch_size, context, &indices_out,
855                             &values_out, &shape_out, &output_start_indices));
856     StringCrosser<tstring> crosser(columns, 0, 0, separator);
857     OutputUpdater<tstring> updater(output_start_indices, indices_out,
858                                    values_out);
859     auto do_work = [&columns, crosser, updater](int64_t begin, int64_t end) {
860       for (int b = begin; b < end; b++) {
861         ProductIterator<tstring> product_iterator(columns, b);
862         int64_t cross_count = 0;
863         while (product_iterator.HasNext()) {
864           const auto permutation = product_iterator.Next();
865           updater.Update(b, cross_count,
866                          crosser.Generate(b, permutation, false));
867           cross_count++;
868         }
869       }
870     };
871 
872     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
873     // TODO(zakaria): optimize kCostPerUnit
874     const int kCostPerUnit = 5000 * indices_list_in.size();
875     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
876           kCostPerUnit, do_work);
877   }
878 };
879 
880 class SparseCrossHashedOp : public OpKernel {
881  public:
SparseCrossHashedOp(OpKernelConstruction * context)882   explicit SparseCrossHashedOp(OpKernelConstruction* context)
883       : OpKernel(context) {}
884 
Compute(OpKernelContext * context)885   void Compute(OpKernelContext* context) override {
886     OpInputList indices_list_in;
887     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
888     OpInputList values_list_in;
889     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
890     OpInputList shapes_list_in;
891     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
892     OpInputList dense_list_in;
893     OP_REQUIRES_OK(context,
894                    context->input_list("dense_inputs", &dense_list_in));
895 
896     // Set internal_type to invalid_type so that the check will be ignored.
897     DataType internal_type = DT_INVALID;
898     OP_REQUIRES_OK(
899         context, ValidateInput(indices_list_in, values_list_in, shapes_list_in,
900                                dense_list_in, internal_type));
901 
902     const Tensor* num_buckets_t;
903     OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t));
904     const int64_t num_buckets = num_buckets_t->scalar<int64_t>()();
905 
906     const Tensor* strong_hash_t;
907     OP_REQUIRES_OK(context, context->input("strong_hash", &strong_hash_t));
908     const bool strong_hash = strong_hash_t->scalar<bool>()();
909 
910     const Tensor* salt_t;
911     OP_REQUIRES_OK(context, context->input("salt", &salt_t));
912     const auto salt = salt_t->flat<int64_t>();
913     std::vector<int64_t> key_{salt(0), salt(1)};
914 
915     std::vector<std::unique_ptr<ColumnInterface<int64_t>>> columns =
916         GenerateKeyedColumnsFromInput<int64_t>(indices_list_in, values_list_in,
917                                                shapes_list_in, dense_list_in,
918                                                key_);
919     Tensor* indices_out;
920     Tensor* values_out;
921     Tensor* shape_out;
922     const int64_t batch_size =
923         CalculateBatchSize(shapes_list_in, dense_list_in);
924     std::vector<int64_t> output_start_indices(batch_size);
925     OP_REQUIRES_OK(
926         context,
927         CreateOutputTensors(columns, batch_size, context, &indices_out,
928                             &values_out, &shape_out, &output_start_indices));
929     const tstring unused_sep;
930     HashCrosserV2 crosser(columns, num_buckets, 0, unused_sep);
931     OutputUpdater<int64_t> updater(output_start_indices, indices_out,
932                                    values_out);
933     auto do_work = [&columns, crosser, updater, strong_hash](int64_t begin,
934                                                              int64_t end) {
935       for (int b = begin; b < end; b++) {
936         ProductIterator<int64_t> product_iterator(columns, b);
937         int64_t cross_count = 0;
938         while (product_iterator.HasNext()) {
939           const auto permutation = product_iterator.Next();
940           updater.Update(b, cross_count,
941                          crosser.Generate(b, permutation, strong_hash));
942           cross_count++;
943         }
944       }
945     };
946 
947     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
948     // TODO(zakaria): optimize kCostPerUnit
949     const int kCostPerUnit = 5000 * indices_list_in.size();
950     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
951           kCostPerUnit, do_work);
952   }
953 };
954 
955 REGISTER_KERNEL_BUILDER(Name("SparseCross")
956                             .Device(DEVICE_CPU)
957                             .TypeConstraint<tstring>("out_type")
958                             .TypeConstraint<tstring>("internal_type"),
959                         SparseCrossOp<false, StringPiece>);
960 
961 REGISTER_KERNEL_BUILDER(Name("SparseCross")
962                             .Device(DEVICE_CPU)
963                             .TypeConstraint<tstring>("out_type")
964                             .TypeConstraint<int64_t>("internal_type"),
965                         SparseCrossOp<false, tstring>);
966 
967 REGISTER_KERNEL_BUILDER(Name("SparseCross")
968                             .Device(DEVICE_CPU)
969                             .TypeConstraint<int64_t>("out_type")
970                             .TypeConstraint<tstring>("internal_type"),
971                         SparseCrossOp<true, int64>);
972 
973 REGISTER_KERNEL_BUILDER(Name("SparseCross")
974                             .Device(DEVICE_CPU)
975                             .TypeConstraint<int64_t>("out_type")
976                             .TypeConstraint<int64_t>("internal_type"),
977                         SparseCrossOp<true, int64>);
978 
979 REGISTER_KERNEL_BUILDER(Name("SparseCrossV2").Device(DEVICE_CPU),
980                         SparseCrossV2Op);
981 
982 REGISTER_KERNEL_BUILDER(Name("SparseCrossHashed").Device(DEVICE_CPU),
983                         SparseCrossHashedOp);
984 
985 }  // namespace tensorflow
986