1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for RaggedTensor operator dispatch.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy as np 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import clip_ops 32from tensorflow.python.ops import data_flow_ops 33from tensorflow.python.ops import gen_bitwise_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import parsing_ops 36from tensorflow.python.ops import string_ops 37from tensorflow.python.ops.ragged import ragged_dispatch 38from tensorflow.python.ops.ragged import ragged_factory_ops 39from tensorflow.python.ops.ragged import ragged_tensor 40from tensorflow.python.platform import googletest 41 42# Constants listing various op types to test. Each operation 43# should be included in at least one list below, or tested separately if 44# necessary (e.g., because it expects additional arguments). 45UNARY_FLOAT_OPS = [ 46 math_ops.abs, 47 math_ops.acos, 48 math_ops.acosh, 49 math_ops.angle, 50 math_ops.asin, 51 math_ops.asinh, 52 math_ops.atan, 53 math_ops.atanh, 54 math_ops.ceil, 55 math_ops.conj, 56 math_ops.cos, 57 math_ops.cosh, 58 math_ops.digamma, 59 math_ops.erf, 60 math_ops.erfc, 61 math_ops.erfinv, 62 math_ops.exp, 63 math_ops.expm1, 64 math_ops.floor, 65 math_ops.imag, 66 math_ops.is_finite, 67 math_ops.is_inf, 68 math_ops.is_nan, 69 math_ops.lgamma, 70 math_ops.log, 71 math_ops.log1p, 72 math_ops.log_sigmoid, 73 math_ops.ndtri, 74 math_ops.negative, 75 math_ops.real, 76 math_ops.reciprocal, 77 math_ops.rint, 78 math_ops.round, 79 math_ops.rsqrt, 80 math_ops.sign, 81 math_ops.sin, 82 math_ops.sinh, 83 math_ops.sqrt, 84 math_ops.square, 85 math_ops.tan, 86 array_ops.identity, 87 array_ops.ones_like, 88 array_ops.zeros_like, 89] 90UNARY_BOOL_OPS = [ 91 math_ops.logical_not, 92] 93UNARY_STRING_OPS = [ 94 string_ops.decode_base64, 95 string_ops.encode_base64, 96 string_ops.string_strip, 97 parsing_ops.decode_compressed, 98] 99BINARY_FLOAT_OPS = [ 100 math_ops.add, 101 math_ops.atan2, 102 math_ops.complex, 103 math_ops.div_no_nan, 104 math_ops.divide, 105 math_ops.equal, 106 math_ops.floordiv, 107 math_ops.floormod, 108 math_ops.greater, 109 math_ops.greater_equal, 110 math_ops.less, 111 math_ops.less_equal, 112 math_ops.maximum, 113 math_ops.minimum, 114 math_ops.multiply, 115 math_ops.not_equal, 116 math_ops.pow, 117 math_ops.realdiv, 118 math_ops.squared_difference, 119 math_ops.subtract, 120 math_ops.truediv, 121] 122BINARY_BOOL_OPS = [ 123 math_ops.logical_and, 124 math_ops.logical_or, 125 math_ops.logical_xor, 126] 127UNARY_INT_OPS = [ 128 gen_bitwise_ops.invert, 129 string_ops.unicode_script, 130] 131BINARY_INT_OPS = [ 132 gen_bitwise_ops.bitwise_and, 133 gen_bitwise_ops.bitwise_or, 134 gen_bitwise_ops.bitwise_xor, 135 gen_bitwise_ops.left_shift, 136 gen_bitwise_ops.right_shift, 137 math_ops.truncatediv, 138 math_ops.truncatemod, 139] 140 141 142@test_util.run_all_in_graph_and_eager_modes 143class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase, 144 parameterized.TestCase): 145 146 def assertSameShape(self, x, y): 147 """Checks that x and y have the same shape (including ragged shapes).""" 148 if isinstance(x, ragged_tensor.RaggedTensor): 149 self.assertIsInstance(y, ragged_tensor.RaggedTensor) 150 self.assertEqual(x.ragged_rank, y.ragged_rank) 151 for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits): 152 self.assertAllEqual(x_splits, y_splits) 153 self.assertAllEqual( 154 array_ops.shape(x.flat_values), array_ops.shape(y.flat_values)) 155 else: 156 self.assertIsInstance(y, ops.Tensor) 157 self.assertAllEqual(array_ops.shape(x), array_ops.shape(y)) 158 159 @parameterized.parameters( 160 #========================================================================= 161 # Test different input shapes. 162 #========================================================================= 163 [ 164 # 0-dimensional input 165 {'x': 12}, 166 # 1-dimensional input 167 {'x': [1, -2, 3]}, 168 # 2-dimensional input 169 {'x': [[-2, 3], [-3, 4]]}, 170 {'x': ragged_factory_ops.constant_value( 171 [[-2, 3], [-3]], ragged_rank=1)}, 172 # 3-dimensional inputs 173 {'x': [[[-2, 3], [3, 4]], [[7, 6], [5, 4]]]}, 174 {'x': ragged_factory_ops.constant_value( 175 [[[-2, 3], [3, 4]], [[7, 6]]], 176 ragged_rank=1)}, 177 {'x': ragged_factory_ops.constant_value( 178 [[[-2, 3, 4], []], [[7, 6]], []], 179 ragged_rank=2)}, 180 ] + 181 #========================================================================= 182 # Test each unary op. 183 #========================================================================= 184 [{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 'op': op} 185 for op in UNARY_FLOAT_OPS] + 186 [{'x': ragged_factory_ops.constant_value([[True, False], [True]]), 187 'op': op} 188 for op in UNARY_BOOL_OPS] + 189 [{'x': ragged_factory_ops.constant_value([[18, 512], [12412]], np.int32), 190 'op': op} 191 for op in UNARY_INT_OPS] + 192 [{'x': ragged_factory_ops.constant_value([['abcd', 'efgh'], 193 ['aabbccdd']]), 194 'op': op} 195 for op in UNARY_STRING_OPS] + 196 [ 197 {'op': clip_ops.clip_by_value, 198 'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 199 'clip_value_min': 0.1, 'clip_value_max': 4.0}, 200 {'op': math_ops.cast, 201 'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 202 'dtype': dtypes.int32}, 203 {'op': math_ops.saturate_cast, 204 'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 205 'dtype': dtypes.int32}, 206 {'op': string_ops.string_to_hash_bucket, 207 'x': ragged_factory_ops.constant_value( 208 [['abcd', 'efgh'], ['aabbccdd']]), 209 'num_buckets': 1000}, 210 {'op': string_ops.string_to_hash_bucket_fast, 211 'x': ragged_factory_ops.constant_value( 212 [['abcd', 'efgh'], ['aabbccdd']]), 213 'num_buckets': 1000}, 214 {'op': string_ops.string_to_hash_bucket_strong, 215 'x': ragged_factory_ops.constant_value( 216 [['abcd', 'efgh'], ['aabbccdd']]), 217 'num_buckets': 1000, 218 'key': [1231, 12512]}, 219 {'op': string_ops.string_to_number, 220 'x': ragged_factory_ops.constant_value([['-2.0', '3.0'], ['-3.0']])}, 221 {'op': string_ops.regex_full_match, 222 'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]), 223 'pattern': r'\w+'}, 224 {'op': string_ops.regex_replace, 225 'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]), 226 'pattern': r'\d', 227 'rewrite': '#'}, 228 {'op': string_ops.substr, 229 'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]), 230 'pos': 2, 'len': 3}, 231 {'op': array_ops.check_numerics, 232 'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 233 'message': 'check-numerics'}, 234 ] 235 ) # pyformat: disable 236 def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args): 237 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x) 238 result = op(x, **extra_args) 239 240 # Run the wrapped op on the dense values, for comparison. 241 dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x 242 expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1]) 243 244 # Check that the result has the expected shape. 245 self.assertSameShape(x, result) 246 247 # Check that the result has the expected (flattened) values. 248 if isinstance(result, ragged_tensor.RaggedTensor): 249 result_flat_values = array_ops.reshape(result.flat_values, [-1]) 250 else: 251 result_flat_values = array_ops.reshape(result, [-1]) 252 self.assertAllEqual(expected_flat_values, result_flat_values) 253 254 @parameterized.parameters( 255 [ 256 #===================================================================== 257 # Without broadcasting -- i.e., shapes match exactly. 258 #===================================================================== 259 # Shapes: x:(), y:() 260 {'x': 12, 261 'y': 8}, 262 # Shapes: x:(3,), y:(3,) 263 {'x': [7, 8, 9], 264 'y': [1, -2, 3]}, 265 # Shapes: x:(2, 2), y:(2, 2) 266 {'x': [[-2, 3], [-3, -4]], 267 'y': [[1, 2], [3, 4]]}, 268 # Shapes: x:(2, None), y:(2, None) 269 {'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]), 270 'y': ragged_factory_ops.constant_value([[5, 6], [7]])}, 271 # Shapes: x:(2, 2, 2), y:(2, 2, 2) 272 {'x': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 273 'y': [[[9, 3], [3, 4]], [[5, 2], [7, 6]]]}, 274 # Shapes: x:(2, None, None), y: (2, None, None) 275 {'x': ragged_factory_ops.constant_value( 276 [[[1, 2], [3], [4]], [[], [5, 7, 8]]]), 277 'y': ragged_factory_ops.constant_value( 278 [[[3, 8], [2], [5]], [[], [1, 9, 8]]])}, 279 # Shapes: x:(2, None, 2), y: (2, None, 2) 280 {'x': ragged_factory_ops.constant_value( 281 [[[1, 2]], [[3, 4], [5, 6], [7, 8]]], 282 ragged_rank=1), 283 'y': ragged_factory_ops.constant_value( 284 [[[9, 3]], [[5, 2], [3, 4], [7, 6]]], 285 ragged_rank=1)}, 286 287 #===================================================================== 288 # With broadcasting 289 #===================================================================== 290 # Shapes: x:(), y:(3,) 291 {'x': 12, # Broadcast () -> (3,) 292 'y': [1, -2, 3]}, 293 # Shapes: x:(1,), y:(3,) 294 {'x': [12], # Broadcast (1,) -> (3,) 295 'y': [1, -2, 3]}, 296 # Shapes: x:(), y:(2, 2) 297 {'x': 12, # Broadcast () -> (2, 2) 298 'y': [[1, 2], [3, 4]]}, 299 # Shapes: x:(1,), y:(2, 2) 300 {'x': 12, # Broadcast (1,) -> (2, 2) 301 'y': [[1, 2], [3, 4]]}, 302 # Shapes: x:(2, 1), y:(2, 2) 303 {'x': [[10], [20]], # Broadcast (2, 1) -> (2, 2) 304 'y': [[1, 2], [3, 4]]}, 305 # Shapes: x:(), y:(2, None) 306 {'x': 10, # Broadcast () -> (2, None) 307 'y': ragged_factory_ops.constant_value( 308 [[1, 2], [3]], dtype=np.int32)}, 309 # TODO(edloper): Add tests for more advanced broadcasting, once we add 310 # support for it. 311 312 #===================================================================== 313 # Keyword Args 314 #===================================================================== 315 {'x': ragged_factory_ops.constant_value( 316 [[[1, 2], [3], [4]], [[], [5, 7, 8]]]), 317 'y': ragged_factory_ops.constant_value( 318 [[[3, 8], [2], [5]], [[], [1, 9, 8]]]), 319 'use_kwargs': ('x', 'y')}, 320 {'x': ragged_factory_ops.constant_value( 321 [[[1, 2]], [[3, 4], [5, 6], [7, 8]]], 322 ragged_rank=1), 323 'y': ragged_factory_ops.constant_value( 324 [[[9, 3]], [[5, 2], [3, 4], [7, 6]]], 325 ragged_rank=1), 326 'use_kwargs': ('x', 'y')}, 327 {'x': ragged_factory_ops.constant_value( 328 [[[1, 2]], [[3, 4], [5, 6], [7, 8]]], 329 ragged_rank=1), 330 'y': ragged_factory_ops.constant_value( 331 [[[9, 3]], [[5, 2], [3, 4], [7, 6]]], 332 ragged_rank=1), 333 'use_kwargs': ('x',)}, 334 ] + 335 #========================================================================= 336 # Test each unary op. 337 #========================================================================= 338 [{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 339 'y': ragged_factory_ops.constant_value([[5.0, 1.0], [12.0]]), 340 'op': op} 341 for op in BINARY_FLOAT_OPS] + 342 [{'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]), 343 'y': ragged_factory_ops.constant_value([[5, 1], [12]]), 344 'op': op} 345 for op in BINARY_INT_OPS] + 346 [{'x': ragged_factory_ops.constant_value([[True, True], [False]]), 347 'y': ragged_factory_ops.constant_value([[False, True], [False]]), 348 'op': op} 349 for op in BINARY_BOOL_OPS] 350 ) # pyformat: disable 351 def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args): 352 use_kwargs = extra_args.pop('use_kwargs', ()) 353 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x) 354 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y) 355 if 'x' in use_kwargs and 'y' in use_kwargs: 356 result = op(x=x, y=y, **extra_args) 357 elif 'y' in use_kwargs: 358 result = op(x, y=y, **extra_args) 359 else: 360 result = op(x, y, **extra_args) 361 362 # Run the wrapped op on the dense values, for comparison. 363 dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x 364 dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y 365 expected_flat_values = array_ops.reshape( 366 op(dense_x, dense_y, **extra_args), [-1]) 367 368 # Check that the result has the expected shape. 369 self.assertSameShape(y, result) 370 371 # Check that the result has the expected (flattened) values. 372 if isinstance(result, ragged_tensor.RaggedTensor): 373 result_flat_values = array_ops.reshape(result.flat_values, [-1]) 374 else: 375 result_flat_values = array_ops.reshape(result, [-1]) 376 self.assertAllEqual(expected_flat_values, result_flat_values) 377 378 @parameterized.parameters( 379 [ 380 {'inputs': (12, 8, 3)}, 381 {'inputs': ([1, 2, 3], [7, 8, 9], [3, 6, 9])}, 382 {'inputs': ([[1, 2]], [[3, 4]], [[5, 6]])}, 383 {'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]), 384 ragged_factory_ops.constant_value([[4, 7], [88]]), 385 ragged_factory_ops.constant_value([[2, 9], [12]]))}, 386 {'inputs': (ragged_factory_ops.constant_value( 387 [[[1, 3], [-3]], [[1]]]), 388 ragged_factory_ops.constant_value( 389 [[[4, 7], [88]], [[2]]]), 390 ragged_factory_ops.constant_value( 391 [[[2, 9], [12]], [[8]]]))}, 392 {'inputs': ( 393 ragged_factory_ops.constant_value([[[1, 3], [3, 4]], [[1, 5]]], 394 ragged_rank=1), 395 ragged_factory_ops.constant_value([[[4, 7], [1, 2]], [[2, 2]]], 396 ragged_rank=1), 397 ragged_factory_ops.constant_value([[[2, 9], [5, 2]], [[8, 0]]], 398 ragged_rank=1))}, 399 {'inputs': ( 400 ragged_factory_ops.constant_value([[[1, 3], [-3]], [[1]]]), 401 ragged_factory_ops.constant_value([[[4, 7], [88]], [[2]]]), 402 ragged_factory_ops.constant_value([[[2, 9], [12]], [[8]]])), 403 'use_kwargs': True}, 404 ] + [ 405 {'op': math_ops.add_n, 406 'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]), 407 ragged_factory_ops.constant_value([[4, 7], [88]]), 408 ragged_factory_ops.constant_value([[2, 9], [12]]))}, 409 {'op': string_ops.string_join, 410 'inputs': ( 411 ragged_factory_ops.constant_value([['a', 'b'], ['c']]), 412 ragged_factory_ops.constant_value([['foo', 'bar'], ['baz']]), 413 ragged_factory_ops.constant_value([['2', '9'], ['12']]))}, 414 ]) # pyformat: disable 415 def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n, 416 **extra_args): 417 use_kwargs = extra_args.pop('use_kwargs', False) 418 inputs = [ 419 ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs 420 ] 421 if use_kwargs: 422 result = op(inputs=inputs, **extra_args) 423 else: 424 result = op(inputs, **extra_args) 425 426 # Run the wrapped op on the dense values, for comparison. 427 dense_inputs = [ 428 x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x 429 for x in inputs 430 ] 431 expected_flat_values = array_ops.reshape( 432 op(dense_inputs, **extra_args), [-1]) 433 434 # Check that the result has the expected shape. 435 self.assertSameShape(inputs[0], result) 436 437 # Check that the result has the expected (flattened) values. 438 if isinstance(result, ragged_tensor.RaggedTensor): 439 result_flat_values = array_ops.reshape(result.flat_values, [-1]) 440 else: 441 result_flat_values = array_ops.reshape(result, [-1]) 442 self.assertAllEqual(expected_flat_values, result_flat_values) 443 444 def testElementwiseOpUnknownRankError(self): 445 if context.executing_eagerly(): 446 return 447 x = ragged_factory_ops.constant([[1, 2], [3]]) 448 y = ragged_tensor.RaggedTensor.from_row_splits( 449 array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits) 450 with self.assertRaisesRegexp(ValueError, 451 r'Unable to broadcast: unknown rank'): 452 math_ops.add(x, y) 453 454 @parameterized.parameters([ 455 dict( 456 x=ragged_factory_ops.constant_value([[1, 2], [3]]), 457 y=[[10]], 458 expected=[[11, 12], [13]]), 459 dict( 460 x=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5]]], 461 ragged_rank=2), 462 y=ragged_factory_ops.constant_value([[[10], [20]], [[30]]], 463 ragged_rank=1), 464 expected=[[[11, 12], [23, 24]], [[35]]]), 465 dict( 466 x=ragged_factory_ops.constant_value([[[1]]]), 467 y=ragged_factory_ops.constant_value([[1]]), 468 expected=[[[2]]]), 469 ]) 470 def testElementwiseOpBroadcast(self, x, y, expected): 471 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32) 472 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32) 473 result = x + y 474 self.assertAllEqual(result, expected) 475 476 def testElementwiseOpShapeMismatch(self): 477 x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]]) 478 y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]]) 479 with self.assertRaises(errors.InvalidArgumentError): 480 self.evaluate(math_ops.add(x, y)) 481 482 def testBinaryOpSparseAndRagged(self): 483 x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]]) 484 y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2]) 485 with self.assertRaises((TypeError, ValueError)): 486 self.evaluate(math_ops.add(x, y)) 487 488 with self.assertRaises((TypeError, ValueError)): 489 self.evaluate(math_ops.add_n([x, y])) 490 491 @parameterized.parameters([ 492 dict( 493 op=array_ops.batch_gather, 494 args=(ragged_factory_ops.constant_value([[5, 6, 7], [8, 9]]), 495 ragged_factory_ops.constant_value([[2, 1, 0], [1]])), 496 expected=ragged_factory_ops.constant_value([[7, 6, 5], [9]])), 497 dict( 498 op=array_ops.concat, 499 args=([ 500 ragged_factory_ops.constant_value([[1, 2, 3], [4]], 501 dtype=np.int32), 502 np.array([[5, 6]], dtype=np.int32) 503 ],), 504 kwargs={'axis': 0}, 505 expected=ragged_factory_ops.constant_value([[1, 2, 3], [4], [5, 6]])), 506 dict( 507 op=array_ops.expand_dims, 508 kwargs={ 509 'input': ragged_factory_ops.constant_value([[1, 2], [3]]), 510 'axis': 0 511 }, 512 expected=ragged_factory_ops.constant_value([[[1, 2], [3]]])), 513 dict( 514 op=array_ops.expand_dims_v2, 515 kwargs={ 516 'input': ragged_factory_ops.constant_value([[1, 2], [3]]), 517 'axis': -1 518 }, 519 expected=ragged_factory_ops.constant_value([[[1], [2]], [[3]]], 520 ragged_rank=1), 521 ), 522 dict( 523 op=array_ops.gather, 524 kwargs={ 525 'params': ragged_factory_ops.constant_value([[1, 2], [3]]), 526 'indices': [1, 0, 1] 527 }, 528 expected=ragged_factory_ops.constant_value([[3], [1, 2], [3]])), 529 dict( 530 op=array_ops.gather_v2, 531 kwargs={ 532 'params': ragged_factory_ops.constant_value([[1, 2], [3]]), 533 'indices': ragged_factory_ops.constant_value([[1, 0], [1]]) 534 }, 535 expected=ragged_factory_ops.constant_value([[[3], [1, 2]], [[3]]])), 536 dict( 537 op=array_ops.gather_nd, 538 kwargs={ 539 'params': ragged_factory_ops.constant_value([[7, 8], [9]]), 540 'indices': [[0, 1], [1, 0], [0, 0]] 541 }, 542 expected=ragged_factory_ops.constant_value([8, 9, 7])), 543 dict( 544 op=array_ops.one_hot, 545 kwargs={ 546 'indices': 547 ragged_factory_ops.constant_value([[1, 2, 3], [0]], 548 dtype=np.int32), 549 'depth': 550 4, 551 'axis': 552 1 553 }, 554 expected=ragged_factory_ops.constant_value( 555 [[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], [[1, 0, 0, 0]]], 556 ragged_rank=1)), 557 dict( 558 op=array_ops.stack, 559 args=([ 560 ragged_factory_ops.constant_value([[1, 2, 3], [4]], 561 dtype=np.int32), 562 np.array([[5, 6]], dtype=np.int32) 563 ],), 564 expected=ragged_factory_ops.constant_value([[[1, 2, 3], [4]], 565 [[5, 6]]])), 566 dict( 567 op=array_ops.tile, 568 args=([ 569 ragged_factory_ops.constant_value([[1, 2], [3]], dtype=np.int32), 570 [2, 3] 571 ]), 572 expected=ragged_factory_ops.constant_value([[1, 2, 1, 2, 1, 2], 573 [3, 3, 3], 574 [1, 2, 1, 2, 1, 2], 575 [3, 3, 3]])), 576 dict( 577 op=array_ops.where, 578 args=(ragged_factory_ops.constant_value([[True, False], [True]]), 579 ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]), 580 ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])), 581 expected=ragged_factory_ops.constant_value([[b'A', b'b'], [b'C']])), 582 dict( 583 op=array_ops.where, 584 args=(ragged_factory_ops.constant_value([[True, False], [True]]),), 585 expected=[[0, 0], [1, 0]]), 586 dict( 587 op=math_ops.unsorted_segment_sum, 588 kwargs={ 589 'data': ragged_factory_ops.constant_value([[1, 2], [3]]), 590 'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]), 591 'num_segments': 3 592 }, 593 expected=[4, 0, 2]), 594 dict( 595 op=math_ops.unsorted_segment_prod, 596 kwargs={ 597 'data': ragged_factory_ops.constant_value([[1, 2], [3]]), 598 'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]), 599 'num_segments': 3 600 }, 601 expected=[3, 1, 2]), 602 dict( 603 op=math_ops.unsorted_segment_min, 604 kwargs={ 605 'data': ragged_factory_ops.constant_value([[1, 2], [3]]), 606 'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]), 607 'num_segments': 2 608 }, 609 expected=[1, 2]), 610 dict( 611 op=math_ops.unsorted_segment_max, 612 kwargs={ 613 'data': ragged_factory_ops.constant_value([[1, 2], [3]]), 614 'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]), 615 'num_segments': 2 616 }, 617 expected=[3, 2]), 618 dict( 619 op=math_ops.unsorted_segment_mean, 620 kwargs={ 621 'data': ragged_factory_ops.constant_value([[1, 2], [3]]), 622 'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]), 623 'num_segments': 2 624 }, 625 expected=[2, 2]), 626 dict( 627 op=math_ops.unsorted_segment_sqrt_n, 628 kwargs={ 629 'data': 630 ragged_factory_ops.constant_value([[1.0, 2.0], 631 [3.0, 4.0, 6.0]]), 632 'segment_ids': 633 ragged_factory_ops.constant_value([[0, 1], [0, 0, 0]]), 634 'num_segments': 635 2 636 }, 637 expected=[7.0, 2.0]), 638 dict( 639 op=math_ops.reduce_sum, 640 kwargs={ 641 'input_tensor': 642 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]), 643 'axis': 644 1 645 }, 646 expected=[3, 12]), 647 dict( 648 op=math_ops.reduce_prod, 649 kwargs={ 650 'input_tensor': 651 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]), 652 'axis': 653 1 654 }, 655 expected=[2, 60]), 656 dict( 657 op=math_ops.reduce_min, 658 kwargs={ 659 'input_tensor': 660 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]), 661 'axis': 662 1 663 }, 664 expected=[1, 3]), 665 dict( 666 op=math_ops.reduce_max, 667 kwargs={ 668 'input_tensor': 669 ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]), 670 'axis': 671 1 672 }, 673 expected=[2, 5]), 674 dict( 675 op=math_ops.reduce_mean, 676 kwargs={ 677 'input_tensor': 678 ragged_factory_ops.constant_value([[1, 3], [3, 4, 5]]), 679 'axis': 680 1 681 }, 682 expected=[2, 4]), 683 dict( 684 op=math_ops.reduce_any, 685 kwargs={ 686 'input_tensor': 687 ragged_factory_ops.constant_value([[True, False], 688 [True, True, True]]), 689 'axis': 690 1 691 }, 692 expected=[True, True]), 693 dict( 694 op=string_ops.reduce_join, 695 kwargs={ 696 'inputs': 697 ragged_factory_ops.constant_value([[ 698 b'this', b'is', b'a', b'test', b'for', b'ragged', 699 b'tensors' 700 ], [b'please', b'do', b'not', b'panic', b'!']]), 701 'axis': 702 0, 703 'keepdims': 704 False, 705 'separator': 706 '' 707 }, 708 expected=[ 709 b'thisplease', b'isdo', b'anot', b'testpanic', b'for!', b'ragged', 710 b'tensors' 711 ]), 712 dict( 713 op=math_ops.reduce_all, 714 kwargs={ 715 'input_tensor': 716 ragged_factory_ops.constant_value([[True, False], 717 [True, True, True]]), 718 'axis': 719 1 720 }, 721 expected=[False, True]), 722 dict( 723 op=array_ops.rank, 724 kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])}, 725 expected=2), 726 dict( 727 op=array_ops.size, 728 kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])}, 729 expected=3), 730 dict( 731 op=array_ops.size_v2, 732 kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])}, 733 expected=3), 734 dict( 735 op=array_ops.squeeze, 736 kwargs={ 737 'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]), 738 'axis': [0] 739 }, 740 expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])), 741 dict( 742 op=array_ops.squeeze_v2, 743 kwargs={ 744 'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]), 745 'axis': [0] 746 }, 747 expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])), 748 dict( 749 op=data_flow_ops.dynamic_partition, 750 kwargs={ 751 'data': ragged_factory_ops.constant_value([[1], [2, 3, 4], [5]]), 752 'partitions': [2, 1, 1], 753 'num_partitions': 3 754 }, 755 expected=[ 756 ragged_factory_ops.constant_value([], ragged_rank=1), 757 ragged_factory_ops.constant_value([[2, 3, 4], [5]]), 758 ragged_factory_ops.constant_value([[1]]) 759 ], 760 result_is_list=True), 761 dict( 762 op=array_ops.reverse, 763 kwargs={ 764 'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]), 765 'axis': [0, -1] 766 }, 767 expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])) 768 ]) 769 def testRaggedDispatch(self, op, expected, args=(), result_is_list=False, 770 kwargs=None): 771 if kwargs is None: kwargs = {} 772 result = op(*args, **kwargs) 773 if result_is_list: 774 self.assertLen(result, len(expected)) 775 for (r, e) in zip(result, expected): 776 self.assertAllEqual(r, e) 777 else: 778 self.assertAllEqual(result, expected) 779 780 def test_ragged_op_list(self): 781 # Ops that should be listed as supported in both v1 and v2. 782 supported_ops = [ 783 'bitwise.bitwise_and', 'bitwise.bitwise_or', 'bitwise.bitwise_xor', 784 'bitwise.invert', 'bitwise.left_shift', 'bitwise.right_shift', 785 'clip_by_value', 'concat', 'debugging.check_numerics', 'cast', 786 'dtypes.complex', 'dtypes.saturate_cast', 'expand_dims', 'gather_nd', 787 'gather', 'identity', 'io.decode_base64', 'io.decode_compressed', 788 'io.encode_base64', 'math.abs', 'math.acos', 'math.acosh', 'math.add_n', 789 'math.add', 'math.angle', 'math.asin', 'math.asinh', 'math.atan2', 790 'math.atan', 'math.atanh', 'math.ceil', 'math.conj', 'math.cos', 791 'math.cosh', 'math.digamma', 'math.divide_no_nan', 'math.divide', 792 'math.equal', 'math.erf', 'math.erfc', 'math.exp', 'math.expm1', 793 'math.floor', 'math.floordiv', 'math.floormod', 'math.greater_equal', 794 'math.greater', 'math.imag', 'math.is_finite', 'math.is_inf', 795 'math.is_nan', 'math.less_equal', 'math.less', 'math.lgamma', 796 'math.log1p', 'math.log_sigmoid', 'math.log', 'math.logical_and', 797 'math.logical_not', 'math.logical_or', 'math.logical_xor', 798 'math.maximum', 'math.minimum', 'math.multiply', 'math.negative', 799 'math.not_equal', 'math.pow', 'math.real', 'math.reciprocal', 800 'math.reduce_any', 'math.reduce_max', 'math.reduce_mean', 801 'math.reduce_min', 'math.reduce_prod', 'math.reduce_sum', 'math.rint', 802 'math.round', 'math.rsqrt', 'math.sign', 'math.sin', 'math.sinh', 803 'math.sqrt', 'math.square', 'math.squared_difference', 'math.subtract', 804 'math.tan', 'math.truediv', 'math.unsorted_segment_max', 805 'math.unsorted_segment_mean', 'math.unsorted_segment_min', 806 'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n', 807 'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv', 808 'reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string', 809 'strings.join', 'strings.length', 'strings.reduce_join', 810 'strings.regex_full_match', 'strings.regex_replace', 'strings.strip', 811 'strings.substr', 'strings.to_hash_bucket_fast', 812 'strings.to_hash_bucket_strong', 'strings.to_hash_bucket', 813 'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv', 814 'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse' 815 ] 816 817 # Ops that should be listed as supported in v1 only. 818 # TODO(edloper): Add a dispatch for where_v2. 819 supported_ops_v1 = ['batch_gather', 'where'] 820 821 # Ops that should be listed as supported in v2 only. 822 supported_ops_v2 = [] 823 824 v1_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=1) 825 for element in supported_ops + supported_ops_v1: 826 self.assertIn(element, v1_ragged_ops) 827 for element in supported_ops_v2: 828 self.assertNotIn(element, v1_ragged_ops) 829 830 v2_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=2) 831 for element in supported_ops + supported_ops_v2: 832 self.assertIn(element, v2_ragged_ops) 833 for element in supported_ops_v1: 834 self.assertNotIn(element, v2_ragged_ops) 835 836 837if __name__ == '__main__': 838 googletest.main() 839