1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SparseTensorsMap.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.client import session 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import sparse_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import benchmark 32from tensorflow.python.platform import test 33 34# pylint: disable=protected-access 35add_sparse_to_tensors_map = sparse_ops._add_sparse_to_tensors_map 36add_many_sparse_to_tensors_map = sparse_ops._add_many_sparse_to_tensors_map 37take_many_sparse_from_tensors_map = ( 38 sparse_ops._take_many_sparse_from_tensors_map) 39 40# pylint: enable=protected-access 41 42 43class SparseTensorsMapTest(test.TestCase): 44 45 def _SparseTensorPlaceholder(self, dtype=None): 46 if dtype is None: 47 dtype = dtypes.int32 48 return sparse_tensor_lib.SparseTensor( 49 array_ops.placeholder(dtypes.int64), 50 array_ops.placeholder(dtype), array_ops.placeholder(dtypes.int64)) 51 52 def _SparseTensorValue_5x6(self, permutation): 53 ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2], 54 [3, 3]]).astype(np.int64) 55 val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32) 56 57 ind = ind[permutation] 58 val = val[permutation] 59 60 shape = np.array([5, 6]).astype(np.int64) 61 return sparse_tensor_lib.SparseTensorValue(ind, val, shape) 62 63 def _SparseTensorValue_3x4(self, permutation): 64 ind = np.array([[0, 0], [1, 0], [1, 2], [1, 3], [2, 2], 65 [2, 3]]).astype(np.int64) 66 val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32) 67 68 ind = ind[permutation] 69 val = val[permutation] 70 71 shape = np.array([3, 4]).astype(np.int64) 72 return sparse_tensor_lib.SparseTensorValue(ind, val, shape) 73 74 def _SparseTensorValue_1x1x1(self): 75 ind = np.array([[0, 0, 0]]).astype(np.int64) 76 val = np.array([0]).astype(np.int32) 77 shape = np.array([3, 4, 5]).astype(np.int64) 78 return sparse_tensor_lib.SparseTensorValue(ind, val, shape) 79 80 @test_util.run_deprecated_v1 81 def testAddTakeMany(self): 82 with self.session(graph=ops.Graph(), use_gpu=False) as sess: 83 sp_input0 = self._SparseTensorValue_5x6(np.arange(6)) 84 sp_input1 = self._SparseTensorValue_3x4(np.arange(6)) 85 handle0 = add_sparse_to_tensors_map(sp_input0, shared_name="a") 86 handle1 = add_sparse_to_tensors_map(sp_input1, shared_name="a") 87 self.assertEqual(handle0.get_shape(), ()) 88 handles_concat = array_ops.stack([handle0, handle1]) 89 90 sp_out = take_many_sparse_from_tensors_map( 91 sparse_map_op=handle0.op, sparse_handles=handles_concat) 92 93 combined_indices, combined_values, combined_shape = self.evaluate(sp_out) 94 95 self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0 96 self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0]) 97 self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1 98 self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0]) 99 self.assertAllEqual(combined_values[:6], sp_input0[1]) 100 self.assertAllEqual(combined_values[6:], sp_input1[1]) 101 self.assertAllEqual(combined_shape, [2, 5, 6]) 102 103 @test_util.run_deprecated_v1 104 def testFeedAddTakeMany(self): 105 with self.session(use_gpu=False) as sess: 106 sp_input = self._SparseTensorPlaceholder() 107 input0_val = self._SparseTensorValue_5x6(np.arange(6)) 108 input1_val = self._SparseTensorValue_3x4(np.arange(6)) 109 handle = add_sparse_to_tensors_map(sp_input) 110 111 handle0_value = sess.run(handle, feed_dict={sp_input: input0_val}) 112 handle1_value = sess.run(handle, feed_dict={sp_input: input1_val}) 113 114 sparse_handles = ops.convert_to_tensor( 115 [handle0_value, handle1_value], dtype=dtypes.int64) 116 117 sp_roundtrip = take_many_sparse_from_tensors_map( 118 sparse_map_op=handle.op, sparse_handles=sparse_handles) 119 120 combined_indices, combined_values, combined_shape = self.evaluate( 121 sp_roundtrip) 122 123 self.assertAllEqual(combined_indices[:6, 0], [0] * 6) # minibatch 0 124 self.assertAllEqual(combined_indices[:6, 1:], input0_val[0]) 125 self.assertAllEqual(combined_indices[6:, 0], [1] * 6) # minibatch 1 126 self.assertAllEqual(combined_indices[6:, 1:], input1_val[0]) 127 self.assertAllEqual(combined_values[:6], input0_val[1]) 128 self.assertAllEqual(combined_values[6:], input1_val[1]) 129 self.assertAllEqual(combined_shape, [2, 5, 6]) 130 131 @test_util.run_deprecated_v1 132 def testAddManyTakeManyRoundTrip(self): 133 with self.session(use_gpu=False) as sess: 134 # N == 4 because shape_value == [4, 5] 135 indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64) 136 values_value = np.array([b"a", b"b", b"c"]) 137 shape_value = np.array([4, 5], dtype=np.int64) 138 sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string) 139 handles = add_many_sparse_to_tensors_map(sparse_tensor) 140 roundtrip = take_many_sparse_from_tensors_map( 141 sparse_map_op=handles.op, sparse_handles=handles) 142 handles_value, roundtrip_value = sess.run( 143 [handles, roundtrip], 144 feed_dict={ 145 sparse_tensor.indices: indices_value, 146 sparse_tensor.values: values_value, 147 sparse_tensor.dense_shape: shape_value 148 }) 149 self.assertEqual(handles_value.shape, (4,)) 150 self.assertAllEqual(roundtrip_value.indices, indices_value) 151 self.assertAllEqual(roundtrip_value.values, values_value) 152 self.assertAllEqual(roundtrip_value.dense_shape, shape_value) 153 154 @test_util.run_deprecated_v1 155 def testDeserializeFailsInconsistentRank(self): 156 with self.session(use_gpu=False) as sess: 157 sp_input = self._SparseTensorPlaceholder() 158 input0_val = self._SparseTensorValue_5x6(np.arange(6)) 159 input1_val = self._SparseTensorValue_1x1x1() 160 handle = add_sparse_to_tensors_map(sp_input) 161 162 handle0_value = sess.run(handle, feed_dict={sp_input: input0_val}) 163 handle1_value = sess.run(handle, feed_dict={sp_input: input1_val}) 164 165 handle_concat = ops.convert_to_tensor( 166 [handle0_value, handle1_value], dtype=dtypes.int64) 167 168 sp_roundtrip = take_many_sparse_from_tensors_map( 169 sparse_map_op=handle.op, sparse_handles=handle_concat) 170 171 with self.assertRaisesOpError( 172 r"Inconsistent rank across SparseTensors: rank prior to " 173 r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"): 174 self.evaluate(sp_roundtrip) 175 176 @test_util.run_deprecated_v1 177 def testTakeManyFailsWrongInputOp(self): 178 with self.session(use_gpu=False) as sess: 179 input_val = self._SparseTensorValue_5x6(np.arange(6)) 180 handle = add_sparse_to_tensors_map(input_val) 181 handle_value = self.evaluate(handle) 182 bad_handle = handle_value + 10 183 sp_roundtrip = take_many_sparse_from_tensors_map( 184 sparse_map_op=handle.op, sparse_handles=[handle_value, bad_handle]) 185 186 with self.assertRaisesOpError(r"Unable to find SparseTensor: 10"): 187 self.evaluate(sp_roundtrip) 188 189 190class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark): 191 192 def benchmarkVeryLarge2DFloatSparseTensor(self): 193 np.random.seed(127) 194 num_elements = 10000 195 batch_size = 64 196 indices_batch = np.random.randint( 197 batch_size, size=num_elements, dtype=np.int64) 198 indices_value = np.arange(num_elements, dtype=np.int64) 199 indices = np.asarray( 200 sorted(zip(indices_batch, indices_value)), dtype=np.int64) 201 values = ["feature_value_for_embedding_lookup"] * num_elements 202 shape = np.asarray([batch_size, num_elements], dtype=np.int64) 203 with session.Session(config=benchmark.benchmark_config()) as sess: 204 with ops.device("/cpu:0"): 205 indices = variables.Variable(indices) 206 values = variables.Variable(values) 207 shape = variables.Variable(shape) 208 st = sparse_tensor_lib.SparseTensor(indices, values, shape) 209 210 st_handles = add_many_sparse_to_tensors_map(st) 211 st_roundtrip = take_many_sparse_from_tensors_map( 212 sparse_map_op=st_handles.op, sparse_handles=st_handles) 213 st_roundtrip_op = st_roundtrip.values.op 214 215 st_serialized = sparse_ops.serialize_many_sparse(st) 216 st_deserialized = sparse_ops.deserialize_many_sparse( 217 st_serialized, dtype=values.dtype) 218 st_deserialized_op = st_deserialized.values.op 219 220 variables.global_variables_initializer().run() 221 222 st_roundtrip_values = self.evaluate(st_roundtrip) 223 st_deserialized_values = self.evaluate(st_deserialized) 224 np.testing.assert_equal(st_roundtrip_values.values, 225 st_deserialized_values.values) 226 np.testing.assert_equal(st_roundtrip_values.indices, 227 st_deserialized_values.indices) 228 np.testing.assert_equal(st_roundtrip_values.dense_shape, 229 st_deserialized_values.dense_shape) 230 231 self.run_op_benchmark( 232 sess, 233 st_roundtrip_op, 234 min_iters=2000, 235 name="benchmark_very_large_2d_float_st_tensor_maps") 236 self.run_op_benchmark( 237 sess, 238 st_deserialized_op, 239 min_iters=2000, 240 name="benchmark_very_large_2d_float_st_serialization") 241 242 243if __name__ == "__main__": 244 test.main() 245