• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Wrappers for sparse cross operations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.framework import deprecated_arg_values
21from tensorflow.contrib.layers.ops import gen_sparse_feature_cross_op
22from tensorflow.contrib.util import loader
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.ops import math_ops
27from tensorflow.python.platform import resource_loader
28
29_sparse_feature_cross_op = loader.load_op_library(
30    resource_loader.get_path_to_datafile("_sparse_feature_cross_op.so"))
31
32# Default hash key for the FingerprintCat64.
33SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY = 0xDECAFCAFFE
34
35
36@deprecated_arg_values(
37    "2016-11-20",
38    "The default behavior of sparse_feature_cross is changing, the default\n"
39    "value for hash_key will change to SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY.\n"
40    "From that point on sparse_feature_cross will always use FingerprintCat64\n"
41    "to concatenate the feature fingerprints. And the underlying\n"
42    "_sparse_feature_cross_op.sparse_feature_cross operation will be marked\n"
43    "as deprecated.",
44    hash_key=None)
45def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0,
46                         name=None, hash_key=None):
47  """Crosses a list of Tensor or SparseTensor objects.
48
49  See sparse_feature_cross_kernel.cc for more details.
50
51  Args:
52    inputs: List of `SparseTensor` or `Tensor` to be crossed.
53    hashed_output: If true, returns the hash of the cross instead of the string.
54      This will allow us avoiding string manipulations.
55    num_buckets: It is used if hashed_output is true.
56      output = hashed_value%num_buckets if num_buckets > 0 else hashed_value.
57    name: A name prefix for the returned tensors (optional).
58    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
59      function to combine the crosses fingerprints on SparseFeatureCrossOp.
60      The default value is None, but will become
61      SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY after 2016-11-20 (optional).
62
63  Returns:
64    A `SparseTensor` with the crossed features.
65    Return type is string if hashed_output=False, int64 otherwise.
66
67  Raises:
68    TypeError: If the inputs aren't either SparseTensor or Tensor.
69  """
70  if not isinstance(inputs, list):
71    raise TypeError("Inputs must be a list")
72  if not all(isinstance(i, sparse_tensor.SparseTensor) or
73             isinstance(i, ops.Tensor) for i in inputs):
74    raise TypeError("All inputs must be SparseTensors")
75
76  sparse_inputs = [i for i in inputs
77                   if isinstance(i, sparse_tensor.SparseTensor)]
78  dense_inputs = [i for i in inputs
79                  if not isinstance(i, sparse_tensor.SparseTensor)]
80
81  indices = [sp_input.indices for sp_input in sparse_inputs]
82  values = [sp_input.values for sp_input in sparse_inputs]
83  shapes = [sp_input.dense_shape for sp_input in sparse_inputs]
84  out_type = dtypes.int64 if hashed_output else dtypes.string
85
86  internal_type = dtypes.string
87  for i in range(len(values)):
88    if values[i].dtype != dtypes.string:
89      values[i] = math_ops.cast(values[i], dtypes.int64)
90      internal_type = dtypes.int64
91  for i in range(len(dense_inputs)):
92    if dense_inputs[i].dtype != dtypes.string:
93      dense_inputs[i] = math_ops.cast(dense_inputs[i], dtypes.int64)
94      internal_type = dtypes.int64
95
96  if hash_key:
97    indices_out, values_out, shape_out = (
98        gen_sparse_feature_cross_op.sparse_feature_cross_v2(
99            indices,
100            values,
101            shapes,
102            dense_inputs,
103            hashed_output,
104            num_buckets,
105            hash_key=hash_key,
106            out_type=out_type,
107            internal_type=internal_type,
108            name=name))
109  else:
110    indices_out, values_out, shape_out = (
111        gen_sparse_feature_cross_op.sparse_feature_cross(
112            indices,
113            values,
114            shapes,
115            dense_inputs,
116            hashed_output,
117            num_buckets,
118            out_type=out_type,
119            internal_type=internal_type,
120            name=name))
121
122  return sparse_tensor.SparseTensor(indices_out, values_out, shape_out)
123
124
125ops.NotDifferentiable("SparseFeatureCross")
126
127
128ops.NotDifferentiable("SparseFeatureCrossV2")
129