1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Contains OP to generate sparse crosses.
17 #include <assert.h>
18 #include <limits>
19 #include <string>
20 #include <vector>
21
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/kernel_def_builder.h"
24 #include "tensorflow/core/framework/op_def_builder.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/fingerprint.h"
32 #include "tensorflow/core/util/work_sharder.h"
33
34 namespace tensorflow {
35
36 namespace {
37 // An interface that represents a column with batches.
38 template <typename InternalType>
39 class ColumnInterface {
40 public:
41 // Returns the number of features in the specified batch.
42 virtual int64 FeatureCount(int64 batch) const = 0;
43
44 // Returns the fingerprint of nth feature from the specified batch.
45 virtual InternalType Feature(int64 batch, int64 n) const = 0;
46
~ColumnInterface()47 virtual ~ColumnInterface() {}
48 };
49
50 // A column that is backed by a sparse tensor.
51 template <typename InternalType>
52 class SparseTensorColumn : public ColumnInterface<InternalType> {
53 public:
SparseTensorColumn(const Tensor & values,std::vector<int64> feature_counts,std::vector<int64> feature_start_indices)54 SparseTensorColumn(const Tensor& values, std::vector<int64> feature_counts,
55 std::vector<int64> feature_start_indices)
56 : values_(values),
57 feature_counts_(std::move(feature_counts)),
58 feature_start_indices_(std::move(feature_start_indices)) {
59 CHECK_EQ(feature_counts_.size(), feature_start_indices_.size());
60 }
61
FeatureCount(int64 batch) const62 int64 FeatureCount(int64 batch) const override {
63 return feature_counts_[batch];
64 }
65
66 InternalType Feature(int64 batch, int64 n) const override;
67
~SparseTensorColumn()68 ~SparseTensorColumn() override {}
69
70 private:
71 const Tensor& values_;
72 std::vector<int64> feature_counts_;
73 std::vector<int64> feature_start_indices_;
74 };
75
76 // InternalType is int64 only when using HashCrosser.
77 template <>
Feature(int64 batch,int64 n) const78 int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
79 const int64 start = feature_start_indices_[batch];
80 if (DT_STRING == values_.dtype())
81 return Fingerprint64(values_.vec<string>().data()[start + n]);
82 return values_.vec<int64>().data()[start + n];
83 }
84
85 // InternalType is string or StringPiece when using StringCrosser.
86 template <>
Feature(int64 batch,int64 n) const87 string SparseTensorColumn<string>::Feature(int64 batch, int64 n) const {
88 const int64 start = feature_start_indices_[batch];
89 if (DT_STRING == values_.dtype())
90 return values_.vec<string>().data()[start + n];
91 return std::to_string(values_.vec<int64>().data()[start + n]);
92 }
93
94 template <>
Feature(int64 batch,int64 n) const95 StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch,
96 int64 n) const {
97 const int64 start = feature_start_indices_[batch];
98 return values_.vec<string>().data()[start + n];
99 }
100
101 // A column that is backed by a dense tensor.
102 template <typename InternalType>
103 class DenseTensorColumn : public ColumnInterface<InternalType> {
104 public:
DenseTensorColumn(const Tensor & tensor)105 explicit DenseTensorColumn(const Tensor& tensor) : tensor_(tensor) {}
106
FeatureCount(int64 batch) const107 int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); }
108
109 InternalType Feature(int64 batch, int64 n) const override;
110
~DenseTensorColumn()111 ~DenseTensorColumn() override {}
112
113 private:
114 const Tensor& tensor_;
115 };
116
117 // InternalType is int64 only when using HashCrosser.
118 template <>
Feature(int64 batch,int64 n) const119 int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const {
120 if (DT_STRING == tensor_.dtype())
121 return Fingerprint64(tensor_.matrix<string>()(batch, n));
122 return tensor_.matrix<int64>()(batch, n);
123 }
124
125 // Internal type is string or StringPiece when using StringCrosser.
126 template <>
Feature(int64 batch,int64 n) const127 string DenseTensorColumn<string>::Feature(int64 batch, int64 n) const {
128 if (DT_STRING == tensor_.dtype()) return tensor_.matrix<string>()(batch, n);
129 return std::to_string(tensor_.matrix<int64>()(batch, n));
130 }
131
132 template <>
Feature(int64 batch,int64 n) const133 StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch,
134 int64 n) const {
135 return tensor_.matrix<string>()(batch, n);
136 }
137
138 // Updates Output tensors with sparse crosses.
139 template <typename OutType>
140 class OutputUpdater {
141 public:
OutputUpdater(const std::vector<int64> & output_start_indices,Tensor * indices_out,Tensor * values_out)142 OutputUpdater(const std::vector<int64>& output_start_indices,
143 Tensor* indices_out, Tensor* values_out)
144 : output_start_indices_(output_start_indices),
145 indices_out_(indices_out),
146 values_out_(values_out) {}
147
Update(const int64 batch_index,const int64 cross_count,const OutType & cross) const148 void Update(const int64 batch_index, const int64 cross_count,
149 const OutType& cross) const {
150 const int64 output_index = output_start_indices_[batch_index] + cross_count;
151
152 auto indices_matrix = indices_out_->matrix<int64>();
153 indices_matrix(output_index, 0) = batch_index;
154 indices_matrix(output_index, 1) = cross_count;
155
156 auto value_vec = values_out_->vec<OutType>();
157 value_vec(output_index) = cross;
158 }
159
160 private:
161 const std::vector<int64>& output_start_indices_;
162 Tensor* indices_out_;
163 Tensor* values_out_;
164 };
165
166 // Generates the sparse crosses as concatenation of strings.
167 template <typename InternalType>
168 class StringCrosser {
169 public:
StringCrosser(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,const int64 num_buckets_unused,const uint64 hash_key_unused)170 StringCrosser(const std::vector<
171 std::unique_ptr<ColumnInterface<InternalType>>>& columns,
172 const int64 num_buckets_unused, const uint64 hash_key_unused)
173 : columns_(columns) {}
174
Generate(const int64 batch_index,const std::vector<int> & permutation) const175 string Generate(const int64 batch_index,
176 const std::vector<int>& permutation) const {
177 static const auto k_feature_separator = "_X_";
178
179 gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size());
180 for (size_t i = 0; i < permutation.size(); i++) {
181 cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]);
182 }
183 // TODO(zakaria): this will copy the string twice, might effect
184 // performance.
185 return str_util::Join(cross_vec, k_feature_separator);
186 }
187
188 private:
189 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
190 };
191
192 // Generates the sparse crosses as nested hash to avoid string manipulations.
193 class HashCrosser {
194 public:
HashCrosser(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key_unused)195 HashCrosser(
196 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
197 const int64 num_buckets, const uint64 hash_key_unused)
198 : columns_(columns), num_buckets_(num_buckets) {}
199
Generate(const int64 batch_index,const std::vector<int> & permutation) const200 int64 Generate(const int64 batch_index,
201 const std::vector<int>& permutation) const {
202 // Seed is chosen based on third_party/tensorflow/core/lib/hash/hash.h
203 static const int64 kInitialHashSeed = 0xDECAFCAFFE;
204
205 uint64 hashed_output = kInitialHashSeed;
206 for (size_t i = 0; i < permutation.size(); ++i) {
207 int64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
208 hashed_output = HashCombine(hashed_output, hash_i);
209 }
210 if (num_buckets_ > 0) {
211 return hashed_output % num_buckets_;
212 } else {
213 // To prevent negative output we take modulo to max int64.
214 return hashed_output % std::numeric_limits<int64>::max();
215 }
216 }
217
218 private:
HashCombine(int64 a,int64 b)219 static int64 HashCombine(int64 a, int64 b) {
220 return a ^ (b + 0x9e3779b97f4a7800 + (a << 10) + (a >> 4));
221 }
222
223 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
224 const int64 num_buckets_;
225 };
226
227 // Generates the sparse crosses as nested hash to avoid string manipulations.
228 class HashCrosserV2 {
229 public:
HashCrosserV2(const std::vector<std::unique_ptr<ColumnInterface<int64>>> & columns,const int64 num_buckets,const uint64 hash_key)230 HashCrosserV2(
231 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns,
232 const int64 num_buckets, const uint64 hash_key)
233 : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {}
234
Generate(const int64 batch_index,const std::vector<int> & permutation) const235 int64 Generate(const int64 batch_index,
236 const std::vector<int>& permutation) const {
237 // Do the fingerprint concatenation on uint64.
238 uint64 hashed_output = hash_key_;
239 for (size_t i = 0; i < permutation.size(); ++i) {
240 uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]);
241 hashed_output = FingerprintCat64(hashed_output, hash_i);
242 }
243 // The return value is int64 based on the number of buckets.
244 if (num_buckets_ > 0) {
245 return hashed_output % num_buckets_;
246 } else {
247 // To prevent negative output we take modulo to max int64.
248 return hashed_output % std::numeric_limits<int64>::max();
249 }
250 }
251
252 private:
253 const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_;
254 const int64 num_buckets_;
255 const uint64 hash_key_;
256 };
257
258 // ProductIterator generates cartesian products based on indices.
259 template <typename InternalType>
260 class ProductIterator {
261 public:
ProductIterator(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_index)262 explicit ProductIterator(
263 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
264 columns,
265 int64 batch_index)
266 : columns_(columns), batch_index_(batch_index) {
267 next_permutation_.resize(columns_.size(), 0);
268 // Sets has_next_ to false if any feature column has 0 features.
269 has_next_ = true;
270 for (size_t i = 0; i < columns_.size(); i++) {
271 if (columns_[i]->FeatureCount(batch_index_) == 0) {
272 has_next_ = false;
273 break;
274 }
275 }
276 }
277
Next()278 std::vector<int> Next() {
279 std::vector<int> permutation(next_permutation_);
280
281 // Generates next permutation, if available.
282 bool carry = true;
283 for (int i = next_permutation_.size() - 1; i >= 0; i--) {
284 if (carry) {
285 next_permutation_[i] = next_permutation_[i] + 1;
286 }
287 if (next_permutation_[i] == columns_[i]->FeatureCount(batch_index_)) {
288 next_permutation_[i] = 0;
289 } else {
290 carry = false;
291 break;
292 }
293 }
294 has_next_ = !carry;
295 return permutation;
296 }
297
HasNext()298 bool HasNext() { return has_next_; }
299
300 private:
301 bool has_next_;
302 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_;
303 const int64 batch_index_;
304 std::vector<int> next_permutation_;
305 };
306
307 template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
308 struct CrossTraits;
309
310 template <typename InternalType, bool VERSION_2>
311 struct CrossTraits<false, InternalType, VERSION_2> {
312 typedef StringCrosser<InternalType> Crosser;
313 typedef OutputUpdater<string> Updater;
314 };
315
316 template <>
317 struct CrossTraits<true, int64, false> {
318 typedef HashCrosser Crosser;
319 typedef OutputUpdater<int64> Updater;
320 };
321
322 template <>
323 struct CrossTraits<true, int64, true> {
324 typedef HashCrosserV2 Crosser;
325 typedef OutputUpdater<int64> Updater;
326 };
327 } // namespace
328
329 template <bool HASHED_OUTPUT, typename InternalType, bool VERSION_2>
330 class SparseFeatureCrossOp : public OpKernel {
331 public:
SparseFeatureCrossOp(OpKernelConstruction * context)332 explicit SparseFeatureCrossOp(OpKernelConstruction* context)
333 : OpKernel(context) {
334 OP_REQUIRES_OK(context, context->GetAttr("num_buckets", &num_buckets_));
335 if (VERSION_2) {
336 // Read signed_hash_key_ as int64 since uint64 attributes are not
337 // supported by REGISTER_OP.
338 int64 signed_hash_key_;
339 OP_REQUIRES_OK(context, context->GetAttr("hash_key", &signed_hash_key_));
340 hash_key_ = static_cast<uint64>(signed_hash_key_);
341 }
342 }
343
Compute(OpKernelContext * context)344 void Compute(OpKernelContext* context) override {
345 OpInputList indices_list_in;
346 OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in));
347 OpInputList values_list_in;
348 OP_REQUIRES_OK(context, context->input_list("values", &values_list_in));
349 OpInputList shapes_list_in;
350 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in));
351 OpInputList dense_list_in;
352 OP_REQUIRES_OK(context, context->input_list("dense", &dense_list_in));
353
354 ValidateInput(context, indices_list_in, values_list_in, shapes_list_in,
355 dense_list_in);
356
357 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns =
358 GenerateColumnsFromInput(indices_list_in, values_list_in,
359 shapes_list_in, dense_list_in);
360
361 typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Crosser
362 crosser(columns, num_buckets_, hash_key_);
363 Tensor* indices_out;
364 Tensor* values_out;
365 Tensor* shape_out;
366 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
367 std::vector<int64> output_start_indices(batch_size);
368 CreateOutputTensors(columns, batch_size, context, &indices_out, &values_out,
369 &shape_out, &output_start_indices);
370
371 typename CrossTraits<HASHED_OUTPUT, InternalType, VERSION_2>::Updater
372 updater(output_start_indices, indices_out, values_out);
373 auto do_work = [this, &columns, crosser, updater](int64 begin, int64 end) {
374 for (int b = begin; b < end; b++) {
375 ProductIterator<InternalType> product_iterator(columns, b);
376 int64 cross_count = 0;
377 while (product_iterator.HasNext()) {
378 const auto permutation = product_iterator.Next();
379 updater.Update(b, cross_count, crosser.Generate(b, permutation));
380 cross_count++;
381 }
382 }
383 };
384
385 auto* worker_threads = context->device()->tensorflow_cpu_worker_threads();
386 // TODO(zakaria): optimize kCostPerUnit
387 const int kCostPerUnit = 5000 * indices_list_in.size();
388 Shard(worker_threads->num_threads, worker_threads->workers, batch_size,
389 kCostPerUnit, do_work);
390 }
391
392 private:
393 // Validates input tensors.
ValidateInput(OpKernelContext * context,const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)394 void ValidateInput(OpKernelContext* context,
395 const OpInputList& indices_list_in,
396 const OpInputList& values_list_in,
397 const OpInputList& shapes_list_in,
398 const OpInputList& dense_list_in) {
399 const auto size = indices_list_in.size();
400 // Validates indices_list_in OpInputList.
401 for (int i = 0; i < size; i++) {
402 OP_REQUIRES(
403 context, TensorShapeUtils::IsMatrix(indices_list_in[i].shape()),
404 errors::InvalidArgument(
405 "Input indices should be a matrix but received shape ",
406 indices_list_in[i].shape().DebugString(), " at position ", i));
407 OP_REQUIRES(
408 context, indices_list_in[i].shape().dim_size(1) == 2,
409 errors::InvalidArgument("Expected D2 of index to be 2 got ",
410 indices_list_in[i].shape().dim_size(1),
411 " at position ", i));
412 }
413
414 // Validates values_list_in OpInputList.
415 OP_REQUIRES(
416 context, values_list_in.size() == size,
417 errors::InvalidArgument("Expected ", size, " input values, got ",
418 values_list_in.size()));
419 for (int i = 0; i < size; i++) {
420 OP_REQUIRES(
421 context, TensorShapeUtils::IsVector(values_list_in[i].shape()),
422 errors::InvalidArgument(
423 "Input values should be a std::vector but received shape ",
424 values_list_in[i].shape().DebugString(), " at position ", i));
425 OP_REQUIRES(
426 context,
427 indices_list_in[i].shape().dim_size(0) ==
428 values_list_in[i].shape().dim_size(0),
429 errors::InvalidArgument(
430 "Expected size of values to be ",
431 indices_list_in[i].shape().dim_size(0), " got ",
432 values_list_in[i].shape().dim_size(0), " at position ", i));
433 }
434
435 // Validates shapes_list_in OpInputList
436 OP_REQUIRES(
437 context, shapes_list_in.size() == size,
438 errors::InvalidArgument("Expected ", size, " input shapes, got ",
439 shapes_list_in.size()));
440 const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
441 for (int i = 0; i < size; i++) {
442 OP_REQUIRES(
443 context, TensorShapeUtils::IsVector(shapes_list_in[i].shape()),
444 errors::InvalidArgument(
445 "Input shapes should be a std::vector but received shape ",
446 shapes_list_in[i].shape().DebugString(), " at position ", i));
447
448 OP_REQUIRES(
449 context, shapes_list_in[i].vec<int64>().size() == 2,
450 errors::InvalidArgument("shape should imply a 2D tensor, but got ",
451 shapes_list_in[i].shape().DebugString(),
452 " at position ", i));
453 OP_REQUIRES(context, shapes_list_in[i].vec<int64>()(0) == batch_size,
454 errors::InvalidArgument(
455 "Expected batch size ", batch_size, " got ",
456 shapes_list_in[i].vec<int64>()(0), " at position ", i));
457 }
458
459 // Validates dense_list_in OpInputList
460 for (int i = 0; i < dense_list_in.size(); ++i) {
461 OP_REQUIRES(
462 context, TensorShapeUtils::IsMatrix(dense_list_in[i].shape()),
463 errors::InvalidArgument(
464 "Dense inputs should be a matrix but received shape ",
465 indices_list_in[i].shape().DebugString(), " at position ", i));
466 OP_REQUIRES(context, dense_list_in[i].dim_size(0) == batch_size,
467 errors::InvalidArgument("Expected batch size ", batch_size,
468 " got ", dense_list_in[i].dim_size(0),
469 " at dense tensor ", i));
470 }
471 }
472
473 // Calculate the batch size from either the shapes input or the dense input.
CalculateBatchSize(const OpInputList & shapes_list_in,const OpInputList & dense_list_in)474 int64 CalculateBatchSize(const OpInputList& shapes_list_in,
475 const OpInputList& dense_list_in) {
476 if (shapes_list_in.size() > 0) {
477 return shapes_list_in[0].vec<int64>()(0);
478 }
479
480 if (dense_list_in.size() > 0) {
481 return dense_list_in[0].dim_size(0);
482 }
483
484 return 0;
485 }
486
487 // Generate the columns given the sparse and dense inputs.
488 std::vector<std::unique_ptr<ColumnInterface<InternalType>>>
GenerateColumnsFromInput(const OpInputList & indices_list_in,const OpInputList & values_list_in,const OpInputList & shapes_list_in,const OpInputList & dense_list_in)489 GenerateColumnsFromInput(const OpInputList& indices_list_in,
490 const OpInputList& values_list_in,
491 const OpInputList& shapes_list_in,
492 const OpInputList& dense_list_in) {
493 std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns;
494 const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in);
495 const int64 number_of_columns = shapes_list_in.size();
496
497 std::vector<std::vector<int64>> feature_counts(number_of_columns,
498 std::vector<int64>());
499 std::vector<std::vector<int64>> feature_start_indices(number_of_columns,
500 std::vector<int64>());
501
502 ExtractFeatureData(indices_list_in, batch_size, &feature_counts,
503 &feature_start_indices);
504
505 columns.reserve(values_list_in.size());
506 for (int i = 0; i < values_list_in.size(); ++i) {
507 columns.emplace_back(new SparseTensorColumn<InternalType>(
508 values_list_in[i], std::move(feature_counts[i]),
509 std::move(feature_start_indices[i])));
510 }
511 for (int i = 0; i < dense_list_in.size(); ++i) {
512 columns.emplace_back(
513 new DenseTensorColumn<InternalType>(dense_list_in[i]));
514 }
515
516 return columns;
517 }
518
519 // Extracts data about the features and populates feature data.
ExtractFeatureData(const OpInputList & indices_list_in,int64 batch_size,std::vector<std::vector<int64>> * feature_counts,std::vector<std::vector<int64>> * feature_start_indices)520 void ExtractFeatureData(
521 const OpInputList& indices_list_in, int64 batch_size,
522 std::vector<std::vector<int64>>* feature_counts,
523 std::vector<std::vector<int64>>* feature_start_indices) {
524 gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0);
525 for (int b = 0; b < batch_size; b++) {
526 for (int i = 0; i < indices_list_in.size(); i++) {
527 const auto indices = indices_list_in[i].matrix<int64>();
528 int64 feature_count = 0;
529 int64 start_index = current_row[i];
530 // Loops until we reach next batch index for current feature column.
531 while (current_row[i] < indices_list_in[i].dim_size(0) &&
532 indices(current_row[i], 0) == b) {
533 feature_count++;
534 current_row[i]++;
535 }
536 (*feature_counts)[i].push_back(feature_count);
537 (*feature_start_indices)[i].push_back(start_index);
538 }
539 }
540 }
541
542 // Allocates output tensors with proper size and sets the shape tensor of
543 // the output SparseTensor.
544 // It also output_start_indices which contains the start indices for each
545 // input in the output SparseTensor.
CreateOutputTensors(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int64 batch_size,OpKernelContext * context,Tensor ** indices_out,Tensor ** values_out,Tensor ** shape_out,std::vector<int64> * output_start_indices)546 void CreateOutputTensors(
547 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
548 columns,
549 int64 batch_size, OpKernelContext* context, Tensor** indices_out,
550 Tensor** values_out, Tensor** shape_out,
551 std::vector<int64>* output_start_indices) {
552 // Calculates dimensions for output tensors.
553 int64 cross_count_total = 0;
554 int64 max_cross_count = 0;
555 for (int64 b = 0; b < batch_size; b++) {
556 // For each input, sets starting indices in output SparseTensor
557 (*output_start_indices)[b] = cross_count_total;
558 const auto cross_count = CrossCountByBatchIndex(columns, b);
559 max_cross_count = std::max(max_cross_count, cross_count);
560 cross_count_total += cross_count;
561 }
562
563 // Allocates tensors.
564 OP_REQUIRES_OK(context,
565 context->allocate_output(
566 0, TensorShape({cross_count_total, 2}), indices_out));
567 OP_REQUIRES_OK(context,
568 context->allocate_output(1, TensorShape({cross_count_total}),
569 values_out));
570 OP_REQUIRES_OK(context,
571 context->allocate_output(2, TensorShape({2}), shape_out));
572
573 // Sets shape.
574 auto shape_vec = (*shape_out)->vec<int64>();
575 shape_vec(0) = batch_size;
576 shape_vec(1) = max_cross_count;
577 }
578
579 // Returns number of crosses for a given batch_index
CrossCountByBatchIndex(const std::vector<std::unique_ptr<ColumnInterface<InternalType>>> & columns,int batch_index)580 int64 CrossCountByBatchIndex(
581 const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>&
582 columns,
583 int batch_index) {
584 int64 cross_count = 1;
585 for (size_t i = 0; i < columns.size(); i++) {
586 const auto feature_count = columns[i]->FeatureCount(batch_index);
587 // If one column is missing any feature, there won't be any cross.
588 if (feature_count == 0) {
589 return 0;
590 }
591 cross_count *= feature_count;
592 }
593 return cross_count;
594 }
595 int64 num_buckets_;
596 uint64 hash_key_;
597 };
598
599 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
600 .Device(DEVICE_CPU)
601 .TypeConstraint<string>("out_type")
602 .TypeConstraint<string>("internal_type"),
603 SparseFeatureCrossOp<false, StringPiece, false>);
604
605 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
606 .Device(DEVICE_CPU)
607 .TypeConstraint<string>("out_type")
608 .TypeConstraint<int64>("internal_type"),
609 SparseFeatureCrossOp<false, string, false>);
610
611 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
612 .Device(DEVICE_CPU)
613 .TypeConstraint<int64>("out_type")
614 .TypeConstraint<string>("internal_type"),
615 SparseFeatureCrossOp<true, int64, false>);
616
617 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross")
618 .Device(DEVICE_CPU)
619 .TypeConstraint<int64>("out_type")
620 .TypeConstraint<int64>("internal_type"),
621 SparseFeatureCrossOp<true, int64, false>);
622
623 // The following builders enable FingerprintCat64 concatenation for the
624 // crosses features.
625 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
626 .Device(DEVICE_CPU)
627 .TypeConstraint<string>("out_type")
628 .TypeConstraint<string>("internal_type"),
629 SparseFeatureCrossOp<false, StringPiece, true>);
630
631 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
632 .Device(DEVICE_CPU)
633 .TypeConstraint<string>("out_type")
634 .TypeConstraint<int64>("internal_type"),
635 SparseFeatureCrossOp<false, string, true>);
636
637 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
638 .Device(DEVICE_CPU)
639 .TypeConstraint<int64>("out_type")
640 .TypeConstraint<string>("internal_type"),
641 SparseFeatureCrossOp<true, int64, true>);
642
643 REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2")
644 .Device(DEVICE_CPU)
645 .TypeConstraint<int64>("out_type")
646 .TypeConstraint<int64>("internal_type"),
647 SparseFeatureCrossOp<true, int64, true>);
648
649 } // namespace tensorflow
650