• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ragged_range op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22from absl.testing import parameterized
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops.ragged import ragged_factory_ops
29from tensorflow.python.ops.ragged import ragged_math_ops
30from tensorflow.python.ops.ragged import ragged_tensor
31from tensorflow.python.ops.ragged import ragged_test_util
32from tensorflow.python.platform import googletest
33
34
35def prod(values):
36  val = 1
37  for v in values:
38    val *= v
39  return val
40  # return reduce(lambda x, y: x * y, values, 1)
41
42
43def mean(values):
44  return 1.0 * sum(values) / len(values)
45
46
47def sqrt_n(values):
48  return 1.0 * sum(values) / math.sqrt(len(values))
49
50
51@test_util.run_all_in_graph_and_eager_modes
52class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
53                           parameterized.TestCase):
54
55  def expected_value(self, data, segment_ids, num_segments, combiner):
56    """Find the expected value for a call to ragged_segment_<aggregate>.
57
58    Args:
59      data: The input RaggedTensor, expressed as a nested python list.
60      segment_ids: The segment ids, as a python list of ints.
61      num_segments: The number of segments, as a python int.
62      combiner: The Python function used to combine values.
63    Returns:
64      The expected value, as a nested Python list.
65    """
66    self.assertLen(data, len(segment_ids))
67
68    # Build an empty (num_segments x ncols) "grouped" matrix
69    ncols = max(len(row) for row in data)
70    grouped = [[[] for _ in range(ncols)] for row in range(num_segments)]
71
72    # Append values from data[row] to grouped[segment_ids[row]]
73    for row in range(len(data)):
74      for col in range(len(data[row])):
75        grouped[segment_ids[row]][col].append(data[row][col])
76
77    # Combine the values.
78    return [[combiner(values)
79             for values in grouped_row
80             if values]
81            for grouped_row in grouped]
82
83  @parameterized.parameters(
84      (ragged_math_ops.segment_sum, sum, [0, 0, 1, 1, 2, 2]),
85      (ragged_math_ops.segment_sum, sum, [0, 0, 0, 1, 1, 1]),
86      (ragged_math_ops.segment_sum, sum, [5, 4, 3, 2, 1, 0]),
87      (ragged_math_ops.segment_sum, sum, [0, 0, 0, 10, 10, 10]),
88      (ragged_math_ops.segment_prod, prod, [0, 0, 1, 1, 2, 2]),
89      (ragged_math_ops.segment_prod, prod, [0, 0, 0, 1, 1, 1]),
90      (ragged_math_ops.segment_prod, prod, [5, 4, 3, 2, 1, 0]),
91      (ragged_math_ops.segment_prod, prod, [0, 0, 0, 10, 10, 10]),
92      (ragged_math_ops.segment_min, min, [0, 0, 1, 1, 2, 2]),
93      (ragged_math_ops.segment_min, min, [0, 0, 0, 1, 1, 1]),
94      (ragged_math_ops.segment_min, min, [5, 4, 3, 2, 1, 0]),
95      (ragged_math_ops.segment_min, min, [0, 0, 0, 10, 10, 10]),
96      (ragged_math_ops.segment_max, max, [0, 0, 1, 1, 2, 2]),
97      (ragged_math_ops.segment_max, max, [0, 0, 0, 1, 1, 1]),
98      (ragged_math_ops.segment_max, max, [5, 4, 3, 2, 1, 0]),
99      (ragged_math_ops.segment_max, max, [0, 0, 0, 10, 10, 10]),
100      (ragged_math_ops.segment_mean, mean, [0, 0, 1, 1, 2, 2]),
101      (ragged_math_ops.segment_mean, mean, [0, 0, 0, 1, 1, 1]),
102      (ragged_math_ops.segment_mean, mean, [5, 4, 3, 2, 1, 0]),
103      (ragged_math_ops.segment_mean, mean, [0, 0, 0, 10, 10, 10]),
104  )
105  def testRaggedSegment_Int(self, segment_op, combiner, segment_ids):
106    rt_as_list = [[0, 1, 2, 3], [4], [], [5, 6], [7], [8, 9]]
107    rt = ragged_factory_ops.constant(rt_as_list)
108    num_segments = max(segment_ids) + 1
109    expected = self.expected_value(rt_as_list, segment_ids, num_segments,
110                                   combiner)
111
112    segmented = segment_op(rt, segment_ids, num_segments)
113    self.assertRaggedEqual(segmented, expected)
114
115  @parameterized.parameters(
116      (ragged_math_ops.segment_sum, sum, [0, 0, 1, 1, 2, 2]),
117      (ragged_math_ops.segment_sum, sum, [0, 0, 0, 1, 1, 1]),
118      (ragged_math_ops.segment_sum, sum, [5, 4, 3, 2, 1, 0]),
119      (ragged_math_ops.segment_sum, sum, [0, 0, 0, 10, 10, 10]),
120      (ragged_math_ops.segment_prod, prod, [0, 0, 1, 1, 2, 2]),
121      (ragged_math_ops.segment_prod, prod, [0, 0, 0, 1, 1, 1]),
122      (ragged_math_ops.segment_prod, prod, [5, 4, 3, 2, 1, 0]),
123      (ragged_math_ops.segment_prod, prod, [0, 0, 0, 10, 10, 10]),
124      (ragged_math_ops.segment_min, min, [0, 0, 1, 1, 2, 2]),
125      (ragged_math_ops.segment_min, min, [0, 0, 0, 1, 1, 1]),
126      (ragged_math_ops.segment_min, min, [5, 4, 3, 2, 1, 0]),
127      (ragged_math_ops.segment_min, min, [0, 0, 0, 10, 10, 10]),
128      (ragged_math_ops.segment_max, max, [0, 0, 1, 1, 2, 2]),
129      (ragged_math_ops.segment_max, max, [0, 0, 0, 1, 1, 1]),
130      (ragged_math_ops.segment_max, max, [5, 4, 3, 2, 1, 0]),
131      (ragged_math_ops.segment_max, max, [0, 0, 0, 10, 10, 10]),
132      (ragged_math_ops.segment_mean, mean, [0, 0, 1, 1, 2, 2]),
133      (ragged_math_ops.segment_mean, mean, [0, 0, 0, 1, 1, 1]),
134      (ragged_math_ops.segment_mean, mean, [5, 4, 3, 2, 1, 0]),
135      (ragged_math_ops.segment_mean, mean, [0, 0, 0, 10, 10, 10]),
136      (ragged_math_ops.segment_sqrt_n, sqrt_n, [0, 0, 1, 1, 2, 2]),
137      (ragged_math_ops.segment_sqrt_n, sqrt_n, [0, 0, 0, 1, 1, 1]),
138      (ragged_math_ops.segment_sqrt_n, sqrt_n, [5, 4, 3, 2, 1, 0]),
139      (ragged_math_ops.segment_sqrt_n, sqrt_n, [0, 0, 0, 10, 10, 10]),
140  )
141  def testRaggedSegment_Float(self, segment_op, combiner, segment_ids):
142    rt_as_list = [[0., 1., 2., 3.], [4.], [], [5., 6.], [7.], [8., 9.]]
143    rt = ragged_factory_ops.constant(rt_as_list)
144    num_segments = max(segment_ids) + 1
145    expected = self.expected_value(rt_as_list, segment_ids, num_segments,
146                                   combiner)
147
148    segmented = segment_op(rt, segment_ids, num_segments)
149    self.assertRaggedAlmostEqual(segmented, expected, places=5)
150
151  def testRaggedRankTwo(self):
152    rt = ragged_factory_ops.constant([
153        [[111, 112, 113, 114], [121],],  # row 0
154        [],                              # row 1
155        [[], [321, 322], [331]],         # row 2
156        [[411, 412]]                     # row 3
157    ])  # pyformat: disable
158    segment_ids1 = [0, 2, 2, 2]
159    segmented1 = ragged_math_ops.segment_sum(rt, segment_ids1, 3)
160    expected1 = [[[111, 112, 113, 114], [121]],     # row 0
161                 [],                                # row 1
162                 [[411, 412], [321, 322], [331]]    # row 2
163                ]  # pyformat: disable
164    self.assertRaggedEqual(segmented1, expected1)
165
166    segment_ids2 = [1, 2, 1, 1]
167    segmented2 = ragged_math_ops.segment_sum(rt, segment_ids2, 3)
168    expected2 = [[],
169                 [[111+411, 112+412, 113, 114], [121+321, 322], [331]],
170                 []]  # pyformat: disable
171    self.assertRaggedEqual(segmented2, expected2)
172
173  def testRaggedSegmentIds(self):
174    rt = ragged_factory_ops.constant([
175        [[111, 112, 113, 114], [121],],  # row 0
176        [],                              # row 1
177        [[], [321, 322], [331]],         # row 2
178        [[411, 412]]                     # row 3
179    ])  # pyformat: disable
180    segment_ids = ragged_factory_ops.constant([[1, 2], [], [1, 1, 2], [2]])
181    segmented = ragged_math_ops.segment_sum(rt, segment_ids, 3)
182    expected = [[],
183                [111+321, 112+322, 113, 114],
184                [121+331+411, 412]]  # pyformat: disable
185    self.assertRaggedEqual(segmented, expected)
186
187  def testShapeMismatchError1(self):
188    dt = constant_op.constant([1, 2, 3, 4, 5, 6])
189    segment_ids = ragged_factory_ops.constant([[1, 2], []])
190    self.assertRaisesRegexp(
191        ValueError, 'segment_ids.shape must be a prefix of data.shape, '
192        'but segment_ids is ragged and data is not.',
193        ragged_math_ops.segment_sum, dt, segment_ids, 3)
194
195  def testShapeMismatchError2(self):
196    rt = ragged_factory_ops.constant([
197        [[111, 112, 113, 114], [121]],  # row 0
198        [],                             # row 1
199        [[], [321, 322], [331]],        # row 2
200        [[411, 412]]                    # row 3
201    ])  # pyformat: disable
202    segment_ids = ragged_factory_ops.constant([[1, 2], [1], [1, 1, 2], [2]])
203
204    # Error is raised at graph-building time if we can detect it then.
205    self.assertRaisesRegexp(
206        errors.InvalidArgumentError,
207        'segment_ids.shape must be a prefix of data.shape.*',
208        ragged_math_ops.segment_sum, rt, segment_ids, 3)
209
210    # Otherwise, error is raised when we run the graph.
211    segment_ids2 = ragged_tensor.RaggedTensor.from_row_splits(
212        array_ops.placeholder_with_default(segment_ids.values, None),
213        array_ops.placeholder_with_default(segment_ids.row_splits, None))
214    with self.assertRaisesRegexp(
215        errors.InvalidArgumentError,
216        'segment_ids.shape must be a prefix of data.shape.*'):
217      self.evaluate(ragged_math_ops.segment_sum(rt, segment_ids2, 3))
218
219
220if __name__ == '__main__':
221  googletest.main()
222