• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SparseConcat."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import sparse_ops
29from tensorflow.python.platform import test
30
31
32class SparseConcatTest(test.TestCase):
33
34  def _SparseTensor_UnknownShape(self,
35                                 ind_shape=None,
36                                 val_shape=None,
37                                 shape_shape=None):
38    return sparse_tensor.SparseTensor(
39        array_ops.placeholder(
40            dtypes.int64, shape=ind_shape),
41        array_ops.placeholder(
42            dtypes.float32, shape=val_shape),
43        array_ops.placeholder(
44            dtypes.int64, shape=shape_shape))
45
46  def _SparseTensorValue_3x3(self):
47    # [    1]
48    # [2    ]
49    # [3   4]
50    ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
51    val = np.array([1, 2, 3, 4])
52    shape = np.array([3, 3])
53    return sparse_tensor.SparseTensorValue(
54        np.array(ind, np.int64),
55        np.array(val, np.float32), np.array(shape, np.int64))
56
57  def _SparseTensor_3x3(self):
58    return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x3())
59
60  def _SparseTensorValue_3x5(self):
61    # [         ]
62    # [  1      ]
63    # [2     1 0]
64    ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
65    val = np.array([1, 2, 1, 0])
66    shape = np.array([3, 5])
67    return sparse_tensor.SparseTensorValue(
68        np.array(ind, np.int64),
69        np.array(val, np.float32), np.array(shape, np.int64))
70
71  def _SparseTensor_3x5(self):
72    return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x5())
73
74  def _SparseTensor_3x2(self):
75    # [   ]
76    # [1  ]
77    # [2  ]
78    ind = np.array([[1, 0], [2, 0]])
79    val = np.array([1, 2])
80    shape = np.array([3, 2])
81    return sparse_tensor.SparseTensor(
82        constant_op.constant(ind, dtypes.int64),
83        constant_op.constant(val, dtypes.float32),
84        constant_op.constant(shape, dtypes.int64))
85
86  def _SparseTensor_2x3(self):
87    # [  1  ]
88    # [1   2]
89    ind = np.array([[0, 1], [1, 0], [1, 2]])
90    val = np.array([1, 1, 2])
91    shape = np.array([2, 3])
92    return sparse_tensor.SparseTensor(
93        constant_op.constant(ind, dtypes.int64),
94        constant_op.constant(val, dtypes.float32),
95        constant_op.constant(shape, dtypes.int64))
96
97  def _SparseTensor_2x3x4(self):
98    ind = np.array([
99        [0, 0, 1],
100        [0, 1, 0], [0, 1, 2],
101        [1, 0, 3],
102        [1, 1, 1], [1, 1, 3],
103        [1, 2, 2]])
104    val = np.array([1, 10, 12, 103, 111, 113, 122])
105    shape = np.array([2, 3, 4])
106    return sparse_tensor.SparseTensor(
107        constant_op.constant(ind, dtypes.int64),
108        constant_op.constant(val, dtypes.float32),
109        constant_op.constant(shape, dtypes.int64))
110
111  def _SparseTensor_NoNonZeros(self, dense_shape):
112    ind = np.empty(shape=(0, len(dense_shape)))
113    val = np.array([])
114    shape = np.array(dense_shape)
115    return sparse_tensor.SparseTensor(
116        constant_op.constant(ind, dtypes.int64),
117        constant_op.constant(val, dtypes.float32),
118        constant_op.constant(shape, dtypes.int64))
119
120  def _SparseTensor_String3x3(self):
121    # [    a]
122    # [b    ]
123    # [c   d]
124    ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
125    val = np.array(["a", "b", "c", "d"])
126    shape = np.array([3, 3])
127    return sparse_tensor.SparseTensor(
128        constant_op.constant(ind, dtypes.int64),
129        constant_op.constant(val, dtypes.string),
130        constant_op.constant(shape, dtypes.int64))
131
132  def _SparseTensor_String3x5(self):
133    # [         ]
134    # [  e      ]
135    # [f     g h]
136    ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
137    val = np.array(["e", "f", "g", "h"])
138    shape = np.array([3, 5])
139    return sparse_tensor.SparseTensor(
140        constant_op.constant(ind, dtypes.int64),
141        constant_op.constant(val, dtypes.string),
142        constant_op.constant(shape, dtypes.int64))
143
144  def testConcat1(self):
145    with self.session(use_gpu=False) as sess:
146      # concat(A):
147      # [    1]
148      # [2    ]
149      # [3   4]
150      for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
151        # Note that we ignore concat_dim in this case since we short-circuit the
152        # single-input case in python.
153        for concat_dim in (-2000, 1, 2000):
154          sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a])
155
156          self.assertEqual(sp_concat.indices.get_shape(), [4, 2])
157          self.assertEqual(sp_concat.values.get_shape(), [4])
158          self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
159
160          concat_out = self.evaluate(sp_concat)
161
162          self.assertAllEqual(concat_out.indices,
163                              [[0, 2], [1, 0], [2, 0], [2, 2]])
164          self.assertAllEqual(concat_out.values, [1, 2, 3, 4])
165          self.assertAllEqual(concat_out.dense_shape, [3, 3])
166
167  def testConcat2(self):
168    with self.session(use_gpu=False) as sess:
169      # concat(A, B):
170      # [    1          ]
171      # [2       1      ]
172      # [3   4 2     1 0]
173      for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
174        for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()):
175          for concat_dim in (-1, 1):
176            sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b])
177
178            self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
179            self.assertEqual(sp_concat.values.get_shape(), [8])
180            self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
181
182            concat_out = self.evaluate(sp_concat)
183
184            self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4],
185                                                     [2, 0], [2, 2], [2, 3],
186                                                     [2, 6], [2, 7]])
187            self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0])
188            self.assertAllEqual(concat_out.dense_shape, [3, 8])
189
190  def testConcatDim0(self):
191    with self.session(use_gpu=False) as sess:
192      # concat(A, D):
193      # [    1]
194      # [2    ]
195      # [3   4]
196      # [  1  ]
197      # [1   2]
198      sp_a = self._SparseTensor_3x3()
199      sp_d = self._SparseTensor_2x3()
200
201      for concat_dim in (-2, 0):
202        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_d])
203
204        self.assertEqual(sp_concat.indices.get_shape(), [7, 2])
205        self.assertEqual(sp_concat.values.get_shape(), [7])
206        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
207
208        concat_out = self.evaluate(sp_concat)
209
210        self.assertAllEqual(
211            concat_out.indices,
212            [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]])
213        self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2]))
214        self.assertAllEqual(concat_out.dense_shape, np.array([5, 3]))
215
216  def testConcat3(self):
217    with self.session(use_gpu=False) as sess:
218      # concat(A, B, C):
219      # [    1              ]
220      # [2       1       1  ]
221      # [3   4 2     1 0 2  ]
222      sp_a = self._SparseTensor_3x3()
223      sp_b = self._SparseTensor_3x5()
224      sp_c = self._SparseTensor_3x2()
225
226      for concat_dim in (-1, 1):
227        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c])
228
229        self.assertEqual(sp_concat.indices.get_shape(), [10, 2])
230        self.assertEqual(sp_concat.values.get_shape(), [10])
231        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
232
233        concat_out = self.evaluate(sp_concat)
234
235        self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [1, 8],
236                                                 [2, 0], [2, 2], [2, 3], [2, 6],
237                                                 [2, 7], [2, 8]])
238        self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2])
239        self.assertAllEqual(concat_out.dense_shape, [3, 10])
240
241  def testConcatNoNonZeros(self):
242    sp_a = self._SparseTensor_NoNonZeros((2, 3, 4))
243    sp_b = self._SparseTensor_NoNonZeros((2, 7, 4))
244    sp_c = self._SparseTensor_NoNonZeros((2, 5, 4))
245
246    with self.session(use_gpu=False) as sess:
247      concat_dim = 1
248      sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c])
249
250      self.assertEqual(sp_concat.indices.get_shape(), [0, 3])
251      self.assertEqual(sp_concat.values.get_shape(), [0])
252      self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
253
254      concat_out = self.evaluate(sp_concat)
255
256      self.assertEqual(concat_out.indices.shape, (0, 3))
257      self.assertEqual(concat_out.values.shape, (0,))
258      self.assertAllEqual(concat_out.dense_shape, [2, 15, 4])
259
260  def testConcatSomeNoNonZeros(self):
261    sp_a = self._SparseTensor_NoNonZeros((2, 7, 4))
262    sp_b = self._SparseTensor_2x3x4()
263    sp_c = self._SparseTensor_NoNonZeros((2, 5, 4))
264    output_nnz = sp_b.indices.get_shape()[0]
265
266    with self.session(use_gpu=False) as sess:
267      concat_dim = 1
268      sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c])
269
270      self.assertEqual(sp_concat.indices.get_shape(), [output_nnz, 3])
271      self.assertEqual(sp_concat.values.get_shape(), [output_nnz])
272      self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
273
274      concat_out = self.evaluate(sp_concat)
275
276      self.assertAllEqual(concat_out.indices,
277                          sp_b.indices + [0, sp_a.dense_shape[1], 0])
278      self.assertAllEqual(concat_out.values, sp_b.values)
279      self.assertAllEqual(concat_out.dense_shape, [2, 15, 4])
280
281  def testConcatNonNumeric(self):
282    with self.session(use_gpu=False) as sess:
283      # concat(A, B):
284      # [    a          ]
285      # [b       e      ]
286      # [c   d f     g h]
287      sp_a = self._SparseTensor_String3x3()
288      sp_b = self._SparseTensor_String3x5()
289
290      for concat_dim in (-1, 1):
291        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b])
292
293        self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
294        self.assertEqual(sp_concat.values.get_shape(), [8])
295        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
296
297        concat_out = self.evaluate(sp_concat)
298
299        self.assertAllEqual(
300            concat_out.indices,
301            [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
302        self.assertAllEqual(concat_out.values,
303                            [b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"])
304        self.assertAllEqual(concat_out.dense_shape, [3, 8])
305
306  @test_util.run_deprecated_v1
307  def testMismatchedRank(self):
308    with self.session(use_gpu=False):
309      sp_a = self._SparseTensor_3x3()
310      sp_e = self._SparseTensor_2x3x4()
311
312      # Rank mismatches can be caught at shape-inference time
313      for concat_dim in (-1, 1):
314        with self.assertRaises(ValueError):
315          sparse_ops.sparse_concat(concat_dim, [sp_a, sp_e])
316
317  @test_util.run_deprecated_v1
318  def testMismatchedRankExpandNonconcatDim(self):
319    with self.session(use_gpu=False):
320      sp_a = self._SparseTensor_3x3()
321      sp_e = self._SparseTensor_2x3x4()
322
323      # Rank mismatches should be caught at shape-inference time, even for
324      # expand_nonconcat_dim=True.
325      for concat_dim in (-1, 1):
326        with self.assertRaises(ValueError):
327          sparse_ops.sparse_concat(
328              concat_dim, [sp_a, sp_e], expand_nonconcat_dim=True)
329
330  @test_util.run_deprecated_v1
331  def testMismatchedShapes(self):
332    with self.session(use_gpu=False) as sess:
333      sp_a = self._SparseTensor_3x3()
334      sp_b = self._SparseTensor_3x5()
335      sp_c = self._SparseTensor_3x2()
336      sp_d = self._SparseTensor_2x3()
337      for concat_dim in (-1, 1):
338        sp_concat = sparse_ops.sparse_concat(concat_dim,
339                                             [sp_a, sp_b, sp_c, sp_d])
340
341        # Shape mismatches can only be caught when the op is run
342        with self.assertRaisesOpError("Input shapes must match"):
343          self.evaluate(sp_concat)
344
345  def testMismatchedShapesExpandNonconcatDim(self):
346    with self.session(use_gpu=False) as sess:
347      sp_a = self._SparseTensor_3x3()
348      sp_b = self._SparseTensor_3x5()
349      sp_c = self._SparseTensor_3x2()
350      sp_d = self._SparseTensor_2x3()
351      for concat_dim0 in (-2, 0):
352        for concat_dim1 in (-1, 1):
353          sp_concat_dim0 = sparse_ops.sparse_concat(
354              concat_dim0, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True)
355          sp_concat_dim1 = sparse_ops.sparse_concat(
356              concat_dim1, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True)
357
358          sp_concat_dim0_out = self.evaluate(sp_concat_dim0)
359          sp_concat_dim1_out = self.evaluate(sp_concat_dim1)
360
361          self.assertAllEqual(sp_concat_dim0_out.indices,
362                              [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0],
363                               [5, 3], [5, 4], [7, 0], [8, 0], [9, 1], [10, 0],
364                               [10, 2]])
365          self.assertAllEqual(sp_concat_dim0_out.values,
366                              [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2])
367          self.assertAllEqual(sp_concat_dim0_out.dense_shape, [11, 5])
368
369          self.assertAllEqual(sp_concat_dim1_out.indices,
370                              [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10],
371                               [1, 12], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7],
372                               [2, 8]])
373          self.assertAllEqual(sp_concat_dim1_out.values,
374                              [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2])
375          self.assertAllEqual(sp_concat_dim1_out.dense_shape, [3, 13])
376
377  @test_util.run_deprecated_v1
378  def testShapeInferenceUnknownShapes(self):
379    with self.session(use_gpu=False):
380      sp_inputs = [
381          self._SparseTensor_UnknownShape(),
382          self._SparseTensor_UnknownShape(val_shape=[3]),
383          self._SparseTensor_UnknownShape(ind_shape=[1, 3]),
384          self._SparseTensor_UnknownShape(shape_shape=[3])
385      ]
386
387      for concat_dim in (-2, 0):
388        sp_concat = sparse_ops.sparse_concat(concat_dim, sp_inputs)
389
390        self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3])
391        self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
392        self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
393
394  def testConcatShape(self):
395    # Test case for GitHub 21964.
396    x = sparse_tensor.SparseTensor(
397        indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
398    y = sparse_tensor.SparseTensor(
399        indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
400    z = sparse_ops.sparse_concat(-1, [x, y])
401    self.assertEqual(z.get_shape().as_list(), [2, 4])
402
403
404if __name__ == "__main__":
405  test.main()
406