• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SparseTensorsMap."""
16
17import numpy as np
18
19from tensorflow.python.client import session
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import sparse_ops
26from tensorflow.python.ops import variables
27from tensorflow.python.platform import benchmark
28from tensorflow.python.platform import test
29
30# pylint: disable=protected-access
31add_sparse_to_tensors_map = sparse_ops._add_sparse_to_tensors_map
32add_many_sparse_to_tensors_map = sparse_ops._add_many_sparse_to_tensors_map
33take_many_sparse_from_tensors_map = (
34    sparse_ops._take_many_sparse_from_tensors_map)
35
36# pylint: enable=protected-access
37
38
39class SparseTensorsMapTest(test.TestCase):
40
41  def _SparseTensorPlaceholder(self, dtype=None):
42    if dtype is None:
43      dtype = dtypes.int32
44    return sparse_tensor_lib.SparseTensor(
45        array_ops.placeholder(dtypes.int64),
46        array_ops.placeholder(dtype), array_ops.placeholder(dtypes.int64))
47
48  def _SparseTensorValue_5x6(self, permutation):
49    ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4], [3, 2],
50                    [3, 3]]).astype(np.int64)
51    val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
52
53    ind = ind[permutation]
54    val = val[permutation]
55
56    shape = np.array([5, 6]).astype(np.int64)
57    return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
58
59  def _SparseTensorValue_3x4(self, permutation):
60    ind = np.array([[0, 0], [1, 0], [1, 2], [1, 3], [2, 2],
61                    [2, 3]]).astype(np.int64)
62    val = np.array([0, 10, 13, 14, 32, 33]).astype(np.int32)
63
64    ind = ind[permutation]
65    val = val[permutation]
66
67    shape = np.array([3, 4]).astype(np.int64)
68    return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
69
70  def _SparseTensorValue_1x1x1(self):
71    ind = np.array([[0, 0, 0]]).astype(np.int64)
72    val = np.array([0]).astype(np.int32)
73    shape = np.array([3, 4, 5]).astype(np.int64)
74    return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
75
76  @test_util.run_deprecated_v1
77  def testAddTakeMany(self):
78    with self.session(graph=ops.Graph(), use_gpu=False) as sess:
79      sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
80      sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
81      handle0 = add_sparse_to_tensors_map(sp_input0, shared_name="a")
82      handle1 = add_sparse_to_tensors_map(sp_input1, shared_name="a")
83      self.assertEqual(handle0.get_shape(), ())
84      handles_concat = array_ops.stack([handle0, handle1])
85
86      sp_out = take_many_sparse_from_tensors_map(
87          sparse_map_op=handle0.op, sparse_handles=handles_concat)
88
89      combined_indices, combined_values, combined_shape = self.evaluate(sp_out)
90
91      self.assertAllEqual(combined_indices[:6, 0], [0] * 6)  # minibatch 0
92      self.assertAllEqual(combined_indices[:6, 1:], sp_input0[0])
93      self.assertAllEqual(combined_indices[6:, 0], [1] * 6)  # minibatch 1
94      self.assertAllEqual(combined_indices[6:, 1:], sp_input1[0])
95      self.assertAllEqual(combined_values[:6], sp_input0[1])
96      self.assertAllEqual(combined_values[6:], sp_input1[1])
97      self.assertAllEqual(combined_shape, [2, 5, 6])
98
99  @test_util.run_deprecated_v1
100  def testFeedAddTakeMany(self):
101    with self.session(use_gpu=False) as sess:
102      sp_input = self._SparseTensorPlaceholder()
103      input0_val = self._SparseTensorValue_5x6(np.arange(6))
104      input1_val = self._SparseTensorValue_3x4(np.arange(6))
105      handle = add_sparse_to_tensors_map(sp_input)
106
107      handle0_value = sess.run(handle, feed_dict={sp_input: input0_val})
108      handle1_value = sess.run(handle, feed_dict={sp_input: input1_val})
109
110      sparse_handles = ops.convert_to_tensor(
111          [handle0_value, handle1_value], dtype=dtypes.int64)
112
113      sp_roundtrip = take_many_sparse_from_tensors_map(
114          sparse_map_op=handle.op, sparse_handles=sparse_handles)
115
116      combined_indices, combined_values, combined_shape = self.evaluate(
117          sp_roundtrip)
118
119      self.assertAllEqual(combined_indices[:6, 0], [0] * 6)  # minibatch 0
120      self.assertAllEqual(combined_indices[:6, 1:], input0_val[0])
121      self.assertAllEqual(combined_indices[6:, 0], [1] * 6)  # minibatch 1
122      self.assertAllEqual(combined_indices[6:, 1:], input1_val[0])
123      self.assertAllEqual(combined_values[:6], input0_val[1])
124      self.assertAllEqual(combined_values[6:], input1_val[1])
125      self.assertAllEqual(combined_shape, [2, 5, 6])
126
127  @test_util.run_deprecated_v1
128  def testAddManyTakeManyRoundTrip(self):
129    with self.session(use_gpu=False) as sess:
130      # N == 4 because shape_value == [4, 5]
131      indices_value = np.array([[0, 0], [0, 1], [2, 0]], dtype=np.int64)
132      values_value = np.array([b"a", b"b", b"c"])
133      shape_value = np.array([4, 5], dtype=np.int64)
134      sparse_tensor = self._SparseTensorPlaceholder(dtype=dtypes.string)
135      handles = add_many_sparse_to_tensors_map(sparse_tensor)
136      roundtrip = take_many_sparse_from_tensors_map(
137          sparse_map_op=handles.op, sparse_handles=handles)
138      handles_value, roundtrip_value = sess.run(
139          [handles, roundtrip],
140          feed_dict={
141              sparse_tensor.indices: indices_value,
142              sparse_tensor.values: values_value,
143              sparse_tensor.dense_shape: shape_value
144          })
145      self.assertEqual(handles_value.shape, (4,))
146      self.assertAllEqual(roundtrip_value.indices, indices_value)
147      self.assertAllEqual(roundtrip_value.values, values_value)
148      self.assertAllEqual(roundtrip_value.dense_shape, shape_value)
149
150  @test_util.run_deprecated_v1
151  def testDeserializeFailsInconsistentRank(self):
152    with self.session(use_gpu=False) as sess:
153      sp_input = self._SparseTensorPlaceholder()
154      input0_val = self._SparseTensorValue_5x6(np.arange(6))
155      input1_val = self._SparseTensorValue_1x1x1()
156      handle = add_sparse_to_tensors_map(sp_input)
157
158      handle0_value = sess.run(handle, feed_dict={sp_input: input0_val})
159      handle1_value = sess.run(handle, feed_dict={sp_input: input1_val})
160
161      handle_concat = ops.convert_to_tensor(
162          [handle0_value, handle1_value], dtype=dtypes.int64)
163
164      sp_roundtrip = take_many_sparse_from_tensors_map(
165          sparse_map_op=handle.op, sparse_handles=handle_concat)
166
167      with self.assertRaisesOpError(
168          r"Inconsistent rank across SparseTensors: rank prior to "
169          r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"):
170        self.evaluate(sp_roundtrip)
171
172  @test_util.run_deprecated_v1
173  def testTakeManyFailsWrongInputOp(self):
174    with self.session(use_gpu=False) as sess:
175      input_val = self._SparseTensorValue_5x6(np.arange(6))
176      handle = add_sparse_to_tensors_map(input_val)
177      handle_value = self.evaluate(handle)
178      bad_handle = handle_value + 10
179      sp_roundtrip = take_many_sparse_from_tensors_map(
180          sparse_map_op=handle.op, sparse_handles=[handle_value, bad_handle])
181
182      with self.assertRaisesOpError(r"Unable to find SparseTensor: 10"):
183        self.evaluate(sp_roundtrip)
184
185
186class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark):
187
188  def benchmarkVeryLarge2DFloatSparseTensor(self):
189    np.random.seed(127)
190    num_elements = 10000
191    batch_size = 64
192    indices_batch = np.random.randint(
193        batch_size, size=num_elements, dtype=np.int64)
194    indices_value = np.arange(num_elements, dtype=np.int64)
195    indices = np.asarray(
196        sorted(zip(indices_batch, indices_value)), dtype=np.int64)
197    values = ["feature_value_for_embedding_lookup"] * num_elements
198    shape = np.asarray([batch_size, num_elements], dtype=np.int64)
199    with session.Session(config=benchmark.benchmark_config()) as sess:
200      with ops.device("/cpu:0"):
201        indices = variables.Variable(indices)
202        values = variables.Variable(values)
203        shape = variables.Variable(shape)
204        st = sparse_tensor_lib.SparseTensor(indices, values, shape)
205
206        st_handles = add_many_sparse_to_tensors_map(st)
207        st_roundtrip = take_many_sparse_from_tensors_map(
208            sparse_map_op=st_handles.op, sparse_handles=st_handles)
209        st_roundtrip_op = st_roundtrip.values.op
210
211        st_serialized = sparse_ops.serialize_many_sparse(st)
212        st_deserialized = sparse_ops.deserialize_many_sparse(
213            st_serialized, dtype=values.dtype)
214        st_deserialized_op = st_deserialized.values.op
215
216        self.evaluate(variables.global_variables_initializer())
217
218        st_roundtrip_values = self.evaluate(st_roundtrip)
219        st_deserialized_values = self.evaluate(st_deserialized)
220        np.testing.assert_equal(st_roundtrip_values.values,
221                                st_deserialized_values.values)
222        np.testing.assert_equal(st_roundtrip_values.indices,
223                                st_deserialized_values.indices)
224        np.testing.assert_equal(st_roundtrip_values.dense_shape,
225                                st_deserialized_values.dense_shape)
226
227        self.run_op_benchmark(
228            sess,
229            st_roundtrip_op,
230            min_iters=2000,
231            name="benchmark_very_large_2d_float_st_tensor_maps")
232        self.run_op_benchmark(
233            sess,
234            st_deserialized_op,
235            min_iters=2000,
236            name="benchmark_very_large_2d_float_st_serialization")
237
238
239if __name__ == "__main__":
240  test.main()
241