1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Wrappers for sparse cross operations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.contrib.framework import deprecated_arg_values 21from tensorflow.contrib.layers.ops import gen_sparse_feature_cross_op 22from tensorflow.contrib.util import loader 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.ops import math_ops 27from tensorflow.python.platform import resource_loader 28 29_sparse_feature_cross_op = loader.load_op_library( 30 resource_loader.get_path_to_datafile("_sparse_feature_cross_op.so")) 31 32# Default hash key for the FingerprintCat64. 33SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY = 0xDECAFCAFFE 34 35 36@deprecated_arg_values( 37 "2016-11-20", 38 "The default behavior of sparse_feature_cross is changing, the default\n" 39 "value for hash_key will change to SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY.\n" 40 "From that point on sparse_feature_cross will always use FingerprintCat64\n" 41 "to concatenate the feature fingerprints. And the underlying\n" 42 "_sparse_feature_cross_op.sparse_feature_cross operation will be marked\n" 43 "as deprecated.", 44 hash_key=None) 45def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0, 46 name=None, hash_key=None): 47 """Crosses a list of Tensor or SparseTensor objects. 48 49 See sparse_feature_cross_kernel.cc for more details. 50 51 Args: 52 inputs: List of `SparseTensor` or `Tensor` to be crossed. 53 hashed_output: If true, returns the hash of the cross instead of the string. 54 This will allow us avoiding string manipulations. 55 num_buckets: It is used if hashed_output is true. 56 output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. 57 name: A name prefix for the returned tensors (optional). 58 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 59 function to combine the crosses fingerprints on SparseFeatureCrossOp. 60 The default value is None, but will become 61 SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY after 2016-11-20 (optional). 62 63 Returns: 64 A `SparseTensor` with the crossed features. 65 Return type is string if hashed_output=False, int64 otherwise. 66 67 Raises: 68 TypeError: If the inputs aren't either SparseTensor or Tensor. 69 """ 70 if not isinstance(inputs, list): 71 raise TypeError("Inputs must be a list") 72 if not all(isinstance(i, sparse_tensor.SparseTensor) or 73 isinstance(i, ops.Tensor) for i in inputs): 74 raise TypeError("All inputs must be SparseTensors") 75 76 sparse_inputs = [i for i in inputs 77 if isinstance(i, sparse_tensor.SparseTensor)] 78 dense_inputs = [i for i in inputs 79 if not isinstance(i, sparse_tensor.SparseTensor)] 80 81 indices = [sp_input.indices for sp_input in sparse_inputs] 82 values = [sp_input.values for sp_input in sparse_inputs] 83 shapes = [sp_input.dense_shape for sp_input in sparse_inputs] 84 out_type = dtypes.int64 if hashed_output else dtypes.string 85 86 internal_type = dtypes.string 87 for i in range(len(values)): 88 if values[i].dtype != dtypes.string: 89 values[i] = math_ops.cast(values[i], dtypes.int64) 90 internal_type = dtypes.int64 91 for i in range(len(dense_inputs)): 92 if dense_inputs[i].dtype != dtypes.string: 93 dense_inputs[i] = math_ops.cast(dense_inputs[i], dtypes.int64) 94 internal_type = dtypes.int64 95 96 if hash_key: 97 indices_out, values_out, shape_out = ( 98 gen_sparse_feature_cross_op.sparse_feature_cross_v2( 99 indices, 100 values, 101 shapes, 102 dense_inputs, 103 hashed_output, 104 num_buckets, 105 hash_key=hash_key, 106 out_type=out_type, 107 internal_type=internal_type, 108 name=name)) 109 else: 110 indices_out, values_out, shape_out = ( 111 gen_sparse_feature_cross_op.sparse_feature_cross( 112 indices, 113 values, 114 shapes, 115 dense_inputs, 116 hashed_output, 117 num_buckets, 118 out_type=out_type, 119 internal_type=internal_type, 120 name=name)) 121 122 return sparse_tensor.SparseTensor(indices_out, values_out, shape_out) 123 124 125ops.NotDifferentiable("SparseFeatureCross") 126 127 128ops.NotDifferentiable("SparseFeatureCrossV2") 129