1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities working with arbitrarily nested structures.""" 16 17import functools 18 19from absl.testing import parameterized 20 21from tensorflow.python.data.kernel_tests import test_base 22from tensorflow.python.data.util import nest 23from tensorflow.python.data.util import sparse 24from tensorflow.python.framework import combinations 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.platform import test 31 32 33# NOTE(vikoth18): Arguments of parameterized tests are lifted into lambdas to make 34# sure they are not executed before the (eager- or graph-mode) test environment 35# has been set up. 36# 37 38 39def _test_any_sparse_combinations(): 40 41 cases = [("TestCase_0", lambda: (), False), 42 ("TestCase_1", lambda: (ops.Tensor), False), 43 ("TestCase_2", lambda: (((ops.Tensor))), False), 44 ("TestCase_3", lambda: (ops.Tensor, ops.Tensor), False), 45 ("TestCase_4", lambda: 46 (ops.Tensor, sparse_tensor.SparseTensor), True), 47 ("TestCase_5", lambda: 48 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), True), 49 ("TestCase_6", lambda: (((sparse_tensor.SparseTensor))), True)] 50 51 def reduce_fn(x, y): 52 name, classes_fn, expected = y 53 return x + combinations.combine( 54 classes_fn=combinations.NamedObject("classes_fn.{}".format(name), 55 classes_fn), 56 expected=expected) 57 58 return functools.reduce(reduce_fn, cases, []) 59 60 61def _test_as_dense_shapes_combinations(): 62 63 cases = [ 64 ("TestCase_0", lambda: (), lambda: (), lambda: ()), 65 ("TestCase_1", lambda: tensor_shape.TensorShape([]), lambda: ops.Tensor, 66 lambda: tensor_shape.TensorShape([])), 67 ( 68 "TestCase_2", 69 lambda: tensor_shape.TensorShape([]), 70 lambda: sparse_tensor.SparseTensor, 71 lambda: tensor_shape.unknown_shape() # pylint: disable=unnecessary-lambda 72 ), 73 ("TestCase_3", lambda: (tensor_shape.TensorShape([])), lambda: 74 (ops.Tensor), lambda: (tensor_shape.TensorShape([]))), 75 ( 76 "TestCase_4", 77 lambda: (tensor_shape.TensorShape([])), 78 lambda: (sparse_tensor.SparseTensor), 79 lambda: (tensor_shape.unknown_shape()) # pylint: disable=unnecessary-lambda 80 ), 81 ("TestCase_5", lambda: (tensor_shape.TensorShape([]), ()), lambda: 82 (ops.Tensor, ()), lambda: (tensor_shape.TensorShape([]), ())), 83 ("TestCase_6", lambda: ((), tensor_shape.TensorShape([])), lambda: 84 ((), ops.Tensor), lambda: ((), tensor_shape.TensorShape([]))), 85 ("TestCase_7", lambda: (tensor_shape.TensorShape([]), ()), lambda: 86 (sparse_tensor.SparseTensor, ()), lambda: (tensor_shape.unknown_shape(), 87 ())), 88 ("TestCase_8", lambda: ((), tensor_shape.TensorShape([])), lambda: 89 ((), sparse_tensor.SparseTensor), lambda: ( 90 (), tensor_shape.unknown_shape())), 91 ("TestCase_9", lambda: (tensor_shape.TensorShape([]), 92 (), tensor_shape.TensorShape([])), lambda: 93 (ops.Tensor, (), ops.Tensor), lambda: 94 (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([]))), 95 ("TestCase_10", lambda: (tensor_shape.TensorShape([]), 96 (), tensor_shape.TensorShape([])), lambda: 97 (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda: 98 (tensor_shape.unknown_shape(), (), tensor_shape.unknown_shape())), 99 ("TestCase_11", lambda: ((), tensor_shape.TensorShape([]), ()), lambda: 100 ((), ops.Tensor, ()), lambda: ((), tensor_shape.TensorShape([]), ())), 101 ("TestCase_12", lambda: ((), tensor_shape.TensorShape([]), ()), lambda: 102 ((), sparse_tensor.SparseTensor, 103 ()), lambda: ((), tensor_shape.unknown_shape(), ())) 104 ] 105 106 def reduce_fn(x, y): 107 name, types_fn, classes_fn, expected_fn = y 108 return x + combinations.combine( 109 types_fn=combinations.NamedObject("types_fn.{}".format(name), types_fn), 110 classes_fn=combinations.NamedObject("classes_fn.{}".format(name), 111 classes_fn), 112 expected_fn=combinations.NamedObject("expected_fn.{}".format(name), 113 expected_fn)) 114 115 return functools.reduce(reduce_fn, cases, []) 116 117 118def _test_as_dense_types_combinations(): 119 cases = [ 120 ("TestCase_0", lambda: (), lambda: (), lambda: ()), 121 ("TestCase_1", lambda: dtypes.int32, lambda: ops.Tensor, 122 lambda: dtypes.int32), 123 ("TestCase_2", lambda: dtypes.int32, lambda: sparse_tensor.SparseTensor, 124 lambda: dtypes.variant), 125 ("TestCase_3", lambda: (dtypes.int32), lambda: (ops.Tensor), lambda: 126 (dtypes.int32)), 127 ("TestCase_4", lambda: (dtypes.int32), lambda: 128 (sparse_tensor.SparseTensor), lambda: (dtypes.variant)), 129 ("TestCase_5", lambda: (dtypes.int32, ()), lambda: 130 (ops.Tensor, ()), lambda: (dtypes.int32, ())), 131 ("TestCase_6", lambda: ((), dtypes.int32), lambda: 132 ((), ops.Tensor), lambda: ((), dtypes.int32)), 133 ("TestCase_7", lambda: (dtypes.int32, ()), lambda: 134 (sparse_tensor.SparseTensor, ()), lambda: (dtypes.variant, ())), 135 ("TestCase_8", lambda: ((), dtypes.int32), lambda: 136 ((), sparse_tensor.SparseTensor), lambda: ((), dtypes.variant)), 137 ("TestCase_9", lambda: (dtypes.int32, (), dtypes.int32), lambda: 138 (ops.Tensor, (), ops.Tensor), lambda: (dtypes.int32, (), dtypes.int32)), 139 ("TestCase_10", lambda: (dtypes.int32, (), dtypes.int32), lambda: 140 (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda: 141 (dtypes.variant, (), dtypes.variant)), 142 ("TestCase_11", lambda: ((), dtypes.int32, ()), lambda: 143 ((), ops.Tensor, ()), lambda: ((), dtypes.int32, ())), 144 ("TestCase_12", lambda: ((), dtypes.int32, ()), lambda: 145 ((), sparse_tensor.SparseTensor, ()), lambda: ((), dtypes.variant, ())), 146 ] 147 148 def reduce_fn(x, y): 149 name, types_fn, classes_fn, expected_fn = y 150 return x + combinations.combine( 151 types_fn=combinations.NamedObject("types_fn.{}".format(name), types_fn), 152 classes_fn=combinations.NamedObject("classes_fn.{}".format(name), 153 classes_fn), 154 expected_fn=combinations.NamedObject("expected_fn.{}".format(name), 155 expected_fn)) 156 157 return functools.reduce(reduce_fn, cases, []) 158 159 160def _test_get_classes_combinations(): 161 cases = [ 162 ("TestCase_0", lambda: (), lambda: ()), 163 ("TestCase_1", lambda: sparse_tensor.SparseTensor( 164 indices=[[0]], values=[1], dense_shape=[1]), 165 lambda: sparse_tensor.SparseTensor), 166 ("TestCase_2", lambda: constant_op.constant([1]), lambda: ops.Tensor), 167 ("TestCase_3", lambda: 168 (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])), 169 lambda: (sparse_tensor.SparseTensor)), 170 ("TestCase_4", lambda: (constant_op.constant([1])), lambda: (ops.Tensor)), 171 ("TestCase_5", lambda: 172 (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), 173 ()), lambda: (sparse_tensor.SparseTensor, ())), 174 ("TestCase_6", lambda: 175 ((), 176 sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])), 177 lambda: ((), sparse_tensor.SparseTensor)), 178 ("TestCase_7", lambda: (constant_op.constant([1]), ()), lambda: 179 (ops.Tensor, ())), 180 ("TestCase_8", lambda: ((), constant_op.constant([1])), lambda: 181 ((), ops.Tensor)), 182 ("TestCase_9", lambda: 183 (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), 184 (), constant_op.constant([1])), lambda: (sparse_tensor.SparseTensor, 185 (), ops.Tensor)), 186 ("TestCase_10", lambda: 187 ((), 188 sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]), 189 ()), lambda: ((), sparse_tensor.SparseTensor, ())), 190 ("TestCase_11", lambda: ((), constant_op.constant([1]), ()), lambda: 191 ((), ops.Tensor, ())), 192 ] 193 194 def reduce_fn(x, y): 195 name, classes_fn, expected_fn = y 196 return x + combinations.combine( 197 classes_fn=combinations.NamedObject("classes_fn.{}".format(name), 198 classes_fn), 199 expected_fn=combinations.NamedObject("expected_fn.{}".format(name), 200 expected_fn)) 201 202 return functools.reduce(reduce_fn, cases, []) 203 204 205def _test_serialize_deserialize_combinations(): 206 cases = [("TestCase_0", lambda: ()), 207 ("TestCase_1", lambda: sparse_tensor.SparseTensor( 208 indices=[[0, 0]], values=[1], dense_shape=[1, 1])), 209 ("TestCase_2", lambda: sparse_tensor.SparseTensor( 210 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])), 211 ("TestCase_3", lambda: sparse_tensor.SparseTensor( 212 indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5])), 213 ("TestCase_4", lambda: (sparse_tensor.SparseTensor( 214 indices=[[0, 0]], values=[1], dense_shape=[1, 1]))), 215 ("TestCase_5", lambda: (sparse_tensor.SparseTensor( 216 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())), 217 ("TestCase_6", lambda: 218 ((), 219 sparse_tensor.SparseTensor( 220 indices=[[0, 0]], values=[1], dense_shape=[1, 1])))] 221 222 def reduce_fn(x, y): 223 name, input_fn = y 224 return x + combinations.combine( 225 input_fn=combinations.NamedObject("input_fn.{}".format(name), input_fn)) 226 227 return functools.reduce(reduce_fn, cases, []) 228 229 230def _test_serialize_many_deserialize_combinations(): 231 cases = [("TestCase_0", lambda: ()), 232 ("TestCase_1", lambda: sparse_tensor.SparseTensor( 233 indices=[[0, 0]], values=[1], dense_shape=[1, 1])), 234 ("TestCase_2", lambda: sparse_tensor.SparseTensor( 235 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])), 236 ("TestCase_3", lambda: sparse_tensor.SparseTensor( 237 indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5])), 238 ("TestCase_4", lambda: (sparse_tensor.SparseTensor( 239 indices=[[0, 0]], values=[1], dense_shape=[1, 1]))), 240 ("TestCase_5", lambda: (sparse_tensor.SparseTensor( 241 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())), 242 ("TestCase_6", lambda: 243 ((), 244 sparse_tensor.SparseTensor( 245 indices=[[0, 0]], values=[1], dense_shape=[1, 1])))] 246 247 def reduce_fn(x, y): 248 name, input_fn = y 249 return x + combinations.combine( 250 input_fn=combinations.NamedObject("input_fn.{}".format(name), input_fn)) 251 252 return functools.reduce(reduce_fn, cases, []) 253 254 255class SparseTest(test_base.DatasetTestBase, parameterized.TestCase): 256 257 @combinations.generate( 258 combinations.times(test_base.default_test_combinations(), 259 _test_any_sparse_combinations())) 260 def testAnySparse(self, classes_fn, expected): 261 classes = classes_fn() 262 self.assertEqual(sparse.any_sparse(classes), expected) 263 264 def assertShapesEqual(self, a, b): 265 for a, b in zip(nest.flatten(a), nest.flatten(b)): 266 self.assertEqual(a.ndims, b.ndims) 267 if a.ndims is None: 268 continue 269 for c, d in zip(a.as_list(), b.as_list()): 270 self.assertEqual(c, d) 271 272 @combinations.generate( 273 combinations.times(test_base.default_test_combinations(), 274 _test_as_dense_shapes_combinations())) 275 def testAsDenseShapes(self, types_fn, classes_fn, expected_fn): 276 types = types_fn() 277 classes = classes_fn() 278 expected = expected_fn() 279 self.assertShapesEqual(sparse.as_dense_shapes(types, classes), expected) 280 281 @combinations.generate( 282 combinations.times(test_base.default_test_combinations(), 283 _test_as_dense_types_combinations())) 284 def testAsDenseTypes(self, types_fn, classes_fn, expected_fn): 285 types = types_fn() 286 classes = classes_fn() 287 expected = expected_fn() 288 self.assertEqual(sparse.as_dense_types(types, classes), expected) 289 290 @combinations.generate( 291 combinations.times(test_base.default_test_combinations(), 292 _test_get_classes_combinations())) 293 def testGetClasses(self, classes_fn, expected_fn): 294 classes = classes_fn() 295 expected = expected_fn() 296 self.assertEqual(sparse.get_classes(classes), expected) 297 298 def assertSparseValuesEqual(self, a, b): 299 if not isinstance(a, sparse_tensor.SparseTensor): 300 self.assertFalse(isinstance(b, sparse_tensor.SparseTensor)) 301 self.assertEqual(a, b) 302 return 303 self.assertTrue(isinstance(b, sparse_tensor.SparseTensor)) 304 with self.cached_session(): 305 self.assertAllEqual(a.eval().indices, self.evaluate(b).indices) 306 self.assertAllEqual(a.eval().values, self.evaluate(b).values) 307 self.assertAllEqual(a.eval().dense_shape, self.evaluate(b).dense_shape) 308 309 @combinations.generate( 310 combinations.times(test_base.graph_only_combinations(), 311 _test_serialize_deserialize_combinations())) 312 def testSerializeDeserialize(self, input_fn): 313 test_case = input_fn() 314 classes = sparse.get_classes(test_case) 315 shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), 316 classes) 317 types = nest.map_structure(lambda _: dtypes.int32, classes) 318 actual = sparse.deserialize_sparse_tensors( 319 sparse.serialize_sparse_tensors(test_case), types, shapes, 320 sparse.get_classes(test_case)) 321 nest.assert_same_structure(test_case, actual) 322 for a, e in zip(nest.flatten(actual), nest.flatten(test_case)): 323 self.assertSparseValuesEqual(a, e) 324 325 @combinations.generate( 326 combinations.times(test_base.graph_only_combinations(), 327 _test_serialize_many_deserialize_combinations())) 328 def testSerializeManyDeserialize(self, input_fn): 329 test_case = input_fn() 330 classes = sparse.get_classes(test_case) 331 shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), 332 classes) 333 types = nest.map_structure(lambda _: dtypes.int32, classes) 334 actual = sparse.deserialize_sparse_tensors( 335 sparse.serialize_many_sparse_tensors(test_case), types, shapes, 336 sparse.get_classes(test_case)) 337 nest.assert_same_structure(test_case, actual) 338 for a, e in zip(nest.flatten(actual), nest.flatten(test_case)): 339 self.assertSparseValuesEqual(a, e) 340 341 342if __name__ == "__main__": 343 test.main() 344