• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for metrics_utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import test_util
26from tensorflow.python.keras import combinations
27from tensorflow.python.keras.utils import metrics_utils
28from tensorflow.python.ops import script_ops
29from tensorflow.python.ops.ragged import ragged_factory_ops
30from tensorflow.python.ops.ragged import ragged_tensor
31from tensorflow.python.platform import test
32
33
34@combinations.generate(combinations.combine(mode=['graph', 'eager']))
35class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
36
37  @parameterized.parameters([
38      {
39          'x_list': [1],
40          'y_list': [2]
41      },
42      {
43          'x_list': [1, 2],
44          'y_list': [2, 3]
45      },
46      {
47          'x_list': [1, 2, 4],
48          'y_list': [2, 3, 5]
49      },
50      {
51          'x_list': [[1, 2], [3, 4]],
52          'y_list': [[2, 3], [5, 6]]
53      },
54  ])
55  def test_passing_dense_tensors(self, x_list, y_list):
56    x = constant_op.constant(x_list)
57    y = constant_op.constant(y_list)
58    [x,
59     y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
60    x.shape.assert_is_compatible_with(y.shape)
61
62  @parameterized.parameters([
63      {
64          'x_list': [1],
65      },
66      {
67          'x_list': [1, 2],
68      },
69      {
70          'x_list': [1, 2, 4],
71      },
72      {
73          'x_list': [[1, 2], [3, 4]],
74      },
75  ])
76  def test_passing_one_dense_tensor(self, x_list):
77    x = constant_op.constant(x_list)
78    [x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
79
80  @parameterized.parameters([
81      {
82          'x_list': [1],
83          'y_list': [2]
84      },
85      {
86          'x_list': [1, 2],
87          'y_list': [2, 3]
88      },
89      {
90          'x_list': [1, 2, 4],
91          'y_list': [2, 3, 5]
92      },
93      {
94          'x_list': [[1, 2], [3, 4]],
95          'y_list': [[2, 3], [5, 6]]
96      },
97      {
98          'x_list': [[1, 2], [3, 4], [1]],
99          'y_list': [[2, 3], [5, 6], [3]]
100      },
101      {
102          'x_list': [[1, 2], [], [1]],
103          'y_list': [[2, 3], [], [3]]
104      },
105  ])
106  def test_passing_both_ragged(self, x_list, y_list):
107    x = ragged_factory_ops.constant(x_list)
108    y = ragged_factory_ops.constant(y_list)
109    [x,
110     y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
111    x.shape.assert_is_compatible_with(y.shape)
112
113  @parameterized.parameters([
114      {
115          'x_list': [1],
116      },
117      {
118          'x_list': [1, 2],
119      },
120      {
121          'x_list': [1, 2, 4],
122      },
123      {
124          'x_list': [[1, 2], [3, 4]],
125      },
126      {
127          'x_list': [[1, 2], [3, 4], [1]],
128      },
129      {
130          'x_list': [[1, 2], [], [1]],
131      },
132  ])
133  def test_passing_one_ragged(self, x_list):
134    x = ragged_factory_ops.constant(x_list)
135    [x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x])
136
137  @parameterized.parameters([
138      {
139          'x_list': [1],
140          'y_list': [2],
141          'mask_list': [0]
142      },
143      {
144          'x_list': [1, 2],
145          'y_list': [2, 3],
146          'mask_list': [0, 1]
147      },
148      {
149          'x_list': [1, 2, 4],
150          'y_list': [2, 3, 5],
151          'mask_list': [1, 1, 1]
152      },
153      {
154          'x_list': [[1, 2], [3, 4]],
155          'y_list': [[2, 3], [5, 6]],
156          'mask_list': [[1, 1], [0, 1]]
157      },
158      {
159          'x_list': [[1, 2], [3, 4], [1]],
160          'y_list': [[2, 3], [5, 6], [3]],
161          'mask_list': [[1, 1], [0, 0], [1]]
162      },
163      {
164          'x_list': [[1, 2], [], [1]],
165          'y_list': [[2, 3], [], [3]],
166          'mask_list': [[1, 1], [], [0]]
167      },
168  ])
169  def test_passing_both_ragged_with_mask(self, x_list, y_list, mask_list):
170    x = ragged_factory_ops.constant(x_list)
171    y = ragged_factory_ops.constant(y_list)
172    mask = ragged_factory_ops.constant(mask_list)
173    [x, y], mask = \
174        metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y], mask)
175    x.shape.assert_is_compatible_with(y.shape)
176    y.shape.assert_is_compatible_with(mask.shape)
177
178  @parameterized.parameters([
179      {
180          'x_list': [1],
181          'mask_list': [0]
182      },
183      {
184          'x_list': [1, 2],
185          'mask_list': [0, 1]
186      },
187      {
188          'x_list': [1, 2, 4],
189          'mask_list': [1, 1, 1]
190      },
191      {
192          'x_list': [[1, 2], [3, 4]],
193          'mask_list': [[1, 1], [0, 1]]
194      },
195      {
196          'x_list': [[1, 2], [3, 4], [1]],
197          'mask_list': [[1, 1], [0, 0], [1]]
198      },
199      {
200          'x_list': [[1, 2], [], [1]],
201          'mask_list': [[1, 1], [], [0]]
202      },
203  ])
204  def test_passing_one_ragged_with_mask(self, x_list, mask_list):
205    x = ragged_factory_ops.constant(x_list)
206    mask = ragged_factory_ops.constant(mask_list)
207    [x], mask = \
208        metrics_utils.ragged_assert_compatible_and_get_flat_values([x], mask)
209    x.shape.assert_is_compatible_with(mask.shape)
210
211  @parameterized.parameters([
212      {
213          'x_list': [[[1, 3]]],
214          'y_list': [[2, 3]]
215      },
216  ])
217  def test_failing_different_ragged_and_dense_ranks(self, x_list, y_list):
218    x = ragged_factory_ops.constant(x_list)
219    y = ragged_factory_ops.constant(y_list)
220    with self.assertRaises(ValueError):  # pylint: disable=g-error-prone-assert-raises
221      [x, y
222      ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
223
224  @parameterized.parameters([
225      {
226          'x_list': [[[1, 3]]],
227          'y_list': [[[2, 3]]],
228          'mask_list': [[0, 1]]
229      },
230  ])
231  def test_failing_different_mask_ranks(self, x_list, y_list, mask_list):
232    x = ragged_factory_ops.constant(x_list)
233    y = ragged_factory_ops.constant(y_list)
234    mask = ragged_factory_ops.constant(mask_list)
235    with self.assertRaises(ValueError):  # pylint: disable=g-error-prone-assert-raises
236      [x, y
237      ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y],
238                                                                        mask)
239
240  # we do not support such cases that ragged_ranks are different but overall
241  # dimension shapes and sizes are identical due to adding too much performance
242  # overheads to the overall use cases.
243  def test_failing_different_ragged_ranks(self):
244    dt = constant_op.constant([[[1, 2]]])
245    # adding a ragged dimension
246    x = ragged_tensor.RaggedTensor.from_row_splits(dt, row_splits=[0, 1])
247    y = ragged_factory_ops.constant([[[[1, 2]]]])
248    with self.assertRaises(ValueError):  # pylint: disable=g-error-prone-assert-raises
249      [x, y], _ = \
250          metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y])
251
252
253@combinations.generate(combinations.combine(mode=['graph', 'eager']))
254class FilterTopKTest(test_util.TensorFlowTestCase, parameterized.TestCase):
255
256  def test_one_dimensional(self):
257    x = constant_op.constant([.3, .1, .2, -.5, 42.])
258    top_1 = self.evaluate(metrics_utils._filter_top_k(x=x, k=1))
259    top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2))
260    top_3 = self.evaluate(metrics_utils._filter_top_k(x=x, k=3))
261
262    self.assertAllClose(top_1, [
263        metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF,
264        metrics_utils.NEG_INF, 42.
265    ])
266    self.assertAllClose(top_2, [
267        .3, metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF,
268        42.
269    ])
270    self.assertAllClose(
271        top_3, [.3, metrics_utils.NEG_INF, .2, metrics_utils.NEG_INF, 42.])
272
273  def test_three_dimensional(self):
274    x = constant_op.constant([[[.3, .1, .2], [-.3, -.2, -.1]],
275                              [[5., .2, 42.], [-.3, -.6, -.99]]])
276    top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2))
277
278    self.assertAllClose(
279        top_2,
280        [[[.3, metrics_utils.NEG_INF, .2], [metrics_utils.NEG_INF, -.2, -.1]],
281         [[5., metrics_utils.NEG_INF, 42.], [-.3, -.6, metrics_utils.NEG_INF]]])
282
283  def test_handles_dynamic_shapes(self):
284    # See b/150281686.  # GOOGLE_INTERNAL
285
286    def _identity(x):
287      return x
288
289    def _filter_top_k(x):
290      # This loses the static shape.
291      x = script_ops.numpy_function(_identity, (x,), dtypes.float32)
292
293      return metrics_utils._filter_top_k(x=x, k=2)
294
295    x = constant_op.constant([.3, .1, .2, -.5, 42.])
296    top_2 = self.evaluate(_filter_top_k(x))
297    self.assertAllClose(top_2, [
298        .3, metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF,
299        42.
300    ])
301
302
303if __name__ == '__main__':
304  test.main()
305