• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for RaggedTensor operator dispatch."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import clip_ops
32from tensorflow.python.ops import data_flow_ops
33from tensorflow.python.ops import gen_bitwise_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import parsing_ops
36from tensorflow.python.ops import string_ops
37from tensorflow.python.ops.ragged import ragged_dispatch
38from tensorflow.python.ops.ragged import ragged_factory_ops
39from tensorflow.python.ops.ragged import ragged_tensor
40from tensorflow.python.platform import googletest
41
42# Constants listing various op types to test.  Each operation
43# should be included in at least one list below, or tested separately if
44# necessary (e.g., because it expects additional arguments).
45UNARY_FLOAT_OPS = [
46    math_ops.abs,
47    math_ops.acos,
48    math_ops.acosh,
49    math_ops.angle,
50    math_ops.asin,
51    math_ops.asinh,
52    math_ops.atan,
53    math_ops.atanh,
54    math_ops.ceil,
55    math_ops.conj,
56    math_ops.cos,
57    math_ops.cosh,
58    math_ops.digamma,
59    math_ops.erf,
60    math_ops.erfc,
61    math_ops.erfinv,
62    math_ops.exp,
63    math_ops.expm1,
64    math_ops.floor,
65    math_ops.imag,
66    math_ops.is_finite,
67    math_ops.is_inf,
68    math_ops.is_nan,
69    math_ops.lgamma,
70    math_ops.log,
71    math_ops.log1p,
72    math_ops.log_sigmoid,
73    math_ops.ndtri,
74    math_ops.negative,
75    math_ops.real,
76    math_ops.reciprocal,
77    math_ops.rint,
78    math_ops.round,
79    math_ops.rsqrt,
80    math_ops.sign,
81    math_ops.sin,
82    math_ops.sinh,
83    math_ops.sqrt,
84    math_ops.square,
85    math_ops.tan,
86    array_ops.identity,
87    array_ops.ones_like,
88    array_ops.zeros_like,
89]
90UNARY_BOOL_OPS = [
91    math_ops.logical_not,
92]
93UNARY_STRING_OPS = [
94    string_ops.decode_base64,
95    string_ops.encode_base64,
96    string_ops.string_strip,
97    parsing_ops.decode_compressed,
98]
99BINARY_FLOAT_OPS = [
100    math_ops.add,
101    math_ops.atan2,
102    math_ops.complex,
103    math_ops.div_no_nan,
104    math_ops.divide,
105    math_ops.equal,
106    math_ops.floordiv,
107    math_ops.floormod,
108    math_ops.greater,
109    math_ops.greater_equal,
110    math_ops.less,
111    math_ops.less_equal,
112    math_ops.maximum,
113    math_ops.minimum,
114    math_ops.multiply,
115    math_ops.not_equal,
116    math_ops.pow,
117    math_ops.realdiv,
118    math_ops.squared_difference,
119    math_ops.subtract,
120    math_ops.truediv,
121]
122BINARY_BOOL_OPS = [
123    math_ops.logical_and,
124    math_ops.logical_or,
125    math_ops.logical_xor,
126]
127UNARY_INT_OPS = [
128    gen_bitwise_ops.invert,
129    string_ops.unicode_script,
130]
131BINARY_INT_OPS = [
132    gen_bitwise_ops.bitwise_and,
133    gen_bitwise_ops.bitwise_or,
134    gen_bitwise_ops.bitwise_xor,
135    gen_bitwise_ops.left_shift,
136    gen_bitwise_ops.right_shift,
137    math_ops.truncatediv,
138    math_ops.truncatemod,
139]
140
141
142@test_util.run_all_in_graph_and_eager_modes
143class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
144                               parameterized.TestCase):
145
146  def assertSameShape(self, x, y):
147    """Checks that x and y have the same shape (including ragged shapes)."""
148    if isinstance(x, ragged_tensor.RaggedTensor):
149      self.assertIsInstance(y, ragged_tensor.RaggedTensor)
150      self.assertEqual(x.ragged_rank, y.ragged_rank)
151      for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
152        self.assertAllEqual(x_splits, y_splits)
153      self.assertAllEqual(
154          array_ops.shape(x.flat_values), array_ops.shape(y.flat_values))
155    else:
156      self.assertIsInstance(y, ops.Tensor)
157      self.assertAllEqual(array_ops.shape(x), array_ops.shape(y))
158
159  @parameterized.parameters(
160      #=========================================================================
161      # Test different input shapes.
162      #=========================================================================
163      [
164          # 0-dimensional input
165          {'x': 12},
166          # 1-dimensional input
167          {'x': [1, -2, 3]},
168          # 2-dimensional input
169          {'x': [[-2, 3], [-3, 4]]},
170          {'x': ragged_factory_ops.constant_value(
171              [[-2, 3], [-3]], ragged_rank=1)},
172          # 3-dimensional inputs
173          {'x': [[[-2, 3], [3, 4]], [[7, 6], [5, 4]]]},
174          {'x': ragged_factory_ops.constant_value(
175              [[[-2, 3], [3, 4]], [[7, 6]]],
176              ragged_rank=1)},
177          {'x': ragged_factory_ops.constant_value(
178              [[[-2, 3, 4], []], [[7, 6]], []],
179              ragged_rank=2)},
180          ] +
181      #=========================================================================
182      # Test each unary op.
183      #=========================================================================
184      [{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 'op': op}
185       for op in UNARY_FLOAT_OPS] +
186      [{'x': ragged_factory_ops.constant_value([[True, False], [True]]),
187        'op': op}
188       for op in UNARY_BOOL_OPS] +
189      [{'x': ragged_factory_ops.constant_value([[18, 512], [12412]], np.int32),
190        'op': op}
191       for op in UNARY_INT_OPS] +
192      [{'x': ragged_factory_ops.constant_value([['abcd', 'efgh'],
193                                                ['aabbccdd']]),
194        'op': op}
195       for op in UNARY_STRING_OPS] +
196      [
197          {'op': clip_ops.clip_by_value,
198           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
199           'clip_value_min': 0.1, 'clip_value_max': 4.0},
200          {'op': math_ops.cast,
201           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
202           'dtype': dtypes.int32},
203          {'op': math_ops.saturate_cast,
204           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
205           'dtype': dtypes.int32},
206          {'op': string_ops.string_to_hash_bucket,
207           'x': ragged_factory_ops.constant_value(
208               [['abcd', 'efgh'], ['aabbccdd']]),
209           'num_buckets': 1000},
210          {'op': string_ops.string_to_hash_bucket_fast,
211           'x': ragged_factory_ops.constant_value(
212               [['abcd', 'efgh'], ['aabbccdd']]),
213           'num_buckets': 1000},
214          {'op': string_ops.string_to_hash_bucket_strong,
215           'x': ragged_factory_ops.constant_value(
216               [['abcd', 'efgh'], ['aabbccdd']]),
217           'num_buckets': 1000,
218           'key': [1231, 12512]},
219          {'op': string_ops.string_to_number,
220           'x': ragged_factory_ops.constant_value([['-2.0', '3.0'], ['-3.0']])},
221          {'op': string_ops.regex_full_match,
222           'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
223           'pattern': r'\w+'},
224          {'op': string_ops.regex_replace,
225           'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
226           'pattern': r'\d',
227           'rewrite': '#'},
228          {'op': string_ops.substr,
229           'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
230           'pos': 2, 'len': 3},
231          {'op': array_ops.check_numerics,
232           'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
233           'message': 'check-numerics'},
234      ]
235      )  # pyformat: disable
236  def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
237    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
238    result = op(x, **extra_args)
239
240    # Run the wrapped op on the dense values, for comparison.
241    dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
242    expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
243
244    # Check that the result has the expected shape.
245    self.assertSameShape(x, result)
246
247    # Check that the result has the expected (flattened) values.
248    if isinstance(result, ragged_tensor.RaggedTensor):
249      result_flat_values = array_ops.reshape(result.flat_values, [-1])
250    else:
251      result_flat_values = array_ops.reshape(result, [-1])
252    self.assertAllEqual(expected_flat_values, result_flat_values)
253
254  @parameterized.parameters(
255      [
256          #=====================================================================
257          # Without broadcasting -- i.e., shapes match exactly.
258          #=====================================================================
259          # Shapes: x:(), y:()
260          {'x': 12,
261           'y': 8},
262          # Shapes: x:(3,), y:(3,)
263          {'x': [7, 8, 9],
264           'y': [1, -2, 3]},
265          # Shapes: x:(2, 2), y:(2, 2)
266          {'x': [[-2, 3], [-3, -4]],
267           'y': [[1, 2], [3, 4]]},
268          # Shapes: x:(2, None), y:(2, None)
269          {'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
270           'y': ragged_factory_ops.constant_value([[5, 6], [7]])},
271          # Shapes: x:(2, 2, 2), y:(2, 2, 2)
272          {'x': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
273           'y': [[[9, 3], [3, 4]], [[5, 2], [7, 6]]]},
274          # Shapes: x:(2, None, None), y: (2, None, None)
275          {'x': ragged_factory_ops.constant_value(
276              [[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
277           'y': ragged_factory_ops.constant_value(
278               [[[3, 8], [2], [5]], [[], [1, 9, 8]]])},
279          # Shapes: x:(2, None, 2), y: (2, None, 2)
280          {'x': ragged_factory_ops.constant_value(
281              [[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
282              ragged_rank=1),
283           'y': ragged_factory_ops.constant_value(
284               [[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
285               ragged_rank=1)},
286
287          #=====================================================================
288          # With broadcasting
289          #=====================================================================
290          # Shapes: x:(), y:(3,)
291          {'x': 12,                                 # Broadcast () -> (3,)
292           'y': [1, -2, 3]},
293          # Shapes: x:(1,), y:(3,)
294          {'x': [12],                               # Broadcast (1,) -> (3,)
295           'y': [1, -2, 3]},
296          # Shapes: x:(), y:(2, 2)
297          {'x': 12,                                 # Broadcast () -> (2, 2)
298           'y': [[1, 2], [3, 4]]},
299          # Shapes: x:(1,), y:(2, 2)
300          {'x': 12,                                 # Broadcast (1,) -> (2, 2)
301           'y': [[1, 2], [3, 4]]},
302          # Shapes: x:(2, 1), y:(2, 2)
303          {'x': [[10], [20]],                       # Broadcast (2, 1) -> (2, 2)
304           'y': [[1, 2], [3, 4]]},
305          # Shapes: x:(), y:(2, None)
306          {'x': 10,                                 # Broadcast () -> (2, None)
307           'y': ragged_factory_ops.constant_value(
308               [[1, 2], [3]], dtype=np.int32)},
309          # TODO(edloper): Add tests for more advanced broadcasting, once we add
310          # support for it.
311
312          #=====================================================================
313          # Keyword Args
314          #=====================================================================
315          {'x': ragged_factory_ops.constant_value(
316              [[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
317           'y': ragged_factory_ops.constant_value(
318               [[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
319           'use_kwargs': ('x', 'y')},
320          {'x': ragged_factory_ops.constant_value(
321              [[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
322              ragged_rank=1),
323           'y': ragged_factory_ops.constant_value(
324               [[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
325               ragged_rank=1),
326           'use_kwargs': ('x', 'y')},
327          {'x': ragged_factory_ops.constant_value(
328              [[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
329              ragged_rank=1),
330           'y': ragged_factory_ops.constant_value(
331               [[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
332               ragged_rank=1),
333           'use_kwargs': ('x',)},
334      ] +
335      #=========================================================================
336      # Test each unary op.
337      #=========================================================================
338      [{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
339        'y': ragged_factory_ops.constant_value([[5.0, 1.0], [12.0]]),
340        'op': op}
341       for op in BINARY_FLOAT_OPS] +
342      [{'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
343        'y': ragged_factory_ops.constant_value([[5, 1], [12]]),
344        'op': op}
345       for op in BINARY_INT_OPS] +
346      [{'x': ragged_factory_ops.constant_value([[True, True], [False]]),
347        'y': ragged_factory_ops.constant_value([[False, True], [False]]),
348        'op': op}
349       for op in BINARY_BOOL_OPS]
350      )  # pyformat: disable
351  def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
352    use_kwargs = extra_args.pop('use_kwargs', ())
353    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
354    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y)
355    if 'x' in use_kwargs and 'y' in use_kwargs:
356      result = op(x=x, y=y, **extra_args)
357    elif 'y' in use_kwargs:
358      result = op(x, y=y, **extra_args)
359    else:
360      result = op(x, y, **extra_args)
361
362    # Run the wrapped op on the dense values, for comparison.
363    dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
364    dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y
365    expected_flat_values = array_ops.reshape(
366        op(dense_x, dense_y, **extra_args), [-1])
367
368    # Check that the result has the expected shape.
369    self.assertSameShape(y, result)
370
371    # Check that the result has the expected (flattened) values.
372    if isinstance(result, ragged_tensor.RaggedTensor):
373      result_flat_values = array_ops.reshape(result.flat_values, [-1])
374    else:
375      result_flat_values = array_ops.reshape(result, [-1])
376    self.assertAllEqual(expected_flat_values, result_flat_values)
377
378  @parameterized.parameters(
379      [
380          {'inputs': (12, 8, 3)},
381          {'inputs': ([1, 2, 3], [7, 8, 9], [3, 6, 9])},
382          {'inputs': ([[1, 2]], [[3, 4]], [[5, 6]])},
383          {'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
384                      ragged_factory_ops.constant_value([[4, 7], [88]]),
385                      ragged_factory_ops.constant_value([[2, 9], [12]]))},
386          {'inputs': (ragged_factory_ops.constant_value(
387              [[[1, 3], [-3]], [[1]]]),
388                      ragged_factory_ops.constant_value(
389                          [[[4, 7], [88]], [[2]]]),
390                      ragged_factory_ops.constant_value(
391                          [[[2, 9], [12]], [[8]]]))},
392          {'inputs': (
393              ragged_factory_ops.constant_value([[[1, 3], [3, 4]], [[1, 5]]],
394                                                ragged_rank=1),
395              ragged_factory_ops.constant_value([[[4, 7], [1, 2]], [[2, 2]]],
396                                                ragged_rank=1),
397              ragged_factory_ops.constant_value([[[2, 9], [5, 2]], [[8, 0]]],
398                                                ragged_rank=1))},
399          {'inputs': (
400              ragged_factory_ops.constant_value([[[1, 3], [-3]], [[1]]]),
401              ragged_factory_ops.constant_value([[[4, 7], [88]], [[2]]]),
402              ragged_factory_ops.constant_value([[[2, 9], [12]], [[8]]])),
403           'use_kwargs': True},
404      ] + [
405          {'op': math_ops.add_n,
406           'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
407                      ragged_factory_ops.constant_value([[4, 7], [88]]),
408                      ragged_factory_ops.constant_value([[2, 9], [12]]))},
409          {'op': string_ops.string_join,
410           'inputs': (
411               ragged_factory_ops.constant_value([['a', 'b'], ['c']]),
412               ragged_factory_ops.constant_value([['foo', 'bar'], ['baz']]),
413               ragged_factory_ops.constant_value([['2', '9'], ['12']]))},
414      ])  # pyformat: disable
415  def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
416                                  **extra_args):
417    use_kwargs = extra_args.pop('use_kwargs', False)
418    inputs = [
419        ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs
420    ]
421    if use_kwargs:
422      result = op(inputs=inputs, **extra_args)
423    else:
424      result = op(inputs, **extra_args)
425
426    # Run the wrapped op on the dense values, for comparison.
427    dense_inputs = [
428        x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
429        for x in inputs
430    ]
431    expected_flat_values = array_ops.reshape(
432        op(dense_inputs, **extra_args), [-1])
433
434    # Check that the result has the expected shape.
435    self.assertSameShape(inputs[0], result)
436
437    # Check that the result has the expected (flattened) values.
438    if isinstance(result, ragged_tensor.RaggedTensor):
439      result_flat_values = array_ops.reshape(result.flat_values, [-1])
440    else:
441      result_flat_values = array_ops.reshape(result, [-1])
442    self.assertAllEqual(expected_flat_values, result_flat_values)
443
444  def testElementwiseOpUnknownRankError(self):
445    if context.executing_eagerly():
446      return
447    x = ragged_factory_ops.constant([[1, 2], [3]])
448    y = ragged_tensor.RaggedTensor.from_row_splits(
449        array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
450    with self.assertRaisesRegexp(ValueError,
451                                 r'Unable to broadcast: unknown rank'):
452      math_ops.add(x, y)
453
454  @parameterized.parameters([
455      dict(
456          x=ragged_factory_ops.constant_value([[1, 2], [3]]),
457          y=[[10]],
458          expected=[[11, 12], [13]]),
459      dict(
460          x=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5]]],
461                                              ragged_rank=2),
462          y=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
463                                              ragged_rank=1),
464          expected=[[[11, 12], [23, 24]], [[35]]]),
465      dict(
466          x=ragged_factory_ops.constant_value([[[1]]]),
467          y=ragged_factory_ops.constant_value([[1]]),
468          expected=[[[2]]]),
469  ])
470  def testElementwiseOpBroadcast(self, x, y, expected):
471    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
472    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
473    result = x + y
474    self.assertAllEqual(result, expected)
475
476  def testElementwiseOpShapeMismatch(self):
477    x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
478    y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])
479    with self.assertRaises(errors.InvalidArgumentError):
480      self.evaluate(math_ops.add(x, y))
481
482  def testBinaryOpSparseAndRagged(self):
483    x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
484    y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2])
485    with self.assertRaises((TypeError, ValueError)):
486      self.evaluate(math_ops.add(x, y))
487
488    with self.assertRaises((TypeError, ValueError)):
489      self.evaluate(math_ops.add_n([x, y]))
490
491  @parameterized.parameters([
492      dict(
493          op=array_ops.batch_gather,
494          args=(ragged_factory_ops.constant_value([[5, 6, 7], [8, 9]]),
495                ragged_factory_ops.constant_value([[2, 1, 0], [1]])),
496          expected=ragged_factory_ops.constant_value([[7, 6, 5], [9]])),
497      dict(
498          op=array_ops.concat,
499          args=([
500              ragged_factory_ops.constant_value([[1, 2, 3], [4]],
501                                                dtype=np.int32),
502              np.array([[5, 6]], dtype=np.int32)
503          ],),
504          kwargs={'axis': 0},
505          expected=ragged_factory_ops.constant_value([[1, 2, 3], [4], [5, 6]])),
506      dict(
507          op=array_ops.expand_dims,
508          kwargs={
509              'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
510              'axis': 0
511          },
512          expected=ragged_factory_ops.constant_value([[[1, 2], [3]]])),
513      dict(
514          op=array_ops.expand_dims_v2,
515          kwargs={
516              'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
517              'axis': -1
518          },
519          expected=ragged_factory_ops.constant_value([[[1], [2]], [[3]]],
520                                                     ragged_rank=1),
521      ),
522      dict(
523          op=array_ops.gather,
524          kwargs={
525              'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
526              'indices': [1, 0, 1]
527          },
528          expected=ragged_factory_ops.constant_value([[3], [1, 2], [3]])),
529      dict(
530          op=array_ops.gather_v2,
531          kwargs={
532              'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
533              'indices': ragged_factory_ops.constant_value([[1, 0], [1]])
534          },
535          expected=ragged_factory_ops.constant_value([[[3], [1, 2]], [[3]]])),
536      dict(
537          op=array_ops.gather_nd,
538          kwargs={
539              'params': ragged_factory_ops.constant_value([[7, 8], [9]]),
540              'indices': [[0, 1], [1, 0], [0, 0]]
541          },
542          expected=ragged_factory_ops.constant_value([8, 9, 7])),
543      dict(
544          op=array_ops.one_hot,
545          kwargs={
546              'indices':
547                  ragged_factory_ops.constant_value([[1, 2, 3], [0]],
548                                                    dtype=np.int32),
549              'depth':
550                  4,
551              'axis':
552                  1
553          },
554          expected=ragged_factory_ops.constant_value(
555              [[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], [[1, 0, 0, 0]]],
556              ragged_rank=1)),
557      dict(
558          op=array_ops.stack,
559          args=([
560              ragged_factory_ops.constant_value([[1, 2, 3], [4]],
561                                                dtype=np.int32),
562              np.array([[5, 6]], dtype=np.int32)
563          ],),
564          expected=ragged_factory_ops.constant_value([[[1, 2, 3], [4]],
565                                                      [[5, 6]]])),
566      dict(
567          op=array_ops.tile,
568          args=([
569              ragged_factory_ops.constant_value([[1, 2], [3]], dtype=np.int32),
570              [2, 3]
571          ]),
572          expected=ragged_factory_ops.constant_value([[1, 2, 1, 2, 1, 2],
573                                                      [3, 3, 3],
574                                                      [1, 2, 1, 2, 1, 2],
575                                                      [3, 3, 3]])),
576      dict(
577          op=array_ops.where,
578          args=(ragged_factory_ops.constant_value([[True, False], [True]]),
579                ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]),
580                ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])),
581          expected=ragged_factory_ops.constant_value([[b'A', b'b'], [b'C']])),
582      dict(
583          op=array_ops.where,
584          args=(ragged_factory_ops.constant_value([[True, False], [True]]),),
585          expected=[[0, 0], [1, 0]]),
586      dict(
587          op=math_ops.unsorted_segment_sum,
588          kwargs={
589              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
590              'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]),
591              'num_segments': 3
592          },
593          expected=[4, 0, 2]),
594      dict(
595          op=math_ops.unsorted_segment_prod,
596          kwargs={
597              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
598              'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]),
599              'num_segments': 3
600          },
601          expected=[3, 1, 2]),
602      dict(
603          op=math_ops.unsorted_segment_min,
604          kwargs={
605              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
606              'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
607              'num_segments': 2
608          },
609          expected=[1, 2]),
610      dict(
611          op=math_ops.unsorted_segment_max,
612          kwargs={
613              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
614              'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
615              'num_segments': 2
616          },
617          expected=[3, 2]),
618      dict(
619          op=math_ops.unsorted_segment_mean,
620          kwargs={
621              'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
622              'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
623              'num_segments': 2
624          },
625          expected=[2, 2]),
626      dict(
627          op=math_ops.unsorted_segment_sqrt_n,
628          kwargs={
629              'data':
630                  ragged_factory_ops.constant_value([[1.0, 2.0],
631                                                     [3.0, 4.0, 6.0]]),
632              'segment_ids':
633                  ragged_factory_ops.constant_value([[0, 1], [0, 0, 0]]),
634              'num_segments':
635                  2
636          },
637          expected=[7.0, 2.0]),
638      dict(
639          op=math_ops.reduce_sum,
640          kwargs={
641              'input_tensor':
642                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
643              'axis':
644                  1
645          },
646          expected=[3, 12]),
647      dict(
648          op=math_ops.reduce_prod,
649          kwargs={
650              'input_tensor':
651                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
652              'axis':
653                  1
654          },
655          expected=[2, 60]),
656      dict(
657          op=math_ops.reduce_min,
658          kwargs={
659              'input_tensor':
660                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
661              'axis':
662                  1
663          },
664          expected=[1, 3]),
665      dict(
666          op=math_ops.reduce_max,
667          kwargs={
668              'input_tensor':
669                  ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
670              'axis':
671                  1
672          },
673          expected=[2, 5]),
674      dict(
675          op=math_ops.reduce_mean,
676          kwargs={
677              'input_tensor':
678                  ragged_factory_ops.constant_value([[1, 3], [3, 4, 5]]),
679              'axis':
680                  1
681          },
682          expected=[2, 4]),
683      dict(
684          op=math_ops.reduce_any,
685          kwargs={
686              'input_tensor':
687                  ragged_factory_ops.constant_value([[True, False],
688                                                     [True, True, True]]),
689              'axis':
690                  1
691          },
692          expected=[True, True]),
693      dict(
694          op=string_ops.reduce_join,
695          kwargs={
696              'inputs':
697                  ragged_factory_ops.constant_value([[
698                      b'this', b'is', b'a', b'test', b'for', b'ragged',
699                      b'tensors'
700                  ], [b'please', b'do', b'not', b'panic', b'!']]),
701              'axis':
702                  0,
703              'keepdims':
704                  False,
705              'separator':
706                  ''
707          },
708          expected=[
709              b'thisplease', b'isdo', b'anot', b'testpanic', b'for!', b'ragged',
710              b'tensors'
711          ]),
712      dict(
713          op=math_ops.reduce_all,
714          kwargs={
715              'input_tensor':
716                  ragged_factory_ops.constant_value([[True, False],
717                                                     [True, True, True]]),
718              'axis':
719                  1
720          },
721          expected=[False, True]),
722      dict(
723          op=array_ops.rank,
724          kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
725          expected=2),
726      dict(
727          op=array_ops.size,
728          kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
729          expected=3),
730      dict(
731          op=array_ops.size_v2,
732          kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
733          expected=3),
734      dict(
735          op=array_ops.squeeze,
736          kwargs={
737              'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
738              'axis': [0]
739          },
740          expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
741      dict(
742          op=array_ops.squeeze_v2,
743          kwargs={
744              'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
745              'axis': [0]
746          },
747          expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
748      dict(
749          op=data_flow_ops.dynamic_partition,
750          kwargs={
751              'data': ragged_factory_ops.constant_value([[1], [2, 3, 4], [5]]),
752              'partitions': [2, 1, 1],
753              'num_partitions': 3
754          },
755          expected=[
756              ragged_factory_ops.constant_value([], ragged_rank=1),
757              ragged_factory_ops.constant_value([[2, 3, 4], [5]]),
758              ragged_factory_ops.constant_value([[1]])
759          ],
760          result_is_list=True),
761      dict(
762          op=array_ops.reverse,
763          kwargs={
764              'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
765              'axis': [0, -1]
766          },
767          expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]]))
768  ])
769  def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
770                         kwargs=None):
771    if kwargs is None: kwargs = {}
772    result = op(*args, **kwargs)
773    if result_is_list:
774      self.assertLen(result, len(expected))
775      for (r, e) in zip(result, expected):
776        self.assertAllEqual(r, e)
777    else:
778      self.assertAllEqual(result, expected)
779
780  def test_ragged_op_list(self):
781    # Ops that should be listed as supported in both v1 and v2.
782    supported_ops = [
783        'bitwise.bitwise_and', 'bitwise.bitwise_or', 'bitwise.bitwise_xor',
784        'bitwise.invert', 'bitwise.left_shift', 'bitwise.right_shift',
785        'clip_by_value', 'concat', 'debugging.check_numerics', 'cast',
786        'dtypes.complex', 'dtypes.saturate_cast', 'expand_dims', 'gather_nd',
787        'gather', 'identity', 'io.decode_base64', 'io.decode_compressed',
788        'io.encode_base64', 'math.abs', 'math.acos', 'math.acosh', 'math.add_n',
789        'math.add', 'math.angle', 'math.asin', 'math.asinh', 'math.atan2',
790        'math.atan', 'math.atanh', 'math.ceil', 'math.conj', 'math.cos',
791        'math.cosh', 'math.digamma', 'math.divide_no_nan', 'math.divide',
792        'math.equal', 'math.erf', 'math.erfc', 'math.exp', 'math.expm1',
793        'math.floor', 'math.floordiv', 'math.floormod', 'math.greater_equal',
794        'math.greater', 'math.imag', 'math.is_finite', 'math.is_inf',
795        'math.is_nan', 'math.less_equal', 'math.less', 'math.lgamma',
796        'math.log1p', 'math.log_sigmoid', 'math.log', 'math.logical_and',
797        'math.logical_not', 'math.logical_or', 'math.logical_xor',
798        'math.maximum', 'math.minimum', 'math.multiply', 'math.negative',
799        'math.not_equal', 'math.pow', 'math.real', 'math.reciprocal',
800        'math.reduce_any', 'math.reduce_max', 'math.reduce_mean',
801        'math.reduce_min', 'math.reduce_prod', 'math.reduce_sum', 'math.rint',
802        'math.round', 'math.rsqrt', 'math.sign', 'math.sin', 'math.sinh',
803        'math.sqrt', 'math.square', 'math.squared_difference', 'math.subtract',
804        'math.tan', 'math.truediv', 'math.unsorted_segment_max',
805        'math.unsorted_segment_mean', 'math.unsorted_segment_min',
806        'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n',
807        'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv',
808        'reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
809        'strings.join', 'strings.length', 'strings.reduce_join',
810        'strings.regex_full_match', 'strings.regex_replace', 'strings.strip',
811        'strings.substr', 'strings.to_hash_bucket_fast',
812        'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
813        'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
814        'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse'
815    ]
816
817    # Ops that should be listed as supported in v1 only.
818    # TODO(edloper): Add a dispatch for where_v2.
819    supported_ops_v1 = ['batch_gather', 'where']
820
821    # Ops that should be listed as supported in v2 only.
822    supported_ops_v2 = []
823
824    v1_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=1)
825    for element in supported_ops + supported_ops_v1:
826      self.assertIn(element, v1_ragged_ops)
827    for element in supported_ops_v2:
828      self.assertNotIn(element, v1_ragged_ops)
829
830    v2_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=2)
831    for element in supported_ops + supported_ops_v2:
832      self.assertIn(element, v2_ragged_ops)
833    for element in supported_ops_v1:
834      self.assertNotIn(element, v2_ragged_ops)
835
836
837if __name__ == '__main__':
838  googletest.main()
839