1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SparseConcat.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import sparse_ops 29from tensorflow.python.platform import test 30 31 32class SparseConcatTest(test.TestCase): 33 34 def _SparseTensor_UnknownShape(self, 35 ind_shape=None, 36 val_shape=None, 37 shape_shape=None): 38 return sparse_tensor.SparseTensor( 39 array_ops.placeholder( 40 dtypes.int64, shape=ind_shape), 41 array_ops.placeholder( 42 dtypes.float32, shape=val_shape), 43 array_ops.placeholder( 44 dtypes.int64, shape=shape_shape)) 45 46 def _SparseTensorValue_3x3(self): 47 # [ 1] 48 # [2 ] 49 # [3 4] 50 ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]]) 51 val = np.array([1, 2, 3, 4]) 52 shape = np.array([3, 3]) 53 return sparse_tensor.SparseTensorValue( 54 np.array(ind, np.int64), 55 np.array(val, np.float32), np.array(shape, np.int64)) 56 57 def _SparseTensor_3x3(self): 58 return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x3()) 59 60 def _SparseTensorValue_3x5(self): 61 # [ ] 62 # [ 1 ] 63 # [2 1 0] 64 ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]]) 65 val = np.array([1, 2, 1, 0]) 66 shape = np.array([3, 5]) 67 return sparse_tensor.SparseTensorValue( 68 np.array(ind, np.int64), 69 np.array(val, np.float32), np.array(shape, np.int64)) 70 71 def _SparseTensor_3x5(self): 72 return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x5()) 73 74 def _SparseTensor_3x2(self): 75 # [ ] 76 # [1 ] 77 # [2 ] 78 ind = np.array([[1, 0], [2, 0]]) 79 val = np.array([1, 2]) 80 shape = np.array([3, 2]) 81 return sparse_tensor.SparseTensor( 82 constant_op.constant(ind, dtypes.int64), 83 constant_op.constant(val, dtypes.float32), 84 constant_op.constant(shape, dtypes.int64)) 85 86 def _SparseTensor_2x3(self): 87 # [ 1 ] 88 # [1 2] 89 ind = np.array([[0, 1], [1, 0], [1, 2]]) 90 val = np.array([1, 1, 2]) 91 shape = np.array([2, 3]) 92 return sparse_tensor.SparseTensor( 93 constant_op.constant(ind, dtypes.int64), 94 constant_op.constant(val, dtypes.float32), 95 constant_op.constant(shape, dtypes.int64)) 96 97 def _SparseTensor_2x3x4(self): 98 ind = np.array([ 99 [0, 0, 1], 100 [0, 1, 0], [0, 1, 2], 101 [1, 0, 3], 102 [1, 1, 1], [1, 1, 3], 103 [1, 2, 2]]) 104 val = np.array([1, 10, 12, 103, 111, 113, 122]) 105 shape = np.array([2, 3, 4]) 106 return sparse_tensor.SparseTensor( 107 constant_op.constant(ind, dtypes.int64), 108 constant_op.constant(val, dtypes.float32), 109 constant_op.constant(shape, dtypes.int64)) 110 111 def _SparseTensor_NoNonZeros(self, dense_shape): 112 ind = np.empty(shape=(0, len(dense_shape))) 113 val = np.array([]) 114 shape = np.array(dense_shape) 115 return sparse_tensor.SparseTensor( 116 constant_op.constant(ind, dtypes.int64), 117 constant_op.constant(val, dtypes.float32), 118 constant_op.constant(shape, dtypes.int64)) 119 120 def _SparseTensor_String3x3(self): 121 # [ a] 122 # [b ] 123 # [c d] 124 ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]]) 125 val = np.array(["a", "b", "c", "d"]) 126 shape = np.array([3, 3]) 127 return sparse_tensor.SparseTensor( 128 constant_op.constant(ind, dtypes.int64), 129 constant_op.constant(val, dtypes.string), 130 constant_op.constant(shape, dtypes.int64)) 131 132 def _SparseTensor_String3x5(self): 133 # [ ] 134 # [ e ] 135 # [f g h] 136 ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]]) 137 val = np.array(["e", "f", "g", "h"]) 138 shape = np.array([3, 5]) 139 return sparse_tensor.SparseTensor( 140 constant_op.constant(ind, dtypes.int64), 141 constant_op.constant(val, dtypes.string), 142 constant_op.constant(shape, dtypes.int64)) 143 144 def testConcat1(self): 145 with self.session(use_gpu=False) as sess: 146 # concat(A): 147 # [ 1] 148 # [2 ] 149 # [3 4] 150 for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): 151 # Note that we ignore concat_dim in this case since we short-circuit the 152 # single-input case in python. 153 for concat_dim in (-2000, 1, 2000): 154 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a]) 155 156 self.assertEqual(sp_concat.indices.get_shape(), [4, 2]) 157 self.assertEqual(sp_concat.values.get_shape(), [4]) 158 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 159 160 concat_out = self.evaluate(sp_concat) 161 162 self.assertAllEqual(concat_out.indices, 163 [[0, 2], [1, 0], [2, 0], [2, 2]]) 164 self.assertAllEqual(concat_out.values, [1, 2, 3, 4]) 165 self.assertAllEqual(concat_out.dense_shape, [3, 3]) 166 167 def testConcat2(self): 168 with self.session(use_gpu=False) as sess: 169 # concat(A, B): 170 # [ 1 ] 171 # [2 1 ] 172 # [3 4 2 1 0] 173 for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): 174 for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()): 175 for concat_dim in (-1, 1): 176 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b]) 177 178 self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) 179 self.assertEqual(sp_concat.values.get_shape(), [8]) 180 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 181 182 concat_out = self.evaluate(sp_concat) 183 184 self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], 185 [2, 0], [2, 2], [2, 3], 186 [2, 6], [2, 7]]) 187 self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0]) 188 self.assertAllEqual(concat_out.dense_shape, [3, 8]) 189 190 def testConcatDim0(self): 191 with self.session(use_gpu=False) as sess: 192 # concat(A, D): 193 # [ 1] 194 # [2 ] 195 # [3 4] 196 # [ 1 ] 197 # [1 2] 198 sp_a = self._SparseTensor_3x3() 199 sp_d = self._SparseTensor_2x3() 200 201 for concat_dim in (-2, 0): 202 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_d]) 203 204 self.assertEqual(sp_concat.indices.get_shape(), [7, 2]) 205 self.assertEqual(sp_concat.values.get_shape(), [7]) 206 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 207 208 concat_out = self.evaluate(sp_concat) 209 210 self.assertAllEqual( 211 concat_out.indices, 212 [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]]) 213 self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2])) 214 self.assertAllEqual(concat_out.dense_shape, np.array([5, 3])) 215 216 def testConcat3(self): 217 with self.session(use_gpu=False) as sess: 218 # concat(A, B, C): 219 # [ 1 ] 220 # [2 1 1 ] 221 # [3 4 2 1 0 2 ] 222 sp_a = self._SparseTensor_3x3() 223 sp_b = self._SparseTensor_3x5() 224 sp_c = self._SparseTensor_3x2() 225 226 for concat_dim in (-1, 1): 227 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) 228 229 self.assertEqual(sp_concat.indices.get_shape(), [10, 2]) 230 self.assertEqual(sp_concat.values.get_shape(), [10]) 231 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 232 233 concat_out = self.evaluate(sp_concat) 234 235 self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [1, 8], 236 [2, 0], [2, 2], [2, 3], [2, 6], 237 [2, 7], [2, 8]]) 238 self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2]) 239 self.assertAllEqual(concat_out.dense_shape, [3, 10]) 240 241 def testConcatNoNonZeros(self): 242 sp_a = self._SparseTensor_NoNonZeros((2, 3, 4)) 243 sp_b = self._SparseTensor_NoNonZeros((2, 7, 4)) 244 sp_c = self._SparseTensor_NoNonZeros((2, 5, 4)) 245 246 with self.session(use_gpu=False) as sess: 247 concat_dim = 1 248 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) 249 250 self.assertEqual(sp_concat.indices.get_shape(), [0, 3]) 251 self.assertEqual(sp_concat.values.get_shape(), [0]) 252 self.assertEqual(sp_concat.dense_shape.get_shape(), [3]) 253 254 concat_out = self.evaluate(sp_concat) 255 256 self.assertEqual(concat_out.indices.shape, (0, 3)) 257 self.assertEqual(concat_out.values.shape, (0,)) 258 self.assertAllEqual(concat_out.dense_shape, [2, 15, 4]) 259 260 def testConcatSomeNoNonZeros(self): 261 sp_a = self._SparseTensor_NoNonZeros((2, 7, 4)) 262 sp_b = self._SparseTensor_2x3x4() 263 sp_c = self._SparseTensor_NoNonZeros((2, 5, 4)) 264 output_nnz = sp_b.indices.get_shape()[0] 265 266 with self.session(use_gpu=False) as sess: 267 concat_dim = 1 268 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) 269 270 self.assertEqual(sp_concat.indices.get_shape(), [output_nnz, 3]) 271 self.assertEqual(sp_concat.values.get_shape(), [output_nnz]) 272 self.assertEqual(sp_concat.dense_shape.get_shape(), [3]) 273 274 concat_out = self.evaluate(sp_concat) 275 276 self.assertAllEqual(concat_out.indices, 277 sp_b.indices + [0, sp_a.dense_shape[1], 0]) 278 self.assertAllEqual(concat_out.values, sp_b.values) 279 self.assertAllEqual(concat_out.dense_shape, [2, 15, 4]) 280 281 def testConcatNonNumeric(self): 282 with self.session(use_gpu=False) as sess: 283 # concat(A, B): 284 # [ a ] 285 # [b e ] 286 # [c d f g h] 287 sp_a = self._SparseTensor_String3x3() 288 sp_b = self._SparseTensor_String3x5() 289 290 for concat_dim in (-1, 1): 291 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b]) 292 293 self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) 294 self.assertEqual(sp_concat.values.get_shape(), [8]) 295 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 296 297 concat_out = self.evaluate(sp_concat) 298 299 self.assertAllEqual( 300 concat_out.indices, 301 [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]]) 302 self.assertAllEqual(concat_out.values, 303 [b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"]) 304 self.assertAllEqual(concat_out.dense_shape, [3, 8]) 305 306 @test_util.run_deprecated_v1 307 def testMismatchedRank(self): 308 with self.session(use_gpu=False): 309 sp_a = self._SparseTensor_3x3() 310 sp_e = self._SparseTensor_2x3x4() 311 312 # Rank mismatches can be caught at shape-inference time 313 for concat_dim in (-1, 1): 314 with self.assertRaises(ValueError): 315 sparse_ops.sparse_concat(concat_dim, [sp_a, sp_e]) 316 317 @test_util.run_deprecated_v1 318 def testMismatchedRankExpandNonconcatDim(self): 319 with self.session(use_gpu=False): 320 sp_a = self._SparseTensor_3x3() 321 sp_e = self._SparseTensor_2x3x4() 322 323 # Rank mismatches should be caught at shape-inference time, even for 324 # expand_nonconcat_dim=True. 325 for concat_dim in (-1, 1): 326 with self.assertRaises(ValueError): 327 sparse_ops.sparse_concat( 328 concat_dim, [sp_a, sp_e], expand_nonconcat_dim=True) 329 330 @test_util.run_deprecated_v1 331 def testMismatchedShapes(self): 332 with self.session(use_gpu=False) as sess: 333 sp_a = self._SparseTensor_3x3() 334 sp_b = self._SparseTensor_3x5() 335 sp_c = self._SparseTensor_3x2() 336 sp_d = self._SparseTensor_2x3() 337 for concat_dim in (-1, 1): 338 sp_concat = sparse_ops.sparse_concat(concat_dim, 339 [sp_a, sp_b, sp_c, sp_d]) 340 341 # Shape mismatches can only be caught when the op is run 342 with self.assertRaisesOpError("Input shapes must match"): 343 self.evaluate(sp_concat) 344 345 def testMismatchedShapesExpandNonconcatDim(self): 346 with self.session(use_gpu=False) as sess: 347 sp_a = self._SparseTensor_3x3() 348 sp_b = self._SparseTensor_3x5() 349 sp_c = self._SparseTensor_3x2() 350 sp_d = self._SparseTensor_2x3() 351 for concat_dim0 in (-2, 0): 352 for concat_dim1 in (-1, 1): 353 sp_concat_dim0 = sparse_ops.sparse_concat( 354 concat_dim0, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) 355 sp_concat_dim1 = sparse_ops.sparse_concat( 356 concat_dim1, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) 357 358 sp_concat_dim0_out = self.evaluate(sp_concat_dim0) 359 sp_concat_dim1_out = self.evaluate(sp_concat_dim1) 360 361 self.assertAllEqual(sp_concat_dim0_out.indices, 362 [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0], 363 [5, 3], [5, 4], [7, 0], [8, 0], [9, 1], [10, 0], 364 [10, 2]]) 365 self.assertAllEqual(sp_concat_dim0_out.values, 366 [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2]) 367 self.assertAllEqual(sp_concat_dim0_out.dense_shape, [11, 5]) 368 369 self.assertAllEqual(sp_concat_dim1_out.indices, 370 [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10], 371 [1, 12], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7], 372 [2, 8]]) 373 self.assertAllEqual(sp_concat_dim1_out.values, 374 [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2]) 375 self.assertAllEqual(sp_concat_dim1_out.dense_shape, [3, 13]) 376 377 @test_util.run_deprecated_v1 378 def testShapeInferenceUnknownShapes(self): 379 with self.session(use_gpu=False): 380 sp_inputs = [ 381 self._SparseTensor_UnknownShape(), 382 self._SparseTensor_UnknownShape(val_shape=[3]), 383 self._SparseTensor_UnknownShape(ind_shape=[1, 3]), 384 self._SparseTensor_UnknownShape(shape_shape=[3]) 385 ] 386 387 for concat_dim in (-2, 0): 388 sp_concat = sparse_ops.sparse_concat(concat_dim, sp_inputs) 389 390 self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3]) 391 self.assertEqual(sp_concat.values.get_shape().as_list(), [None]) 392 self.assertEqual(sp_concat.dense_shape.get_shape(), [3]) 393 394 def testConcatShape(self): 395 # Test case for GitHub 21964. 396 x = sparse_tensor.SparseTensor( 397 indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2]) 398 y = sparse_tensor.SparseTensor( 399 indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2]) 400 z = sparse_ops.sparse_concat(-1, [x, y]) 401 self.assertEqual(z.get_shape().as_list(), [2, 4]) 402 403 404if __name__ == "__main__": 405 test.main() 406