1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.ragged.cross and tf.ragged.cross_hashed.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23import numpy as np 24 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import sparse_ops 31from tensorflow.python.ops.ragged import ragged_array_ops 32from tensorflow.python.ops.ragged import ragged_factory_ops 33from tensorflow.python.ops.ragged import ragged_tensor 34from tensorflow.python.platform import googletest 35 36ragged_const = ragged_factory_ops.constant_value 37dense_const = np.array 38 39 40def sparse_const(matrix): 41 indices = [] 42 values = [] 43 for i, row in enumerate(matrix): 44 for j, val in enumerate(row): 45 indices.append([i, j]) 46 values.append(val) 47 shape = [len(matrix), max(len(row) for row in matrix)] if matrix else [0, 0] 48 if not values: 49 indices = np.zeros([0, 2], dtype=np.int64) 50 values = np.zeros([0], dtype=np.int64) 51 return sparse_tensor.SparseTensorValue(indices, values, shape) 52 53 54@test_util.run_all_in_graph_and_eager_modes 55class RaggedCrossOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): 56 57 @parameterized.named_parameters([ 58 dict( 59 testcase_name='NoInputs', 60 inputs=[], 61 expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)), 62 dict( 63 testcase_name='OneInput_RaggedStr', 64 inputs=[ragged_const([['a', 'b'], [], ['c']])], 65 expected=ragged_const([[b'a', b'b'], [], [b'c']])), 66 dict( 67 testcase_name='OneInput_RaggedInt', 68 inputs=[ragged_const([[1, 2, 3], [4, 5]])], 69 expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5']])), 70 dict( 71 testcase_name='OneInput_DenseInt', 72 inputs=[dense_const([[1, 2, 3], [4, 5, 6]])], 73 expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5', b'6']])), 74 dict( 75 testcase_name='OneInput_SparseStr', 76 inputs=[sparse_const([['a', 'b'], [], ['c']])], 77 expected=ragged_const([[b'a', b'b'], [], [b'c']])), 78 dict( 79 testcase_name='TwoInputs_RaggedStr_RaggedStr', 80 inputs=[ 81 ragged_const([['a', 'b'], [], ['c']]), 82 ragged_const([['d', 'e'], ['f'], ['g']]) 83 ], 84 expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [], 85 [b'c_X_g']])), 86 dict( 87 testcase_name='TwoInputs_RaggedInt_RaggedInt', 88 inputs=[ 89 ragged_const([[1, 2], [], [3]]), 90 ragged_const([[4, 5, 6], [], [7]]) 91 ], 92 expected=ragged_const( 93 [[b'1_X_4', b'1_X_5', b'1_X_6', b'2_X_4', b'2_X_5', b'2_X_6'], [], 94 [b'3_X_7']])), 95 dict( 96 testcase_name='TwoInputs_RaggedStr_RaggedInt', 97 inputs=[ 98 ragged_const([['a', 'b'], [], ['c']]), 99 ragged_const([['1', '2'], ['3'], ['4']]) 100 ], 101 expected=ragged_const([[b'a_X_1', b'a_X_2', b'b_X_1', b'b_X_2'], [], 102 [b'c_X_4']])), 103 dict( 104 testcase_name='TwoInputs_SparseStr_SparseStr', 105 inputs=[ 106 sparse_const([['a', 'b'], [], ['c']]), 107 sparse_const([['d', 'e'], ['f'], ['g']]) 108 ], 109 expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [], 110 [b'c_X_g']])), 111 dict( 112 testcase_name='TwoInputs_DenseInt_DenseInt', 113 inputs=[dense_const([[1, 2], [3, 4]]), 114 dense_const([[5, 6], [7, 8]])], 115 expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'], 116 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])), 117 dict( 118 testcase_name='TwoInputs_DenseInt_DenseStr', 119 inputs=[ 120 dense_const([[1, 2], [3, 4]]), 121 dense_const([[b'5', b'6'], [b'7', b'8']]) 122 ], 123 expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'], 124 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])), 125 dict( 126 testcase_name='TwoInputs_RaggedInt_DenseInt', 127 inputs=[ 128 ragged_const([[], [], [1, 2], [3]]), 129 dense_const([[1, 2], [3, 4], [5, 6], [7, 8]]) 130 ], 131 expected=ragged_const([[], [], 132 [b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'], 133 [b'3_X_7', b'3_X_8']])), 134 dict( 135 # This test exercises `input_order`. 136 testcase_name='TwoInputs_DenseInt_RaggedStr', 137 inputs=[ 138 dense_const([[1, 2], [3, 4], [5, 6]]), 139 ragged_const([['d', 'e'], ['f'], ['g']]) 140 ], 141 expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'], 142 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]), 143 matches_sparse_cross=False # sparse doesn't preserve input order. 144 ), 145 dict( 146 # This test exercises `input_order`. 147 testcase_name='TwoInputs_SparseInt_RaggedStr', 148 inputs=[ 149 sparse_const([[1, 2], [3, 4], [5, 6]]), 150 ragged_const([['d', 'e'], ['f'], ['g']]) 151 ], 152 expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'], 153 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]), 154 matches_sparse_cross=False # sparse doesn't preserve input order. 155 ), 156 dict( 157 testcase_name='ThreeInputs_RaggedInt_RaggedInt_RaggedInt', 158 inputs=[ 159 ragged_const([[11], [12, 13], [], [14, 15]]), 160 ragged_const([[21, 22], [23], [24, 25], [26, 27]]), 161 ragged_const([[31], [32, 33], [34, 35], [36, 37]]) 162 ], 163 expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'], 164 [ 165 b'12_X_23_X_32', b'12_X_23_X_33', 166 b'13_X_23_X_32', b'13_X_23_X_33' 167 ], [], 168 [ 169 b'14_X_26_X_36', b'14_X_26_X_37', 170 b'14_X_27_X_36', b'14_X_27_X_37', 171 b'15_X_26_X_36', b'15_X_26_X_37', 172 b'15_X_27_X_36', b'15_X_27_X_37' 173 ]])), 174 dict( 175 testcase_name='ThreeInputs_RaggedInt_SparseInt_DenseInt', 176 inputs=[ 177 ragged_const([[11], [12, 13], [], [14, 15]]), 178 sparse_const([[21, 22], [23], [24, 25], [26, 27]]), 179 dense_const([[31], [32], [33], [34]]) 180 ], 181 expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'], 182 [ 183 b'12_X_23_X_32', 184 b'13_X_23_X_32', 185 ], [], 186 [ 187 b'14_X_26_X_34', 188 b'14_X_27_X_34', 189 b'15_X_26_X_34', 190 b'15_X_27_X_34', 191 ]])), 192 dict( 193 testcase_name='FiveInputs', 194 inputs=[ 195 ragged_const([[1]]), 196 dense_const([[2]]), 197 ragged_const([[3]]), 198 sparse_const([[4]]), 199 ragged_const([[5]]) 200 ], 201 expected=ragged_const([[b'1_X_2_X_3_X_4_X_5']]), 202 matches_sparse_cross=False # sparse doesn't preserve input order. 203 ), 204 dict( 205 testcase_name='Permutation_3x3x3', 206 inputs=[[['11', '12', '13']], [['21', '22', '23']], 207 [['31', '32', '33']]], 208 expected=[[ 209 b'11_X_21_X_31', b'11_X_21_X_32', b'11_X_21_X_33', 210 b'11_X_22_X_31', b'11_X_22_X_32', b'11_X_22_X_33', 211 b'11_X_23_X_31', b'11_X_23_X_32', b'11_X_23_X_33', 212 b'12_X_21_X_31', b'12_X_21_X_32', b'12_X_21_X_33', 213 b'12_X_22_X_31', b'12_X_22_X_32', b'12_X_22_X_33', 214 b'12_X_23_X_31', b'12_X_23_X_32', b'12_X_23_X_33', 215 b'13_X_21_X_31', b'13_X_21_X_32', b'13_X_21_X_33', 216 b'13_X_22_X_31', b'13_X_22_X_32', b'13_X_22_X_33', 217 b'13_X_23_X_31', b'13_X_23_X_32', b'13_X_23_X_33' 218 ]]), 219 dict( 220 testcase_name='BatchSizeZero', 221 inputs=[ 222 ragged_const([], ragged_rank=1, dtype=dtypes.int32), 223 sparse_const([]), 224 np.zeros([0, 3], dtype=np.int32), 225 ], 226 expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)), 227 dict( 228 testcase_name='ThreeInputs_OneEmpty', 229 inputs=[ 230 ragged_const([[1, 2]]), 231 ragged_const([[]], dtype=dtypes.int32), 232 ragged_const([[3, 4]]) 233 ], 234 expected=ragged_const([[]], dtype=dtypes.string)), 235 dict( 236 testcase_name='ThreeInputs_AllEmpty', 237 inputs=[ 238 ragged_const([[]], dtype=dtypes.int64), 239 ragged_const([[]], dtype=dtypes.string), 240 ragged_const([[]], dtype=dtypes.int32) 241 ], 242 expected=ragged_const([[]], ragged_rank=1, dtype=dtypes.string)), 243 dict( 244 testcase_name='HashedZeroBucketsDefaultKey', 245 inputs=[ 246 ragged_const([['batch1-FC1-F1']]), 247 ragged_const([['batch1-FC2-F1']]), 248 ragged_const([['batch1-FC3-F1']]) 249 ], 250 expected_hashed=ragged_const([[1971693436396284976]])), 251 dict( 252 testcase_name='Hashed100BucketsDefaultKey', 253 inputs=[ 254 ragged_const([['batch1-FC1-F1']]), 255 ragged_const([['batch1-FC2-F1']]), 256 ragged_const([['batch1-FC3-F1']]) 257 ], 258 num_buckets=100, 259 expected_hashed=ragged_const([[83]])), 260 dict( 261 testcase_name='HashedZeroBucketsCustomKey', 262 inputs=[ 263 ragged_const([['batch1-FC1-F1']]), 264 ragged_const([['batch1-FC2-F1']]), 265 ragged_const([['batch1-FC3-F1']]) 266 ], 267 hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1, 268 expected_hashed=ragged_const([[4847552627144134031]])), 269 dict( 270 testcase_name='Hashed100BucketsCustomKey', 271 inputs=[ 272 ragged_const([['batch1-FC1-F1']]), 273 ragged_const([['batch1-FC2-F1']]), 274 ragged_const([['batch1-FC3-F1']]) 275 ], 276 num_buckets=100, 277 hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1, 278 expected_hashed=ragged_const([[31]])), 279 dict( 280 testcase_name='HashedZeroKey', 281 inputs=[ 282 ragged_const([['batch1-FC1-F1']]), 283 ragged_const([['batch1-FC2-F1']]), 284 ragged_const([['batch1-FC3-F1']]) 285 ], 286 hash_key=0, 287 expected_hashed=ragged_const([[9077905385164735582]]), 288 matches_sparse_cross=False # sparse treats hash_key=0 as None. 289 ), 290 dict( 291 testcase_name='UInt64', 292 inputs=[ragged_const([[2**64 - 1]], dtype=dtypes.uint64)], 293 expected=ragged_const([[b'-1']])), 294 ]) 295 def testRaggedCross(self, 296 inputs, 297 num_buckets=0, 298 hash_key=None, 299 expected=None, 300 expected_hashed=None, 301 matches_sparse_cross=True): 302 ragged_cross = ragged_array_ops.cross(inputs) 303 ragged_cross_hashed = ragged_array_ops.cross_hashed(inputs, num_buckets, 304 hash_key) 305 306 if expected is not None: 307 self.assertAllEqual(ragged_cross, expected) 308 if expected_hashed is not None: 309 self.assertAllEqual(ragged_cross_hashed, expected_hashed) 310 311 if matches_sparse_cross: 312 # Check that ragged.cross & sparse.cross match. 313 sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] 314 sparse_cross = sparse_ops.sparse_cross(sparse_inputs) 315 self.assertAllEqual(ragged_cross, 316 ragged_tensor.RaggedTensor.from_sparse(sparse_cross)) 317 318 # Check that ragged.cross_hashed & sparse.cross_hashed match. 319 sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] 320 sparse_cross_hashed = sparse_ops.sparse_cross_hashed( 321 sparse_inputs, num_buckets, hash_key) 322 self.assertAllEqual( 323 ragged_cross_hashed, 324 ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed)) 325 326 def testRaggedCrossLargeBatch(self): 327 batch_size = 5000 328 inputs = [ 329 ragged_const([[1, 2, 3]] * batch_size), 330 ragged_const([[b'4']] * batch_size), 331 dense_const([[5]] * batch_size), 332 sparse_const([[6, 7]] * batch_size) 333 ] 334 335 expected = [[ 336 b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7', 337 b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7' 338 ]] * batch_size 339 340 ragged_cross = ragged_array_ops.cross(inputs) 341 342 # Note: we don't use assertAllEqual here because if they don't match, 343 # then the code in assertAllEqual that tries to build the error message 344 # is very slow, causing the test to timeout. 345 # pylint: disable=g-generic-assert 346 self.assertTrue(self.evaluate(ragged_cross).to_list() == expected) 347 348 @parameterized.named_parameters([ 349 dict( 350 testcase_name='BadDType', 351 inputs=[ragged_const([[1.1], [2.2, 3.3]])], 352 message=r'Unexpected dtype for inputs\[0\]'), 353 dict( 354 testcase_name='StaticBatchSizeMismatch1', 355 inputs=[ragged_const([[1]]), 356 ragged_const([[2], [3]])], 357 exception=(ValueError, errors.InvalidArgumentError), 358 message='inputs must all have the same batch dimension size'), 359 dict( 360 testcase_name='StaticBatchSizeMismatch2', 361 inputs=[ragged_const([[1]]), 362 dense_const([[2], [3]])], 363 exception=(ValueError, errors.InvalidArgumentError), 364 message='inputs must all have the same batch dimension size'), 365 ]) 366 def testStaticError(self, inputs, exception=ValueError, message=None): 367 with self.assertRaisesRegex(exception, message): 368 ragged_array_ops.cross(inputs) 369 370 @parameterized.named_parameters([ 371 dict( 372 testcase_name='3DRaggedTensor', 373 inputs=[ragged_const([[[1]]], ragged_rank=1)], 374 message='tf.ragged.cross only supports inputs with rank=2'), 375 dict( 376 testcase_name='3DDenseTensor', 377 inputs=[dense_const([[[1]]])], 378 message='tf.ragged.cross only supports inputs with rank=2'), 379 ]) 380 def testRuntimeError(self, 381 inputs, 382 exception=errors.InvalidArgumentError, 383 message=None): 384 with self.assertRaisesRegex(exception, message): 385 self.evaluate(ragged_array_ops.cross(inputs)) 386 387 def _ragged_to_sparse(self, t): 388 if ragged_tensor.is_ragged(t): 389 return ragged_tensor.convert_to_tensor_or_ragged_tensor(t).to_sparse() 390 elif sparse_tensor.is_sparse(t): 391 return sparse_tensor.SparseTensor.from_value(t) 392 else: 393 return ops.convert_to_tensor(t) 394 395 396if __name__ == '__main__': 397 googletest.main() 398