• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.ragged.cross and tf.ragged.cross_hashed."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23import numpy as np
24
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import sparse_ops
31from tensorflow.python.ops.ragged import ragged_array_ops
32from tensorflow.python.ops.ragged import ragged_factory_ops
33from tensorflow.python.ops.ragged import ragged_tensor
34from tensorflow.python.platform import googletest
35
36ragged_const = ragged_factory_ops.constant_value
37dense_const = np.array
38
39
40def sparse_const(matrix):
41  indices = []
42  values = []
43  for i, row in enumerate(matrix):
44    for j, val in enumerate(row):
45      indices.append([i, j])
46      values.append(val)
47  shape = [len(matrix), max(len(row) for row in matrix)] if matrix else [0, 0]
48  if not values:
49    indices = np.zeros([0, 2], dtype=np.int64)
50    values = np.zeros([0], dtype=np.int64)
51  return sparse_tensor.SparseTensorValue(indices, values, shape)
52
53
54@test_util.run_all_in_graph_and_eager_modes
55class RaggedCrossOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
56
57  @parameterized.named_parameters([
58      dict(
59          testcase_name='NoInputs',
60          inputs=[],
61          expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
62      dict(
63          testcase_name='OneInput_RaggedStr',
64          inputs=[ragged_const([['a', 'b'], [], ['c']])],
65          expected=ragged_const([[b'a', b'b'], [], [b'c']])),
66      dict(
67          testcase_name='OneInput_RaggedInt',
68          inputs=[ragged_const([[1, 2, 3], [4, 5]])],
69          expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5']])),
70      dict(
71          testcase_name='OneInput_DenseInt',
72          inputs=[dense_const([[1, 2, 3], [4, 5, 6]])],
73          expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5', b'6']])),
74      dict(
75          testcase_name='OneInput_SparseStr',
76          inputs=[sparse_const([['a', 'b'], [], ['c']])],
77          expected=ragged_const([[b'a', b'b'], [], [b'c']])),
78      dict(
79          testcase_name='TwoInputs_RaggedStr_RaggedStr',
80          inputs=[
81              ragged_const([['a', 'b'], [], ['c']]),
82              ragged_const([['d', 'e'], ['f'], ['g']])
83          ],
84          expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
85                                 [b'c_X_g']])),
86      dict(
87          testcase_name='TwoInputs_RaggedInt_RaggedInt',
88          inputs=[
89              ragged_const([[1, 2], [], [3]]),
90              ragged_const([[4, 5, 6], [], [7]])
91          ],
92          expected=ragged_const(
93              [[b'1_X_4', b'1_X_5', b'1_X_6', b'2_X_4', b'2_X_5', b'2_X_6'], [],
94               [b'3_X_7']])),
95      dict(
96          testcase_name='TwoInputs_RaggedStr_RaggedInt',
97          inputs=[
98              ragged_const([['a', 'b'], [], ['c']]),
99              ragged_const([['1', '2'], ['3'], ['4']])
100          ],
101          expected=ragged_const([[b'a_X_1', b'a_X_2', b'b_X_1', b'b_X_2'], [],
102                                 [b'c_X_4']])),
103      dict(
104          testcase_name='TwoInputs_SparseStr_SparseStr',
105          inputs=[
106              sparse_const([['a', 'b'], [], ['c']]),
107              sparse_const([['d', 'e'], ['f'], ['g']])
108          ],
109          expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
110                                 [b'c_X_g']])),
111      dict(
112          testcase_name='TwoInputs_DenseInt_DenseInt',
113          inputs=[dense_const([[1, 2], [3, 4]]),
114                  dense_const([[5, 6], [7, 8]])],
115          expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
116                                 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
117      dict(
118          testcase_name='TwoInputs_DenseInt_DenseStr',
119          inputs=[
120              dense_const([[1, 2], [3, 4]]),
121              dense_const([[b'5', b'6'], [b'7', b'8']])
122          ],
123          expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
124                                 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
125      dict(
126          testcase_name='TwoInputs_RaggedInt_DenseInt',
127          inputs=[
128              ragged_const([[], [], [1, 2], [3]]),
129              dense_const([[1, 2], [3, 4], [5, 6], [7, 8]])
130          ],
131          expected=ragged_const([[], [],
132                                 [b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
133                                 [b'3_X_7', b'3_X_8']])),
134      dict(
135          # This test exercises `input_order`.
136          testcase_name='TwoInputs_DenseInt_RaggedStr',
137          inputs=[
138              dense_const([[1, 2], [3, 4], [5, 6]]),
139              ragged_const([['d', 'e'], ['f'], ['g']])
140          ],
141          expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
142                                 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
143          matches_sparse_cross=False  # sparse doesn't preserve input order.
144      ),
145      dict(
146          # This test exercises `input_order`.
147          testcase_name='TwoInputs_SparseInt_RaggedStr',
148          inputs=[
149              sparse_const([[1, 2], [3, 4], [5, 6]]),
150              ragged_const([['d', 'e'], ['f'], ['g']])
151          ],
152          expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
153                                 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
154          matches_sparse_cross=False  # sparse doesn't preserve input order.
155      ),
156      dict(
157          testcase_name='ThreeInputs_RaggedInt_RaggedInt_RaggedInt',
158          inputs=[
159              ragged_const([[11], [12, 13], [], [14, 15]]),
160              ragged_const([[21, 22], [23], [24, 25], [26, 27]]),
161              ragged_const([[31], [32, 33], [34, 35], [36, 37]])
162          ],
163          expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
164                                 [
165                                     b'12_X_23_X_32', b'12_X_23_X_33',
166                                     b'13_X_23_X_32', b'13_X_23_X_33'
167                                 ], [],
168                                 [
169                                     b'14_X_26_X_36', b'14_X_26_X_37',
170                                     b'14_X_27_X_36', b'14_X_27_X_37',
171                                     b'15_X_26_X_36', b'15_X_26_X_37',
172                                     b'15_X_27_X_36', b'15_X_27_X_37'
173                                 ]])),
174      dict(
175          testcase_name='ThreeInputs_RaggedInt_SparseInt_DenseInt',
176          inputs=[
177              ragged_const([[11], [12, 13], [], [14, 15]]),
178              sparse_const([[21, 22], [23], [24, 25], [26, 27]]),
179              dense_const([[31], [32], [33], [34]])
180          ],
181          expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
182                                 [
183                                     b'12_X_23_X_32',
184                                     b'13_X_23_X_32',
185                                 ], [],
186                                 [
187                                     b'14_X_26_X_34',
188                                     b'14_X_27_X_34',
189                                     b'15_X_26_X_34',
190                                     b'15_X_27_X_34',
191                                 ]])),
192      dict(
193          testcase_name='FiveInputs',
194          inputs=[
195              ragged_const([[1]]),
196              dense_const([[2]]),
197              ragged_const([[3]]),
198              sparse_const([[4]]),
199              ragged_const([[5]])
200          ],
201          expected=ragged_const([[b'1_X_2_X_3_X_4_X_5']]),
202          matches_sparse_cross=False  # sparse doesn't preserve input order.
203      ),
204      dict(
205          testcase_name='Permutation_3x3x3',
206          inputs=[[['11', '12', '13']], [['21', '22', '23']],
207                  [['31', '32', '33']]],
208          expected=[[
209              b'11_X_21_X_31', b'11_X_21_X_32', b'11_X_21_X_33',
210              b'11_X_22_X_31', b'11_X_22_X_32', b'11_X_22_X_33',
211              b'11_X_23_X_31', b'11_X_23_X_32', b'11_X_23_X_33',
212              b'12_X_21_X_31', b'12_X_21_X_32', b'12_X_21_X_33',
213              b'12_X_22_X_31', b'12_X_22_X_32', b'12_X_22_X_33',
214              b'12_X_23_X_31', b'12_X_23_X_32', b'12_X_23_X_33',
215              b'13_X_21_X_31', b'13_X_21_X_32', b'13_X_21_X_33',
216              b'13_X_22_X_31', b'13_X_22_X_32', b'13_X_22_X_33',
217              b'13_X_23_X_31', b'13_X_23_X_32', b'13_X_23_X_33'
218          ]]),
219      dict(
220          testcase_name='BatchSizeZero',
221          inputs=[
222              ragged_const([], ragged_rank=1, dtype=dtypes.int32),
223              sparse_const([]),
224              np.zeros([0, 3], dtype=np.int32),
225          ],
226          expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
227      dict(
228          testcase_name='ThreeInputs_OneEmpty',
229          inputs=[
230              ragged_const([[1, 2]]),
231              ragged_const([[]], dtype=dtypes.int32),
232              ragged_const([[3, 4]])
233          ],
234          expected=ragged_const([[]], dtype=dtypes.string)),
235      dict(
236          testcase_name='ThreeInputs_AllEmpty',
237          inputs=[
238              ragged_const([[]], dtype=dtypes.int64),
239              ragged_const([[]], dtype=dtypes.string),
240              ragged_const([[]], dtype=dtypes.int32)
241          ],
242          expected=ragged_const([[]], ragged_rank=1, dtype=dtypes.string)),
243      dict(
244          testcase_name='HashedZeroBucketsDefaultKey',
245          inputs=[
246              ragged_const([['batch1-FC1-F1']]),
247              ragged_const([['batch1-FC2-F1']]),
248              ragged_const([['batch1-FC3-F1']])
249          ],
250          expected_hashed=ragged_const([[1971693436396284976]])),
251      dict(
252          testcase_name='Hashed100BucketsDefaultKey',
253          inputs=[
254              ragged_const([['batch1-FC1-F1']]),
255              ragged_const([['batch1-FC2-F1']]),
256              ragged_const([['batch1-FC3-F1']])
257          ],
258          num_buckets=100,
259          expected_hashed=ragged_const([[83]])),
260      dict(
261          testcase_name='HashedZeroBucketsCustomKey',
262          inputs=[
263              ragged_const([['batch1-FC1-F1']]),
264              ragged_const([['batch1-FC2-F1']]),
265              ragged_const([['batch1-FC3-F1']])
266          ],
267          hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
268          expected_hashed=ragged_const([[4847552627144134031]])),
269      dict(
270          testcase_name='Hashed100BucketsCustomKey',
271          inputs=[
272              ragged_const([['batch1-FC1-F1']]),
273              ragged_const([['batch1-FC2-F1']]),
274              ragged_const([['batch1-FC3-F1']])
275          ],
276          num_buckets=100,
277          hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
278          expected_hashed=ragged_const([[31]])),
279      dict(
280          testcase_name='HashedZeroKey',
281          inputs=[
282              ragged_const([['batch1-FC1-F1']]),
283              ragged_const([['batch1-FC2-F1']]),
284              ragged_const([['batch1-FC3-F1']])
285          ],
286          hash_key=0,
287          expected_hashed=ragged_const([[9077905385164735582]]),
288          matches_sparse_cross=False  # sparse treats hash_key=0 as None.
289      ),
290      dict(
291          testcase_name='UInt64',
292          inputs=[ragged_const([[2**64 - 1]], dtype=dtypes.uint64)],
293          expected=ragged_const([[b'-1']])),
294  ])
295  def testRaggedCross(self,
296                      inputs,
297                      num_buckets=0,
298                      hash_key=None,
299                      expected=None,
300                      expected_hashed=None,
301                      matches_sparse_cross=True):
302    ragged_cross = ragged_array_ops.cross(inputs)
303    ragged_cross_hashed = ragged_array_ops.cross_hashed(inputs, num_buckets,
304                                                        hash_key)
305
306    if expected is not None:
307      self.assertAllEqual(ragged_cross, expected)
308    if expected_hashed is not None:
309      self.assertAllEqual(ragged_cross_hashed, expected_hashed)
310
311    if matches_sparse_cross:
312      # Check that ragged.cross & sparse.cross match.
313      sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
314      sparse_cross = sparse_ops.sparse_cross(sparse_inputs)
315      self.assertAllEqual(ragged_cross,
316                          ragged_tensor.RaggedTensor.from_sparse(sparse_cross))
317
318      # Check that ragged.cross_hashed & sparse.cross_hashed match.
319      sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
320      sparse_cross_hashed = sparse_ops.sparse_cross_hashed(
321          sparse_inputs, num_buckets, hash_key)
322      self.assertAllEqual(
323          ragged_cross_hashed,
324          ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
325
326  def testRaggedCrossLargeBatch(self):
327    batch_size = 5000
328    inputs = [
329        ragged_const([[1, 2, 3]] * batch_size),
330        ragged_const([[b'4']] * batch_size),
331        dense_const([[5]] * batch_size),
332        sparse_const([[6, 7]] * batch_size)
333    ]
334
335    expected = [[
336        b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7',
337        b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7'
338    ]] * batch_size
339
340    ragged_cross = ragged_array_ops.cross(inputs)
341
342    # Note: we don't use assertAllEqual here because if they don't match,
343    # then the code in assertAllEqual that tries to build the error message
344    # is very slow, causing the test to timeout.
345    # pylint: disable=g-generic-assert
346    self.assertTrue(self.evaluate(ragged_cross).to_list() == expected)
347
348  @parameterized.named_parameters([
349      dict(
350          testcase_name='BadDType',
351          inputs=[ragged_const([[1.1], [2.2, 3.3]])],
352          message=r'Unexpected dtype for inputs\[0\]'),
353      dict(
354          testcase_name='StaticBatchSizeMismatch1',
355          inputs=[ragged_const([[1]]),
356                  ragged_const([[2], [3]])],
357          exception=(ValueError, errors.InvalidArgumentError),
358          message='inputs must all have the same batch dimension size'),
359      dict(
360          testcase_name='StaticBatchSizeMismatch2',
361          inputs=[ragged_const([[1]]),
362                  dense_const([[2], [3]])],
363          exception=(ValueError, errors.InvalidArgumentError),
364          message='inputs must all have the same batch dimension size'),
365  ])
366  def testStaticError(self, inputs, exception=ValueError, message=None):
367    with self.assertRaisesRegex(exception, message):
368      ragged_array_ops.cross(inputs)
369
370  @parameterized.named_parameters([
371      dict(
372          testcase_name='3DRaggedTensor',
373          inputs=[ragged_const([[[1]]], ragged_rank=1)],
374          message='tf.ragged.cross only supports inputs with rank=2'),
375      dict(
376          testcase_name='3DDenseTensor',
377          inputs=[dense_const([[[1]]])],
378          message='tf.ragged.cross only supports inputs with rank=2'),
379  ])
380  def testRuntimeError(self,
381                       inputs,
382                       exception=errors.InvalidArgumentError,
383                       message=None):
384    with self.assertRaisesRegex(exception, message):
385      self.evaluate(ragged_array_ops.cross(inputs))
386
387  def _ragged_to_sparse(self, t):
388    if ragged_tensor.is_ragged(t):
389      return ragged_tensor.convert_to_tensor_or_ragged_tensor(t).to_sparse()
390    elif sparse_tensor.is_sparse(t):
391      return sparse_tensor.SparseTensor.from_value(t)
392    else:
393      return ops.convert_to_tensor(t)
394
395
396if __name__ == '__main__':
397  googletest.main()
398