1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for metrics_utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import test_util 26from tensorflow.python.keras import combinations 27from tensorflow.python.keras.utils import metrics_utils 28from tensorflow.python.ops import script_ops 29from tensorflow.python.ops.ragged import ragged_factory_ops 30from tensorflow.python.ops.ragged import ragged_tensor 31from tensorflow.python.platform import test 32 33 34@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 35class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): 36 37 @parameterized.parameters([ 38 { 39 'x_list': [1], 40 'y_list': [2] 41 }, 42 { 43 'x_list': [1, 2], 44 'y_list': [2, 3] 45 }, 46 { 47 'x_list': [1, 2, 4], 48 'y_list': [2, 3, 5] 49 }, 50 { 51 'x_list': [[1, 2], [3, 4]], 52 'y_list': [[2, 3], [5, 6]] 53 }, 54 ]) 55 def test_passing_dense_tensors(self, x_list, y_list): 56 x = constant_op.constant(x_list) 57 y = constant_op.constant(y_list) 58 [x, 59 y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y]) 60 x.shape.assert_is_compatible_with(y.shape) 61 62 @parameterized.parameters([ 63 { 64 'x_list': [1], 65 }, 66 { 67 'x_list': [1, 2], 68 }, 69 { 70 'x_list': [1, 2, 4], 71 }, 72 { 73 'x_list': [[1, 2], [3, 4]], 74 }, 75 ]) 76 def test_passing_one_dense_tensor(self, x_list): 77 x = constant_op.constant(x_list) 78 [x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x]) 79 80 @parameterized.parameters([ 81 { 82 'x_list': [1], 83 'y_list': [2] 84 }, 85 { 86 'x_list': [1, 2], 87 'y_list': [2, 3] 88 }, 89 { 90 'x_list': [1, 2, 4], 91 'y_list': [2, 3, 5] 92 }, 93 { 94 'x_list': [[1, 2], [3, 4]], 95 'y_list': [[2, 3], [5, 6]] 96 }, 97 { 98 'x_list': [[1, 2], [3, 4], [1]], 99 'y_list': [[2, 3], [5, 6], [3]] 100 }, 101 { 102 'x_list': [[1, 2], [], [1]], 103 'y_list': [[2, 3], [], [3]] 104 }, 105 ]) 106 def test_passing_both_ragged(self, x_list, y_list): 107 x = ragged_factory_ops.constant(x_list) 108 y = ragged_factory_ops.constant(y_list) 109 [x, 110 y], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y]) 111 x.shape.assert_is_compatible_with(y.shape) 112 113 @parameterized.parameters([ 114 { 115 'x_list': [1], 116 }, 117 { 118 'x_list': [1, 2], 119 }, 120 { 121 'x_list': [1, 2, 4], 122 }, 123 { 124 'x_list': [[1, 2], [3, 4]], 125 }, 126 { 127 'x_list': [[1, 2], [3, 4], [1]], 128 }, 129 { 130 'x_list': [[1, 2], [], [1]], 131 }, 132 ]) 133 def test_passing_one_ragged(self, x_list): 134 x = ragged_factory_ops.constant(x_list) 135 [x], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x]) 136 137 @parameterized.parameters([ 138 { 139 'x_list': [1], 140 'y_list': [2], 141 'mask_list': [0] 142 }, 143 { 144 'x_list': [1, 2], 145 'y_list': [2, 3], 146 'mask_list': [0, 1] 147 }, 148 { 149 'x_list': [1, 2, 4], 150 'y_list': [2, 3, 5], 151 'mask_list': [1, 1, 1] 152 }, 153 { 154 'x_list': [[1, 2], [3, 4]], 155 'y_list': [[2, 3], [5, 6]], 156 'mask_list': [[1, 1], [0, 1]] 157 }, 158 { 159 'x_list': [[1, 2], [3, 4], [1]], 160 'y_list': [[2, 3], [5, 6], [3]], 161 'mask_list': [[1, 1], [0, 0], [1]] 162 }, 163 { 164 'x_list': [[1, 2], [], [1]], 165 'y_list': [[2, 3], [], [3]], 166 'mask_list': [[1, 1], [], [0]] 167 }, 168 ]) 169 def test_passing_both_ragged_with_mask(self, x_list, y_list, mask_list): 170 x = ragged_factory_ops.constant(x_list) 171 y = ragged_factory_ops.constant(y_list) 172 mask = ragged_factory_ops.constant(mask_list) 173 [x, y], mask = \ 174 metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y], mask) 175 x.shape.assert_is_compatible_with(y.shape) 176 y.shape.assert_is_compatible_with(mask.shape) 177 178 @parameterized.parameters([ 179 { 180 'x_list': [1], 181 'mask_list': [0] 182 }, 183 { 184 'x_list': [1, 2], 185 'mask_list': [0, 1] 186 }, 187 { 188 'x_list': [1, 2, 4], 189 'mask_list': [1, 1, 1] 190 }, 191 { 192 'x_list': [[1, 2], [3, 4]], 193 'mask_list': [[1, 1], [0, 1]] 194 }, 195 { 196 'x_list': [[1, 2], [3, 4], [1]], 197 'mask_list': [[1, 1], [0, 0], [1]] 198 }, 199 { 200 'x_list': [[1, 2], [], [1]], 201 'mask_list': [[1, 1], [], [0]] 202 }, 203 ]) 204 def test_passing_one_ragged_with_mask(self, x_list, mask_list): 205 x = ragged_factory_ops.constant(x_list) 206 mask = ragged_factory_ops.constant(mask_list) 207 [x], mask = \ 208 metrics_utils.ragged_assert_compatible_and_get_flat_values([x], mask) 209 x.shape.assert_is_compatible_with(mask.shape) 210 211 @parameterized.parameters([ 212 { 213 'x_list': [[[1, 3]]], 214 'y_list': [[2, 3]] 215 }, 216 ]) 217 def test_failing_different_ragged_and_dense_ranks(self, x_list, y_list): 218 x = ragged_factory_ops.constant(x_list) 219 y = ragged_factory_ops.constant(y_list) 220 with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises 221 [x, y 222 ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y]) 223 224 @parameterized.parameters([ 225 { 226 'x_list': [[[1, 3]]], 227 'y_list': [[[2, 3]]], 228 'mask_list': [[0, 1]] 229 }, 230 ]) 231 def test_failing_different_mask_ranks(self, x_list, y_list, mask_list): 232 x = ragged_factory_ops.constant(x_list) 233 y = ragged_factory_ops.constant(y_list) 234 mask = ragged_factory_ops.constant(mask_list) 235 with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises 236 [x, y 237 ], _ = metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y], 238 mask) 239 240 # we do not support such cases that ragged_ranks are different but overall 241 # dimension shapes and sizes are identical due to adding too much performance 242 # overheads to the overall use cases. 243 def test_failing_different_ragged_ranks(self): 244 dt = constant_op.constant([[[1, 2]]]) 245 # adding a ragged dimension 246 x = ragged_tensor.RaggedTensor.from_row_splits(dt, row_splits=[0, 1]) 247 y = ragged_factory_ops.constant([[[[1, 2]]]]) 248 with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises 249 [x, y], _ = \ 250 metrics_utils.ragged_assert_compatible_and_get_flat_values([x, y]) 251 252 253@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 254class FilterTopKTest(test_util.TensorFlowTestCase, parameterized.TestCase): 255 256 def test_one_dimensional(self): 257 x = constant_op.constant([.3, .1, .2, -.5, 42.]) 258 top_1 = self.evaluate(metrics_utils._filter_top_k(x=x, k=1)) 259 top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2)) 260 top_3 = self.evaluate(metrics_utils._filter_top_k(x=x, k=3)) 261 262 self.assertAllClose(top_1, [ 263 metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF, 264 metrics_utils.NEG_INF, 42. 265 ]) 266 self.assertAllClose(top_2, [ 267 .3, metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF, 268 42. 269 ]) 270 self.assertAllClose( 271 top_3, [.3, metrics_utils.NEG_INF, .2, metrics_utils.NEG_INF, 42.]) 272 273 def test_three_dimensional(self): 274 x = constant_op.constant([[[.3, .1, .2], [-.3, -.2, -.1]], 275 [[5., .2, 42.], [-.3, -.6, -.99]]]) 276 top_2 = self.evaluate(metrics_utils._filter_top_k(x=x, k=2)) 277 278 self.assertAllClose( 279 top_2, 280 [[[.3, metrics_utils.NEG_INF, .2], [metrics_utils.NEG_INF, -.2, -.1]], 281 [[5., metrics_utils.NEG_INF, 42.], [-.3, -.6, metrics_utils.NEG_INF]]]) 282 283 def test_handles_dynamic_shapes(self): 284 # See b/150281686. # GOOGLE_INTERNAL 285 286 def _identity(x): 287 return x 288 289 def _filter_top_k(x): 290 # This loses the static shape. 291 x = script_ops.numpy_function(_identity, (x,), dtypes.float32) 292 293 return metrics_utils._filter_top_k(x=x, k=2) 294 295 x = constant_op.constant([.3, .1, .2, -.5, 42.]) 296 top_2 = self.evaluate(_filter_top_k(x)) 297 self.assertAllClose(top_2, [ 298 .3, metrics_utils.NEG_INF, metrics_utils.NEG_INF, metrics_utils.NEG_INF, 299 42. 300 ]) 301 302 303if __name__ == '__main__': 304 test.main() 305