• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18 #include <limits>
19 #include <string>
20 #include <vector>
21 
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/fingerprint.h"
32 #include "tensorflow/core/util/work_sharder.h"
33 
34 namespace tensorflow {
35 
36 namespace {
37 // An interface that represents a column with batches.
38 template <typename InternalType>
39 class ColumnInterface {
40  public:
41   // Returns the number of features in the specified batch.
42   virtual int64 FeatureCount(int64 batch) const = 0;
43 
44   // Returns the fingerprint of nth feature from the specified batch.
45   virtual InternalType Feature(int64 batch, int64 n) const = 0;
46 
~ColumnInterface()47   virtual ~ColumnInterface() {}
48 };
49 
50 // A column that is backed by a sparse tensor.
51 template <typename InternalType>
52 class SparseTensorColumn : public ColumnInterface<InternalType> {
53  public:
SparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices)54   SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
55                      std::vector<int64> feature_start_indices)
56       : values_(values),
57         feature_counts_(std::move(feature_counts)),
58         feature_start_indices_(std::move(feature_start_indices)) {
59     CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
60   }
61 
FeatureCount(int64 batch) const62   int64 FeatureCount(int64 batch) const override {
63     return feature_counts_[batch];
64   }
65 
66   InternalType Feature(int64 batch, int64 n) const override;
67 
~SparseTensorColumn()68   ~SparseTensorColumn() override {}
69 
70  private:
71   const Tensor& values_;
72   std::vector<int64> feature_counts_;
73   std::vector<int64> feature_start_indices_;
74 };
75 
76 // InternalType is int64 only when using HashCrosser.
77 template <>
Feature(int64 batch,int64 n) const78 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
79   const int64 start = feature_start_indices_[batch];
80   if (DT_STRING == values_.dtype())
81     return Fingerprint64(values_.vec<string>().data()[start + n]);
82   return values_.vec<int64>().data()[start + n];
83 }
84 
85 // InternalType is string or StringPiece when using StringCrosser.
86 template <>
Feature(int64 batch,int64 n) const87 string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
88   const int64 start = feature_start_indices_[batch];
89   if (DT_STRING == values_.dtype())
90     return values_.vec<string>().data()[start + n];
91   return std::to_string(values_.vec<int64>().data()[start + n]);
92 }
93 
94 template <>
Feature(int64 batch,int64 n) const95 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
96                                                      int64 n) const {
97   const int64 start = feature_start_indices_[batch];
98   return values_.vec<string>().data()[start + n];
99 }
100 
101 // A column that is backed by a dense tensor.
102 template <typename InternalType>
103 class DenseTensorColumn : public ColumnInterface<InternalType> {
104  public:
DenseTensorColumn(const Tensor & tensor)105   explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
106 
FeatureCount(int64 batch) const107   int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
108 
109   InternalType Feature(int64 batch, int64 n) const override;
110 
~DenseTensorColumn()111   ~DenseTensorColumn() override {}
112 
113  private:
114   const Tensor& tensor_;
115 };
116 
117 // InternalType is int64 only when using HashCrosser.
118 template <>
Feature(int64 batch,int64 n) const119 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
120   if (DT_STRING == tensor_.dtype())
121     return Fingerprint64(tensor_.matrix<string>()(batch, n));
122   return tensor_.matrix<int64>()(batch, n);
123 }
124 
125 // Internal type is string or StringPiece when using StringCrosser.
126 template <>
Feature(int64 batch,int64 n) const127 string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
128   if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
129   return std::to_string(tensor_.matrix<int64>()(batch, n));
130 }
131 
132 template <>
Feature(int64 batch,int64 n) const133 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
134                                                     int64 n) const {
135   return tensor_.matrix<string>()(batch, n);
136 }
137 
138 // Updates Output tensors with sparse crosses.
139 template <typename OutType>
140 class OutputUpdater {
141  public:
OutputUpdater(const std::vector<int64> & output_start_indices,Tensor * indices_out,Tensor * values_out)142   OutputUpdater(const std::vector<int64>& output_start_indices,
143                 Tensor* indices_out, Tensor* values_out)
144       : output_start_indices_(output_start_indices),
145         indices_out_(indices_out),
146         values_out_(values_out) {}
147 
Update(const int64 batch_index,const int64 cross_count,const OutType & cross) const148   void Update(const int64 batch_index, const int64 cross_count,
149               const OutType& cross) const {
150     const int64 output_index = output_start_indices_[batch_index] + cross_count;
151 
152     auto indices_matrix = indices_out_->matrix<int64>();
153     indices_matrix(output_index, 0) = batch_index;
154     indices_matrix(output_index, 1) = cross_count;
155 
156     auto value_vec = values_out_->vec<OutType>();
157     value_vec(output_index) = cross;
158   }
159 
160  private:
161   const std::vector<int64>& output_start_indices_;
162   Tensor* indices_out_;
163   Tensor* values_out_;
164 };
165 
166 // Generates the sparse crosses as concatenation of strings.
167 template <typename InternalType>
168 class StringCrosser {
169  public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64 num_buckets_unused,const uint64 hash_key_unused)170   StringCrosser(const std::vector<
171                     std::unique_ptr<ColumnInterface<InternalType>>>& columns,
172                 const int64 num_buckets_unused, const uint64 hash_key_unused)
173       : columns_(columns) {}
174 
Generate(const int64 batch_index,const std::vector<int> & permutation) const175   string Generate(const int64 batch_index,
176                   const std::vector<int>& permutation) const {
177     static const auto k_feature_separator = "_X_";
178 
179     gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
180     for (size_t i = 0; i < permutation.size(); i++) {
181       cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]);
182     }
183     // TODO(zakaria): this will copy the string twice, might effect
184     // performance.
185     return str_util::Join(cross_vec, k_feature_separator);
186   }
187 
188  private:
189   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
190 };
191 
192 // Generates the sparse crosses as nested hash to avoid string manipulations.
193 class HashCrosser {
194  public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key_unused)195   HashCrosser(
196       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
197       const int64 num_buckets, const uint64 hash_key_unused)
198       : columns_(columns), num_buckets_(num_buckets) {}
199 
Generate(const int64 batch_index,const std::vector<int> & permutation) const200   int64 Generate(const int64 batch_index,
201                  const std::vector<int>& permutation) const {
202     // Seed is chosen based on third_party/tensorflow/core/lib/hash/hash.h
203     static const int64 kInitialHashSeed = 0xDECAFCAFFE;
204 
205     uint64 hashed_output = kInitialHashSeed;
206     for (size_t i = 0; i < permutation.size(); ++i) {
207       int64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
208       hashed_output = HashCombine(hashed_output, hash_i);
209     }
210     if (num_buckets_ > 0) {
211       return hashed_output % num_buckets_;
212     } else {
213       // To prevent negative output we take modulo to max int64.
214       return hashed_output % std::numeric_limits<int64>::max();
215     }
216   }
217 
218  private:
HashCombine(int64 a,int64 b)219   static int64 HashCombine(int64 a, int64 b) {
220     return a ^ (b + 0x9e3779b97f4a7800 + (a << 10) + (a >> 4));
221   }
222 
223   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
224   const int64 num_buckets_;
225 };
226 
227 // Generates the sparse crosses as nested hash to avoid string manipulations.
228 class HashCrosserV2 {
229  public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key)230   HashCrosserV2(
231       const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
232       const int64 num_buckets, const uint64 hash_key)
233       : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
234 
Generate(const int64 batch_index,const std::vector<int> & permutation) const235   int64 Generate(const int64 batch_index,
236                  const std::vector<int>& permutation) const {
237     // Do the fingerprint concatenation on uint64.
238     uint64 hashed_output = hash_key_;
239     for (size_t i = 0; i < permutation.size(); ++i) {
240       uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
241       hashed_output = FingerprintCat64(hashed_output, hash_i);
242     }
243     // The return value is int64 based on the number of buckets.
244     if (num_buckets_ > 0) {
245       return hashed_output % num_buckets_;
246     } else {
247       // To prevent negative output we take modulo to max int64.
248       return hashed_output % std::numeric_limits<int64>::max();
249     }
250   }
251 
252  private:
253   const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
254   const int64 num_buckets_;
255   const uint64 hash_key_;
256 };
257 
258 // ProductIterator generates cartesian products based on indices.
259 template <typename InternalType>
260 class ProductIterator {
261  public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_index)262   explicit ProductIterator(
263       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
264           columns,
265       int64 batch_index)
266       : columns_(columns), batch_index_(batch_index) {
267     next_permutation_.resize(columns_.size(), 0);
268     // Sets has_next_ to false if any feature column has 0 features.
269     has_next_ = true;
270     for (size_t i = 0; i < columns_.size(); i++) {
271       if (columns_[i]->FeatureCount(batch_index_) == 0) {
272         has_next_ = false;
273         break;
274       }
275     }
276   }
277 
Next()278   std::vector<int> Next() {
279     std::vector<int> permutation(next_permutation_);
280 
281     // Generates next permutation, if available.
282     bool carry = true;
283     for (int i = next_permutation_.size() - 1; i >= 0; i--) {
284       if (carry) {
285         next_permutation_[i] = next_permutation_[i] + 1;
286       }
287       if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
288         next_permutation_[i] = 0;
289       } else {
290         carry = false;
291         break;
292       }
293     }
294     has_next_ = !carry;
295     return permutation;
296   }
297 
HasNext()298   bool HasNext() { return has_next_; }
299 
300  private:
301   bool has_next_;
302   const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
303   const int64 batch_index_;
304   std::vector<int> next_permutation_;
305 };
306 
307 template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
308 struct CrossTraits;
309 
310 template <typename InternalType, bool VERSION_2>
311 struct CrossTraits<false, InternalType, VERSION_2> {
312   typedef StringCrosser<InternalType> Crosser;
313   typedef OutputUpdater<string> Updater;
314 };
315 
316 template <>
317 struct CrossTraits<true, int64, false> {
318   typedef HashCrosser Crosser;
319   typedef OutputUpdater<int64> Updater;
320 };
321 
322 template <>
323 struct CrossTraits<true, int64, true> {
324   typedef HashCrosserV2 Crosser;
325   typedef OutputUpdater<int64> Updater;
326 };
327 }  // namespace
328 
329 template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
330 class SparseFeatureCrossOp : public OpKernel {
331  public:
SparseFeatureCrossOp(OpKernelConstruction * context)332   explicit SparseFeatureCrossOp(OpKernelConstruction* context)
333       : OpKernel(context) {
334     OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
335     if (VERSION_2) {
336       // Read signed_hash_key_ as int64 since uint64 attributes are not
337       // supported by REGISTER_OP.
338       int64 signed_hash_key_;
339       OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
340       hash_key_ = static_cast<uint64>(signed_hash_key_);
341     }
342   }
343 
Compute(OpKernelContext * context)344   void Compute(OpKernelContext* context) override {
345     OpInputList indices_list_in;
346     OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
347     OpInputList values_list_in;
348     OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
349     OpInputList shapes_list_in;
350     OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
351     OpInputList dense_list_in;
352     OP_REQUIRES_OK(context, context->input_list("dense", &dense_list_in));
353 
354     ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
355                   dense_list_in);
356 
357     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
358         GenerateColumnsFromInput(indices_list_in, values_list_in,
359                                  shapes_list_in, dense_list_in);
360 
361     typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Crosser
362         crosser(columns, num_buckets_, hash_key_);
363     Tensor* indices_out;
364     Tensor* values_out;
365     Tensor* shape_out;
366     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
367     std::vector<int64> output_start_indices(batch_size);
368     CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
369                         &shape_out, &output_start_indices);
370 
371     typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Updater
372         updater(output_start_indices, indices_out, values_out);
373     auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
374       for (int b = begin; b < end; b++) {
375         ProductIterator<InternalType> product_iterator(columns, b);
376         int64 cross_count = 0;
377         while (product_iterator.HasNext()) {
378           const auto permutation = product_iterator.Next();
379           updater.Update(b, cross_count, crosser.Generate(b, permutation));
380           cross_count++;
381         }
382       }
383     };
384 
385     auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
386     // TODO(zakaria): optimize kCostPerUnit
387     const int kCostPerUnit = 5000 * indices_list_in.size();
388     Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
389           kCostPerUnit, do_work);
390   }
391 
392  private:
393   // Validates input tensors.
ValidateInput(OpKernelContext * context,const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)394   void ValidateInput(OpKernelContext* context,
395                      const OpInputList& indices_list_in,
396                      const OpInputList& values_list_in,
397                      const OpInputList& shapes_list_in,
398                      const OpInputList& dense_list_in) {
399     const auto size = indices_list_in.size();
400     // Validates indices_list_in OpInputList.
401     for (int i = 0; i < size; i++) {
402       OP_REQUIRES(
403           context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
404           errors::InvalidArgument(
405               "Input indices should be a matrix but received shape ",
406               indices_list_in[i].shape().DebugString(), " at position ", i));
407       OP_REQUIRES(
408           context, indices_list_in[i].shape().dim_size(1) == 2,
409           errors::InvalidArgument("Expected D2 of index to be 2 got ",
410                                   indices_list_in[i].shape().dim_size(1),
411                                   " at position ", i));
412     }
413 
414     // Validates values_list_in OpInputList.
415     OP_REQUIRES(
416         context, values_list_in.size() == size,
417         errors::InvalidArgument("Expected ", size, " input values, got ",
418                                 values_list_in.size()));
419     for (int i = 0; i < size; i++) {
420       OP_REQUIRES(
421           context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
422           errors::InvalidArgument(
423               "Input values should be a std::vector but received shape ",
424               values_list_in[i].shape().DebugString(), " at position ", i));
425       OP_REQUIRES(
426           context,
427           indices_list_in[i].shape().dim_size(0) ==
428               values_list_in[i].shape().dim_size(0),
429           errors::InvalidArgument(
430               "Expected size of values to be ",
431               indices_list_in[i].shape().dim_size(0), " got ",
432               values_list_in[i].shape().dim_size(0), " at position ", i));
433     }
434 
435     // Validates shapes_list_in OpInputList
436     OP_REQUIRES(
437         context, shapes_list_in.size() == size,
438         errors::InvalidArgument("Expected ", size, " input shapes, got ",
439                                 shapes_list_in.size()));
440     const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
441     for (int i = 0; i < size; i++) {
442       OP_REQUIRES(
443           context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
444           errors::InvalidArgument(
445               "Input shapes should be a std::vector but received shape ",
446               shapes_list_in[i].shape().DebugString(), " at position ", i));
447 
448       OP_REQUIRES(
449           context, shapes_list_in[i].vec<int64>().size() == 2,
450           errors::InvalidArgument("shape should imply a 2D tensor, but got ",
451                                   shapes_list_in[i].shape().DebugString(),
452                                   " at position ", i));
453       OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
454                   errors::InvalidArgument(
455                       "Expected batch size ", batch_size, " got ",
456                       shapes_list_in[i].vec<int64>()(0), " at position ", i));
457     }
458 
459     // Validates dense_list_in OpInputList
460     for (int i = 0; i < dense_list_in.size(); ++i) {
461       OP_REQUIRES(
462           context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
463           errors::InvalidArgument(
464               "Dense inputs should be a matrix but received shape ",
465               indices_list_in[i].shape().DebugString(), " at position ", i));
466       OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
467                   errors::InvalidArgument("Expected batch size ", batch_size,
468                                           " got ", dense_list_in[i].dim_size(0),
469                                           " at dense tensor ", i));
470     }
471   }
472 
473   // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)474   int64 CalculateBatchSize(const OpInputList& shapes_list_in,
475                            const OpInputList& dense_list_in) {
476     if (shapes_list_in.size() > 0) {
477       return shapes_list_in[0].vec<int64>()(0);
478     }
479 
480     if (dense_list_in.size() > 0) {
481       return dense_list_in[0].dim_size(0);
482     }
483 
484     return 0;
485   }
486 
487   // Generate the columns given the sparse and dense inputs.
488   std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)489   GenerateColumnsFromInput(const OpInputList& indices_list_in,
490                            const OpInputList& values_list_in,
491                            const OpInputList& shapes_list_in,
492                            const OpInputList& dense_list_in) {
493     std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
494     const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
495     const int64 number_of_columns = shapes_list_in.size();
496 
497     std::vector<std::vector<int64>> feature_counts(number_of_columns,
498                                                    std::vector<int64>());
499     std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
500                                                           std::vector<int64>());
501 
502     ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
503                        &feature_start_indices);
504 
505     columns.reserve(values_list_in.size());
506     for (int i = 0; i < values_list_in.size(); ++i) {
507       columns.emplace_back(new SparseTensorColumn<InternalType>(
508           values_list_in[i], std::move(feature_counts[i]),
509           std::move(feature_start_indices[i])));
510     }
511     for (int i = 0; i < dense_list_in.size(); ++i) {
512       columns.emplace_back(
513           new DenseTensorColumn<InternalType>(dense_list_in[i]));
514     }
515 
516     return columns;
517   }
518 
519   // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64 batch_size,std::vector<std::vector<int64>> * feature_counts,std::vector<std::vector<int64>> * feature_start_indices)520   void ExtractFeatureData(
521       const OpInputList& indices_list_in, int64 batch_size,
522       std::vector<std::vector<int64>>* feature_counts,
523       std::vector<std::vector<int64>>* feature_start_indices) {
524     gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
525     for (int b = 0; b < batch_size; b++) {
526       for (int i = 0; i < indices_list_in.size(); i++) {
527         const auto indices = indices_list_in[i].matrix<int64>();
528         int64 feature_count = 0;
529         int64 start_index = current_row[i];
530         // Loops until we reach next batch index for current feature column.
531         while (current_row[i] < indices_list_in[i].dim_size(0) &&
532                indices(current_row[i], 0) == b) {
533           feature_count++;
534           current_row[i]++;
535         }
536         (*feature_counts)[i].push_back(feature_count);
537         (*feature_start_indices)[i].push_back(start_index);
538       }
539     }
540   }
541 
542   // Allocates output tensors with proper size and sets the shape tensor of
543   // the output SparseTensor.
544   // It also output_start_indices which contains the start indices for each
545   // input in the output SparseTensor.
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64> * output_start_indices)546   void CreateOutputTensors(
547       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
548           columns,
549       int64 batch_size, OpKernelContext* context, Tensor** indices_out,
550       Tensor** values_out, Tensor** shape_out,
551       std::vector<int64>* output_start_indices) {
552     // Calculates dimensions for output tensors.
553     int64 cross_count_total = 0;
554     int64 max_cross_count = 0;
555     for (int64 b = 0; b < batch_size; b++) {
556       // For each input, sets starting indices in output SparseTensor
557       (*output_start_indices)[b] = cross_count_total;
558       const auto cross_count = CrossCountByBatchIndex(columns, b);
559       max_cross_count = std::max(max_cross_count, cross_count);
560       cross_count_total += cross_count;
561     }
562 
563     // Allocates tensors.
564     OP_REQUIRES_OK(context,
565                    context->allocate_output(
566                        0, TensorShape({cross_count_total, 2}), indices_out));
567     OP_REQUIRES_OK(context,
568                    context->allocate_output(1, TensorShape({cross_count_total}),
569                                             values_out));
570     OP_REQUIRES_OK(context,
571                    context->allocate_output(2, TensorShape({2}), shape_out));
572 
573     // Sets shape.
574     auto shape_vec = (*shape_out)->vec<int64>();
575     shape_vec(0) = batch_size;
576     shape_vec(1) = max_cross_count;
577   }
578 
579   // Returns number of crosses for a given batch_index
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)580   int64 CrossCountByBatchIndex(
581       const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
582           columns,
583       int batch_index) {
584     int64 cross_count = 1;
585     for (size_t i = 0; i < columns.size(); i++) {
586       const auto feature_count = columns[i]->FeatureCount(batch_index);
587       // If one column is missing any feature, there won't be any cross.
588       if (feature_count == 0) {
589         return 0;
590       }
591       cross_count *= feature_count;
592     }
593     return cross_count;
594   }
595   int64 num_buckets_;
596   uint64 hash_key_;
597 };
598 
599 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
600                             .Device(DEVICE_CPU)
601                             .TypeConstraint<string>("out_type")
602                             .TypeConstraint<string>("internal_type"),
603                         SparseFeatureCrossOp<false, StringPiece, false>);
604 
605 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
606                             .Device(DEVICE_CPU)
607                             .TypeConstraint<string>("out_type")
608                             .TypeConstraint<int64>("internal_type"),
609                         SparseFeatureCrossOp<false, string, false>);
610 
611 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
612                             .Device(DEVICE_CPU)
613                             .TypeConstraint<int64>("out_type")
614                             .TypeConstraint<string>("internal_type"),
615                         SparseFeatureCrossOp<true, int64, false>);
616 
617 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
618                             .Device(DEVICE_CPU)
619                             .TypeConstraint<int64>("out_type")
620                             .TypeConstraint<int64>("internal_type"),
621                         SparseFeatureCrossOp<true, int64, false>);
622 
623 // The following builders enable FingerprintCat64 concatenation for the
624 // crosses features.
625 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
626                             .Device(DEVICE_CPU)
627                             .TypeConstraint<string>("out_type")
628                             .TypeConstraint<string>("internal_type"),
629                         SparseFeatureCrossOp<false, StringPiece, true>);
630 
631 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
632                             .Device(DEVICE_CPU)
633                             .TypeConstraint<string>("out_type")
634                             .TypeConstraint<int64>("internal_type"),
635                         SparseFeatureCrossOp<false, string, true>);
636 
637 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
638                             .Device(DEVICE_CPU)
639                             .TypeConstraint<int64>("out_type")
640                             .TypeConstraint<string>("internal_type"),
641                         SparseFeatureCrossOp<true, int64, true>);
642 
643 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
644                             .Device(DEVICE_CPU)
645                             .TypeConstraint<int64>("out_type")
646                             .TypeConstraint<int64>("internal_type"),
647                         SparseFeatureCrossOp<true, int64, true>);
648 
649 }  // namespace tensorflow
650