1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for compile utitilies.""" 16 17from tensorflow.python.distribute import one_device_strategy 18from tensorflow.python.framework import constant_op 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.keras import backend 22from tensorflow.python.keras import keras_parameterized 23from tensorflow.python.keras import losses as losses_mod 24from tensorflow.python.keras import metrics as metrics_mod 25from tensorflow.python.keras.engine import compile_utils 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops.ragged import ragged_functional_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.platform import test 31 32 33class LossesContainerTest(keras_parameterized.TestCase): 34 35 def test_single_loss(self): 36 loss_container = compile_utils.LossesContainer('mse') 37 y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5)) 38 total_loss = loss_container(y_t, y_p) 39 40 self.assertTrue(loss_container._built) 41 self.assertLen(loss_container._losses, 1) 42 self.assertEqual(total_loss.numpy(), 1.) 43 self.assertLen(loss_container.metrics, 1) 44 45 loss_metric = loss_container.metrics[0] 46 self.assertEqual(loss_metric.name, 'loss') 47 self.assertEqual(loss_metric.result().numpy(), 1.) 48 49 loss_container.reset_state() 50 self.assertEqual(loss_metric.result().numpy(), 0.) 51 52 def test_loss_list(self): 53 loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5]) 54 55 y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] 56 y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] 57 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 58 59 total_loss = loss_container(y_t, y_p, sample_weight=sw) 60 61 self.assertEqual(loss_container._output_names, ['output_1', 'output_2']) 62 63 self.assertLen(loss_container._losses, 2) 64 self.assertEqual(total_loss.numpy(), 0.25) 65 66 loss_metric = loss_container.metrics[0] 67 self.assertEqual(loss_metric.name, 'loss') 68 self.assertEqual(loss_metric.result().numpy(), 0.25) 69 70 output_1_metric = loss_container.metrics[1] 71 self.assertEqual(output_1_metric.name, 'output_1_loss') 72 self.assertEqual(output_1_metric.result().numpy(), 0) 73 74 output_2_metric = loss_container.metrics[2] 75 self.assertEqual(output_2_metric.name, 'output_2_loss') 76 self.assertEqual(output_2_metric.result().numpy(), 0.5) 77 78 loss_container.reset_state() 79 self.assertEqual(loss_metric.result().numpy(), 0) 80 self.assertEqual(output_1_metric.result().numpy(), 0) 81 self.assertEqual(output_2_metric.result().numpy(), 0) 82 83 def test_loss_dict(self): 84 loss_container = compile_utils.LossesContainer( 85 { 86 'out1': 'mse', 87 'out2': 'mae' 88 }, { 89 'out1': 1, 90 'out2': 0.5 91 }) 92 93 y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} 94 y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))} 95 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 96 97 total_loss = loss_container(y_t, y_p, sample_weight=sw) 98 99 self.assertLen(loss_container._losses, 2) 100 self.assertEqual(total_loss.numpy(), 0.25) 101 self.assertLen(loss_container.metrics, 3) 102 103 loss_metric = loss_container.metrics[0] 104 self.assertEqual(loss_metric.name, 'loss') 105 self.assertEqual(loss_metric.result().numpy(), 0.25) 106 107 out1_metric = loss_container.metrics[1] 108 self.assertEqual(out1_metric.name, 'out1_loss') 109 self.assertEqual(out1_metric.result().numpy(), 0) 110 111 out2_metric = loss_container.metrics[2] 112 self.assertEqual(out2_metric.name, 'out2_loss') 113 self.assertEqual(out2_metric.result().numpy(), 0.5) 114 115 loss_container.reset_state() 116 self.assertEqual(loss_metric.result().numpy(), 0) 117 self.assertEqual(out1_metric.result().numpy(), 0) 118 self.assertEqual(out2_metric.result().numpy(), 0) 119 120 def test_loss_partial_dict_with_output_names(self): 121 loss_container = compile_utils.LossesContainer( 122 {'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2']) 123 124 y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] 125 y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] 126 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 127 128 total_loss = loss_container(y_t, y_p, sample_weight=sw) 129 130 self.assertEqual(total_loss.numpy(), 0.5) 131 self.assertLen(loss_container.metrics, 2) 132 133 loss_metric = loss_container.metrics[0] 134 self.assertEqual(loss_metric.name, 'loss') 135 self.assertEqual(loss_metric.result().numpy(), 0.5) 136 137 out2_metric = loss_container.metrics[1] 138 self.assertEqual(out2_metric.name, 'out2_loss') 139 self.assertEqual(out2_metric.result().numpy(), 0.5) 140 141 def test_loss_dict_with_nones(self): 142 loss_container = compile_utils.LossesContainer({ 143 'out1': None, 144 'out2': 'mae' 145 }) 146 147 y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} 148 y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))} 149 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 150 151 total_loss = loss_container(y_t, y_p, sample_weight=sw) 152 153 self.assertEqual(total_loss.numpy(), 0.5) 154 self.assertLen(loss_container.metrics, 2) 155 156 loss_metric = loss_container.metrics[0] 157 self.assertEqual(loss_metric.name, 'loss') 158 self.assertEqual(loss_metric.result().numpy(), 0.5) 159 160 out2_metric = loss_container.metrics[1] 161 self.assertEqual(out2_metric.name, 'out2_loss') 162 self.assertEqual(out2_metric.result().numpy(), 0.5) 163 164 def test_nested_structure(self): 165 loss_container = compile_utils.LossesContainer( 166 { 167 'b': ['mse', None], 168 'a': 'mae' 169 }, loss_weights={ 170 'b': [0.5, 0], 171 'a': 1 172 }) 173 174 y_t = { 175 'b': [array_ops.ones((10, 1)), 176 array_ops.zeros((10, 1))], 177 'a': array_ops.zeros((10, 1)) 178 } 179 y_p = { 180 'b': [array_ops.zeros((10, 1)), 181 array_ops.zeros((10, 1))], 182 'a': array_ops.ones((10, 1)) 183 } 184 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 185 186 total_loss = loss_container(y_t, y_p, sample_weight=sw) 187 self.assertEqual(total_loss.numpy(), 0.75) 188 self.assertLen(loss_container.metrics, 3) 189 190 loss_metric = loss_container.metrics[0] 191 self.assertEqual(loss_metric.name, 'loss') 192 self.assertEqual(loss_metric.result().numpy(), 0.75) 193 194 a_metric = loss_container.metrics[1] 195 self.assertEqual(a_metric.name, 'a_loss') 196 self.assertEqual(a_metric.result().numpy(), 0.5) 197 198 b_1_metric = loss_container.metrics[2] 199 self.assertEqual(b_1_metric.name, 'b_1_loss') 200 self.assertEqual(b_1_metric.result().numpy(), 0.5) 201 202 def test_broadcast_single_loss(self): 203 loss_container = compile_utils.LossesContainer('mse') 204 205 y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] 206 y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] 207 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 208 209 total_loss = loss_container(y_t, y_p, sample_weight=sw) 210 self.assertEqual(total_loss.numpy(), 0.5) 211 self.assertLen(loss_container.metrics, 3) 212 213 loss_metric = loss_container.metrics[0] 214 self.assertEqual(loss_metric.name, 'loss') 215 self.assertEqual(loss_metric.result().numpy(), 0.5) 216 217 output_1_metric = loss_container.metrics[1] 218 self.assertEqual(output_1_metric.name, 'output_1_loss') 219 self.assertEqual(output_1_metric.result().numpy(), 0.) 220 221 output_2_metric = loss_container.metrics[2] 222 self.assertEqual(output_2_metric.name, 'output_2_loss') 223 self.assertEqual(output_2_metric.result().numpy(), 0.5) 224 225 def test_missing_label_with_no_loss(self): 226 # It's ok to exclude a label if that label has no 227 # losses or metrics associated with it. 228 loss_container = compile_utils.LossesContainer({ 229 'output1': 'mse', 230 'output3': 'mae' 231 }) 232 233 y_p = { 234 'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]), 235 'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]), 236 'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]]) 237 } 238 y_t = { 239 'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]), 240 'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]]) 241 } 242 243 total_loss = loss_container(y_t, y_p) 244 self.assertEqual(total_loss.numpy(), 3.) 245 self.assertLen(loss_container.metrics, 3) 246 247 loss_metric = loss_container.metrics[0] 248 self.assertEqual(loss_metric.name, 'loss') 249 self.assertEqual(loss_metric.result().numpy(), 3.) 250 251 output_1_metric = loss_container.metrics[1] 252 self.assertEqual(output_1_metric.name, 'output1_loss') 253 self.assertEqual(output_1_metric.result().numpy(), 1.) 254 255 output_3_metric = loss_container.metrics[2] 256 self.assertEqual(output_3_metric.name, 'output3_loss') 257 self.assertEqual(output_3_metric.result().numpy(), 2.) 258 259 def test_mismatched_dtypes(self): 260 y_t = constant_op.constant([1, 9, 2, -5], shape=(2, 2)) 261 y_p = constant_op.constant([4, 8, 12, 8], 262 shape=(2, 2), 263 dtype=dtypes.float32) 264 265 def my_mae(labels, preds): 266 self.assertEqual(labels.dtype, dtypes.int32) 267 self.assertEqual(preds.dtype, dtypes.float32) 268 labels = math_ops.cast(labels, preds.dtype) 269 return backend.mean(math_ops.abs(preds - labels), axis=-1) 270 271 loss_container = compile_utils.LossesContainer(my_mae) 272 total_loss = loss_container(y_t, y_p) 273 self.assertEqual(total_loss.dtype, dtypes.float32) 274 275 def test_integer_dtypes(self): 276 y_t = constant_op.constant([1, 9, 2, -5], shape=(2, 2)) 277 y_p = constant_op.constant([4, 8, 12, 8], shape=(2, 2), dtype=dtypes.int64) 278 279 def my_mae(labels, preds): 280 self.assertEqual(labels.dtype, dtypes.int64) 281 self.assertEqual(preds.dtype, dtypes.int64) 282 return backend.mean(math_ops.abs(preds - labels), axis=-1) 283 284 loss_container = compile_utils.LossesContainer(my_mae) 285 total_loss = loss_container(y_t, y_p) 286 self.assertEqual(total_loss.dtype, dtypes.int64) 287 288 def test_float_dtypes(self): 289 y_t = constant_op.constant([1, 9, 2, -5], 290 shape=(2, 2), 291 dtype=dtypes.float32) 292 y_p = constant_op.constant([4, 8, 12, 8], 293 shape=(2, 2), 294 dtype=dtypes.float64) 295 296 def my_mae(labels, preds): 297 self.assertEqual(labels.dtype, dtypes.float64) 298 self.assertEqual(preds.dtype, dtypes.float64) 299 return backend.mean(math_ops.abs(preds - labels), axis=-1) 300 301 loss_container = compile_utils.LossesContainer(my_mae) 302 total_loss = loss_container(y_t, y_p) 303 self.assertEqual(total_loss.dtype, dtypes.float64) 304 305 def test_loss_masking(self): 306 loss_container = compile_utils.LossesContainer('mae') 307 y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32) 308 y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32) 309 y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]], 310 dtype=dtypes.float32) 311 312 total_loss = loss_container(y_t, y_p) 313 self.assertAlmostEqual(total_loss.numpy(), .25) # sum over batch size 314 315 self.assertLen(loss_container.metrics, 1) 316 loss_metric = loss_container.metrics[0] 317 self.assertEqual(loss_metric.name, 'loss') 318 self.assertAlmostEqual(loss_metric.result().numpy(), .25) 319 320 def test_loss_sample_weight(self): 321 loss_container = compile_utils.LossesContainer('mae') 322 y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32) 323 y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32) 324 sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32) 325 326 total_loss = loss_container(y_t, y_p, sample_weight=sw) 327 # (0 * .2 + 0 * .3 + 1 * .5 + 1 * 0) / 4 328 self.assertAlmostEqual(total_loss.numpy(), .125) 329 330 self.assertLen(loss_container.metrics, 1) 331 loss_metric = loss_container.metrics[0] 332 self.assertEqual(loss_metric.name, 'loss') 333 self.assertAlmostEqual(loss_metric.result().numpy(), .125) 334 335 def test_loss_masking_sample_weight(self): 336 loss_container = compile_utils.LossesContainer('mae') 337 y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32) 338 y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32) 339 sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32) 340 y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]], 341 dtype=dtypes.float32) 342 343 total_loss = loss_container(y_t, y_p, sample_weight=sw) 344 # (0 * .2 + 1 * .5) / 4 345 self.assertAlmostEqual(total_loss.numpy(), .125) # sum over batch size 346 347 self.assertLen(loss_container.metrics, 1) 348 loss_metric = loss_container.metrics[0] 349 self.assertEqual(loss_metric.name, 'loss') 350 self.assertAlmostEqual(loss_metric.result().numpy(), .125) 351 352 def test_custom_loss_callables(self): 353 354 def custom_loss_fn(y_true, y_pred): 355 return math_ops.reduce_sum(y_true - y_pred) 356 357 class CustomLossClass(object): 358 359 def __call__(self, y_true, y_pred): 360 return math_ops.reduce_sum(y_true - y_pred) 361 362 loss_container = compile_utils.LossesContainer( 363 [custom_loss_fn, CustomLossClass()]) 364 y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5)) 365 loss_container(y_t, y_p) 366 367 self.assertEqual(loss_container._losses[0].name, 'custom_loss_fn') 368 self.assertEqual(loss_container._losses[1].name, 'custom_loss_class') 369 370 def test_ragged_tensor_output(self): 371 """Ensure that ragged tensors can be passed as targets and predictions.""" 372 373 def custom_loss_fn(y_true, y_pred): 374 """MSE supports RaggedTensors directly.""" 375 return losses_mod.mse(y_true, y_pred) 376 377 class CustomLossClass(losses_mod.Loss): 378 """User defined loss function must implement RaggedTensor support.""" 379 380 def call(self, y_true, y_pred): 381 losses = ragged_functional_ops.map_flat_values( 382 math_ops.squared_difference, y_true, y_pred) 383 return math_ops.reduce_mean(losses) 384 385 loss_container = compile_utils.LossesContainer( 386 [custom_loss_fn, CustomLossClass()]) 387 388 v_t = constant_op.constant([[3., 4.], [1., 2.], [3., 5.]]) 389 v_p = constant_op.constant([[3.1, 4.], [1., 2.], [3., 5.]]) 390 391 y_t = array_ops.expand_dims( 392 ragged_tensor.RaggedTensor.from_row_splits(v_t, [0, 2, 3]), 0) 393 y_p = array_ops.expand_dims( 394 ragged_tensor.RaggedTensor.from_row_splits(v_p, [0, 2, 3]), 0) 395 loss_container(y_t, y_p) 396 397 self.assertEqual(loss_container._losses[0].name, 'custom_loss_fn') 398 399 400class MetricsContainerTest(keras_parameterized.TestCase): 401 402 def test_single_metric(self): 403 metric_container = compile_utils.MetricsContainer('mse') 404 y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5)) 405 metric_container.update_state(y_t, y_p) 406 407 self.assertLen(metric_container.metrics, 1) 408 metric = metric_container.metrics[0] 409 self.assertEqual(metric.name, 'mse') 410 self.assertEqual(metric.result().numpy(), 1.) 411 412 metric_container.reset_state() 413 self.assertEqual(metric.result().numpy(), 0.) 414 415 def test_list_of_metrics_one_output(self): 416 metric_container = compile_utils.MetricsContainer(['mse', 'mae']) 417 y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5)) 418 metric_container.update_state(y_t, y_p) 419 self.assertLen(metric_container.metrics, 2) 420 421 mse_metric = metric_container.metrics[0] 422 self.assertEqual(mse_metric.name, 'mse') 423 self.assertEqual(mse_metric.result().numpy(), 4.) 424 425 mae_metric = metric_container.metrics[1] 426 self.assertEqual(mae_metric.name, 'mae') 427 self.assertEqual(mae_metric.result().numpy(), 2.) 428 429 metric_container.reset_state() 430 self.assertEqual(mse_metric.result().numpy(), 0.) 431 self.assertEqual(mae_metric.result().numpy(), 0.) 432 433 def test_list_of_metrics_list_of_outputs(self): 434 metric_container = compile_utils.MetricsContainer( 435 metrics=['mse', 'mae'], # Should broadcast to both outputs. 436 weighted_metrics=['accuracy']) # Should broadcast to both outputs. 437 438 y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] 439 y_p = [array_ops.ones((10, 1)), 2 * array_ops.ones((10, 1))] 440 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 441 metric_container.update_state(y_t, y_p, sample_weight=sw) 442 self.assertLen(metric_container.metrics, 6) 443 444 mse_metric = metric_container.metrics[0] 445 self.assertEqual(mse_metric.name, 'output_1_mse') 446 self.assertEqual(mse_metric.result().numpy(), 0.) 447 448 mse_metric = metric_container.metrics[1] 449 self.assertEqual(mse_metric.name, 'output_1_mae') 450 self.assertEqual(mse_metric.result().numpy(), 0.) 451 452 acc_metric_1 = metric_container.metrics[2] 453 self.assertEqual(acc_metric_1.name, 'output_1_accuracy') 454 self.assertEqual(acc_metric_1.result().numpy(), 1.) 455 self.assertEqual(acc_metric_1._fn, metrics_mod.binary_accuracy) 456 457 mae_metric = metric_container.metrics[3] 458 self.assertEqual(mae_metric.name, 'output_2_mse') 459 self.assertEqual(mae_metric.result().numpy(), 4.) 460 461 mae_metric = metric_container.metrics[4] 462 self.assertEqual(mae_metric.name, 'output_2_mae') 463 self.assertEqual(mae_metric.result().numpy(), 2.) 464 465 acc_metric_2 = metric_container.metrics[5] 466 self.assertEqual(acc_metric_2.name, 'output_2_accuracy') 467 self.assertEqual(acc_metric_2.result().numpy(), 0.) 468 self.assertEqual(acc_metric_2._fn, metrics_mod.binary_accuracy) 469 470 weighted_metrics = metric_container.weighted_metrics 471 self.assertLen(weighted_metrics, 2) 472 self.assertEqual(weighted_metrics[0].name, 'output_1_accuracy') 473 self.assertEqual(weighted_metrics[1].name, 'output_2_accuracy') 474 475 unweighted_metrics = metric_container.unweighted_metrics 476 self.assertLen(unweighted_metrics, 4) 477 self.assertEqual(unweighted_metrics[0].name, 'output_1_mse') 478 self.assertEqual(unweighted_metrics[1].name, 'output_1_mae') 479 self.assertEqual(unweighted_metrics[2].name, 'output_2_mse') 480 self.assertEqual(unweighted_metrics[3].name, 'output_2_mae') 481 482 def test_metric_dict(self): 483 metric_container = compile_utils.MetricsContainer( 484 metrics={ 485 'out1': 'mse', 486 'out2': 'mae' 487 }, 488 weighted_metrics={ 489 'out1': 'mse', 490 'out2': 'mae' 491 }) 492 493 y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} 494 y_p = {'out1': array_ops.ones((10, 1)), 'out2': 2 * array_ops.ones((10, 1))} 495 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 496 metric_container.update_state(y_t, y_p, sample_weight=sw) 497 498 mse_metric = metric_container.metrics[0] 499 self.assertEqual(mse_metric.name, 'out1_mse') 500 self.assertEqual(mse_metric.result().numpy(), 0.) 501 502 weighted_mse_metric = metric_container.metrics[1] 503 self.assertEqual(weighted_mse_metric.name, 'out1_weighted_mse') 504 self.assertEqual(weighted_mse_metric.result().numpy(), 0.) 505 506 mae_metric = metric_container.metrics[2] 507 self.assertEqual(mae_metric.name, 'out2_mae') 508 self.assertEqual(mae_metric.result().numpy(), 2.) 509 510 weighted_mae_metric = metric_container.metrics[3] 511 self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae') 512 self.assertEqual(weighted_mae_metric.result().numpy(), 2.) 513 514 metric_container.reset_state() 515 self.assertEqual(mse_metric.result().numpy(), 0.) 516 self.assertEqual(weighted_mse_metric.result().numpy(), 0.) 517 self.assertEqual(mae_metric.result().numpy(), 0.) 518 self.assertEqual(weighted_mae_metric.result().numpy(), 0.) 519 520 def test_metric_partial_dict_with_output_names(self): 521 metric_container = compile_utils.MetricsContainer( 522 {'out2': 'mae'}, output_names=['out1', 'out2']) 523 524 y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))] 525 y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))] 526 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 527 528 metric_container.update_state(y_t, y_p, sample_weight=sw) 529 self.assertLen(metric_container.metrics, 1) 530 531 mae_metric = metric_container.metrics[0] 532 self.assertEqual(mae_metric.name, 'out2_mae') 533 self.assertEqual(mae_metric.result().numpy(), 1.) 534 535 def test_metric_partial_dict_with_nones(self): 536 metric_container = compile_utils.MetricsContainer({ 537 'out1': None, 538 'out2': 'mae' 539 }) 540 541 y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))} 542 y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))} 543 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 544 545 metric_container.update_state(y_t, y_p, sample_weight=sw) 546 self.assertLen(metric_container.metrics, 1) 547 548 mae_metric = metric_container.metrics[0] 549 self.assertEqual(mae_metric.name, 'out2_mae') 550 self.assertEqual(mae_metric.result().numpy(), 1.) 551 552 def test_nested_structure(self): 553 metric_container = compile_utils.MetricsContainer( 554 metrics={ 555 'b': ['mse', None], 556 'a': 'mae' 557 }, 558 weighted_metrics={ 559 'b': [None, None], 560 'a': 'mse' 561 }) 562 563 y_t = { 564 'b': [2 * array_ops.ones((10, 1)), 565 array_ops.zeros((10, 1))], 566 'a': array_ops.zeros((10, 1)) 567 } 568 y_p = { 569 'b': [array_ops.zeros((10, 1)), 570 array_ops.zeros((10, 1))], 571 'a': array_ops.ones((10, 1)) 572 } 573 sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 574 575 metric_container.update_state(y_t, y_p, sample_weight=sw) 576 self.assertLen(metric_container.metrics, 3) 577 578 a_mae_metric = metric_container.metrics[0] 579 self.assertEqual(a_mae_metric.name, 'a_mae') 580 self.assertEqual(a_mae_metric.result().numpy(), 1.) 581 582 weighted_a_mae_metric = metric_container.metrics[1] 583 self.assertEqual(weighted_a_mae_metric.name, 'a_mse') 584 self.assertEqual(weighted_a_mae_metric.result().numpy(), 1.) 585 586 b_1_mse_metric = metric_container.metrics[2] 587 self.assertEqual(b_1_mse_metric.name, 'b_1_mse') 588 self.assertEqual(b_1_mse_metric.result().numpy(), 4.) 589 590 def test_crossentropy(self): 591 metric_container = compile_utils.MetricsContainer('crossentropy') 592 y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1)) 593 metric_container.update_state(y_t, y_p) 594 self.assertEqual(metric_container.metrics[0]._fn, 595 metrics_mod.binary_crossentropy) 596 597 metric_container = compile_utils.MetricsContainer('crossentropy') 598 y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 20)) 599 self.assertEqual(y_p.shape.as_list()[-1], 20) 600 metric_container.update_state(y_t, y_p) 601 self.assertEqual(metric_container.metrics[0]._fn, 602 metrics_mod.sparse_categorical_crossentropy) 603 604 metric_container = compile_utils.MetricsContainer('crossentropy') 605 y_t, y_p = array_ops.ones((10, 20)), array_ops.ones((10, 20)) 606 metric_container.update_state(y_t, y_p) 607 self.assertEqual(metric_container.metrics[0]._fn, 608 metrics_mod.categorical_crossentropy) 609 610 def test_accuracy(self): 611 metric_container = compile_utils.MetricsContainer('accuracy') 612 y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1)) 613 metric_container.update_state(y_t, y_p) 614 self.assertEqual(metric_container.metrics[0]._fn, 615 metrics_mod.binary_accuracy) 616 617 metric_container = compile_utils.MetricsContainer('Accuracy') 618 y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1)) 619 metric_container.update_state(y_t, y_p) 620 self.assertEqual(metric_container.metrics[0]._fn, 621 metrics_mod.binary_accuracy) 622 623 metric_container = compile_utils.MetricsContainer('accuracy') 624 y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 20)) 625 self.assertEqual(y_p.shape.as_list()[-1], 20) 626 metric_container.update_state(y_t, y_p) 627 self.assertEqual(metric_container.metrics[0]._fn, 628 metrics_mod.sparse_categorical_accuracy) 629 630 metric_container = compile_utils.MetricsContainer('accuracy') 631 y_t, y_p = array_ops.ones((10, 20)), array_ops.ones((10, 20)) 632 metric_container.update_state(y_t, y_p) 633 self.assertEqual(metric_container.metrics[0]._fn, 634 metrics_mod.categorical_accuracy) 635 636 def test_metric_weighting(self): 637 metric_container = compile_utils.MetricsContainer( 638 metrics=['mae'], weighted_metrics=['mae']) 639 640 y_t = ops.convert_to_tensor_v2_with_dispatch([[0], [3], [0]]) 641 y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [0], [0]]) 642 sw = ops.convert_to_tensor_v2_with_dispatch([[1], [0], [1]]) 643 644 metric_container.update_state(y_t, y_p, sample_weight=sw) 645 self.assertLen(metric_container.metrics, 2) 646 647 mae_metric = metric_container.metrics[0] 648 self.assertEqual(mae_metric.name, 'mae') 649 self.assertEqual(mae_metric.result().numpy(), 1.) 650 651 weighted_mae_metric = metric_container.metrics[1] 652 self.assertEqual(weighted_mae_metric.name, 'weighted_mae') 653 self.assertEqual(weighted_mae_metric.result().numpy(), 0.) 654 655 def test_broadcast_metrics_to_dict(self): 656 metric_container = compile_utils.MetricsContainer(metrics=['mae']) 657 658 y_p = {'output': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]])} 659 y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])} 660 metric_container.update_state(y_t, y_p) 661 662 mae_metric = metric_container.metrics[0] 663 self.assertEqual(mae_metric.name, 'mae') 664 self.assertEqual(mae_metric.result().numpy(), 1.) 665 666 def test_broadcast_metrics_to_dict_with_output_names(self): 667 metric_container = compile_utils.MetricsContainer( 668 metrics=['mae'], output_names=['output']) 669 670 y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]) 671 y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])} 672 metric_container.update_state(y_t, y_p) 673 674 mae_metric = metric_container.metrics[0] 675 self.assertEqual(mae_metric.name, 'mae') 676 self.assertEqual(mae_metric.result().numpy(), 1.) 677 678 def test_missing_label_with_no_metrics(self): 679 # It's ok to exclude a label if that label has no 680 # losses or metrics associated with it. 681 metric_container = compile_utils.MetricsContainer(metrics={ 682 'output1': 'mae', 683 'output3': 'mse' 684 }) 685 686 y_p = { 687 'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]), 688 'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]), 689 'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]]) 690 } 691 y_t = { 692 'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]), 693 'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]]) 694 } 695 696 metric_container.update_state(y_t, y_p) 697 self.assertLen(metric_container.metrics, 2) 698 699 mae_metric = metric_container.metrics[0] 700 self.assertEqual(mae_metric.name, 'output1_mae') 701 self.assertEqual(mae_metric.result().numpy(), 1.) 702 703 mse_metric = metric_container.metrics[1] 704 self.assertEqual(mse_metric.name, 'output3_mse') 705 self.assertEqual(mse_metric.result().numpy(), 4.) 706 707 def test_metrics_masking(self): 708 metrics_container = compile_utils.MetricsContainer( 709 metrics=['mae'], weighted_metrics=['mse']) 710 y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32) 711 y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32) 712 y_p._keras_mask = constant_op.constant([[1, 1], [0, 0]], 713 dtype=dtypes.float32) 714 715 metrics_container.update_state(y_t, y_p) 716 self.assertLen(metrics_container.metrics, 2) 717 718 mae_metric = metrics_container.metrics[0] 719 self.assertEqual(mae_metric.name, 'mae') 720 self.assertAlmostEqual(mae_metric.result().numpy(), 0) 721 722 weighted_mae_metric = metrics_container.metrics[1] 723 self.assertEqual(weighted_mae_metric.name, 'mse') 724 self.assertAlmostEqual(weighted_mae_metric.result().numpy(), 0) 725 726 def test_metrics_sample_weight(self): 727 metrics_container = compile_utils.MetricsContainer( 728 metrics=['mae'], weighted_metrics=['mse']) 729 y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32) 730 y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32) 731 sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32) 732 733 metrics_container.update_state(y_t, y_p, sample_weight=sw) 734 self.assertLen(metrics_container.metrics, 2) 735 736 mae_metric = metrics_container.metrics[0] 737 self.assertEqual(mae_metric.name, 'mae') 738 self.assertAlmostEqual(mae_metric.result().numpy(), .25) # 1 / 4 739 740 weighted_mae_metric = metrics_container.metrics[1] 741 self.assertEqual(weighted_mae_metric.name, 'mse') 742 self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .5) # .5 / 1 743 744 def test_metrics_masking_sample_weight(self): 745 metrics_container = compile_utils.MetricsContainer( 746 metrics=['mae'], weighted_metrics=['mse']) 747 y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32) 748 y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32) 749 sw = constant_op.constant([[.3, .2], [.2, .3]], dtype=dtypes.float32) 750 y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]], 751 dtype=dtypes.float32) 752 753 metrics_container.update_state(y_t, y_p, sample_weight=sw) 754 self.assertLen(metrics_container.metrics, 2) 755 756 mae_metric = metrics_container.metrics[0] 757 self.assertEqual(mae_metric.name, 'mae') 758 self.assertAlmostEqual(mae_metric.result().numpy(), .5) # 1 / .5 759 760 weighted_mae_metric = metrics_container.metrics[1] 761 self.assertEqual(weighted_mae_metric.name, 'mse') 762 self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .2 / .5) 763 764 def test_loss_class_as_metric_with_distribution(self): 765 distribution = one_device_strategy.OneDeviceStrategy('/device:CPU:0') 766 with distribution.scope(): 767 metric_container = compile_utils.MetricsContainer( 768 losses_mod.MeanSquaredError()) 769 y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5)) 770 metric_container.update_state(y_t, y_p) 771 772 self.assertLen(metric_container.metrics, 1) 773 metric = metric_container.metrics[0] 774 self.assertEqual(metric.name, 'mean_squared_error') 775 self.assertEqual(metric.result().numpy(), 1.) 776 777 def test_custom_metric_callables(self): 778 779 def custom_metric_fn(y_true, y_pred): 780 return math_ops.reduce_sum(y_true - y_pred) 781 782 class CustomMetricClass(object): 783 784 def __call__(self, y_true, y_pred): 785 return math_ops.reduce_sum(y_true - y_pred) 786 787 metric_container = compile_utils.MetricsContainer( 788 [custom_metric_fn, CustomMetricClass()]) 789 y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5)) 790 metric_container.update_state(y_t, y_p) 791 792 self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn') 793 self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class') 794 795 def test_reset_state_existing_metric_before_built(self): 796 metric = metrics_mod.Mean() 797 metric.update_state([2.0, 4.0]) 798 self.assertEqual(metric.result().numpy(), 3.0) 799 800 metric_container = compile_utils.MetricsContainer(metric) 801 metric_container.reset_state() 802 self.assertEqual(metric.result().numpy(), 0.0) 803 804 805if __name__ == '__main__': 806 ops.enable_eager_execution() 807 test.main() 808