• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for compile utitilies."""
16
17from tensorflow.python.distribute import one_device_strategy
18from tensorflow.python.framework import constant_op
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.keras import backend
22from tensorflow.python.keras import keras_parameterized
23from tensorflow.python.keras import losses as losses_mod
24from tensorflow.python.keras import metrics as metrics_mod
25from tensorflow.python.keras.engine import compile_utils
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops.ragged import ragged_functional_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.platform import test
31
32
33class LossesContainerTest(keras_parameterized.TestCase):
34
35  def test_single_loss(self):
36    loss_container = compile_utils.LossesContainer('mse')
37    y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
38    total_loss = loss_container(y_t, y_p)
39
40    self.assertTrue(loss_container._built)
41    self.assertLen(loss_container._losses, 1)
42    self.assertEqual(total_loss.numpy(), 1.)
43    self.assertLen(loss_container.metrics, 1)
44
45    loss_metric = loss_container.metrics[0]
46    self.assertEqual(loss_metric.name, 'loss')
47    self.assertEqual(loss_metric.result().numpy(), 1.)
48
49    loss_container.reset_state()
50    self.assertEqual(loss_metric.result().numpy(), 0.)
51
52  def test_loss_list(self):
53    loss_container = compile_utils.LossesContainer(['mse', 'mae'], [1, 0.5])
54
55    y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
56    y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
57    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
58
59    total_loss = loss_container(y_t, y_p, sample_weight=sw)
60
61    self.assertEqual(loss_container._output_names, ['output_1', 'output_2'])
62
63    self.assertLen(loss_container._losses, 2)
64    self.assertEqual(total_loss.numpy(), 0.25)
65
66    loss_metric = loss_container.metrics[0]
67    self.assertEqual(loss_metric.name, 'loss')
68    self.assertEqual(loss_metric.result().numpy(), 0.25)
69
70    output_1_metric = loss_container.metrics[1]
71    self.assertEqual(output_1_metric.name, 'output_1_loss')
72    self.assertEqual(output_1_metric.result().numpy(), 0)
73
74    output_2_metric = loss_container.metrics[2]
75    self.assertEqual(output_2_metric.name, 'output_2_loss')
76    self.assertEqual(output_2_metric.result().numpy(), 0.5)
77
78    loss_container.reset_state()
79    self.assertEqual(loss_metric.result().numpy(), 0)
80    self.assertEqual(output_1_metric.result().numpy(), 0)
81    self.assertEqual(output_2_metric.result().numpy(), 0)
82
83  def test_loss_dict(self):
84    loss_container = compile_utils.LossesContainer(
85        {
86            'out1': 'mse',
87            'out2': 'mae'
88        }, {
89            'out1': 1,
90            'out2': 0.5
91        })
92
93    y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
94    y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))}
95    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
96
97    total_loss = loss_container(y_t, y_p, sample_weight=sw)
98
99    self.assertLen(loss_container._losses, 2)
100    self.assertEqual(total_loss.numpy(), 0.25)
101    self.assertLen(loss_container.metrics, 3)
102
103    loss_metric = loss_container.metrics[0]
104    self.assertEqual(loss_metric.name, 'loss')
105    self.assertEqual(loss_metric.result().numpy(), 0.25)
106
107    out1_metric = loss_container.metrics[1]
108    self.assertEqual(out1_metric.name, 'out1_loss')
109    self.assertEqual(out1_metric.result().numpy(), 0)
110
111    out2_metric = loss_container.metrics[2]
112    self.assertEqual(out2_metric.name, 'out2_loss')
113    self.assertEqual(out2_metric.result().numpy(), 0.5)
114
115    loss_container.reset_state()
116    self.assertEqual(loss_metric.result().numpy(), 0)
117    self.assertEqual(out1_metric.result().numpy(), 0)
118    self.assertEqual(out2_metric.result().numpy(), 0)
119
120  def test_loss_partial_dict_with_output_names(self):
121    loss_container = compile_utils.LossesContainer(
122        {'out2': 'mae'}, {'out2': 1.}, output_names=['out1', 'out2'])
123
124    y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
125    y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
126    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
127
128    total_loss = loss_container(y_t, y_p, sample_weight=sw)
129
130    self.assertEqual(total_loss.numpy(), 0.5)
131    self.assertLen(loss_container.metrics, 2)
132
133    loss_metric = loss_container.metrics[0]
134    self.assertEqual(loss_metric.name, 'loss')
135    self.assertEqual(loss_metric.result().numpy(), 0.5)
136
137    out2_metric = loss_container.metrics[1]
138    self.assertEqual(out2_metric.name, 'out2_loss')
139    self.assertEqual(out2_metric.result().numpy(), 0.5)
140
141  def test_loss_dict_with_nones(self):
142    loss_container = compile_utils.LossesContainer({
143        'out1': None,
144        'out2': 'mae'
145    })
146
147    y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
148    y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))}
149    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
150
151    total_loss = loss_container(y_t, y_p, sample_weight=sw)
152
153    self.assertEqual(total_loss.numpy(), 0.5)
154    self.assertLen(loss_container.metrics, 2)
155
156    loss_metric = loss_container.metrics[0]
157    self.assertEqual(loss_metric.name, 'loss')
158    self.assertEqual(loss_metric.result().numpy(), 0.5)
159
160    out2_metric = loss_container.metrics[1]
161    self.assertEqual(out2_metric.name, 'out2_loss')
162    self.assertEqual(out2_metric.result().numpy(), 0.5)
163
164  def test_nested_structure(self):
165    loss_container = compile_utils.LossesContainer(
166        {
167            'b': ['mse', None],
168            'a': 'mae'
169        }, loss_weights={
170            'b': [0.5, 0],
171            'a': 1
172        })
173
174    y_t = {
175        'b': [array_ops.ones((10, 1)),
176              array_ops.zeros((10, 1))],
177        'a': array_ops.zeros((10, 1))
178    }
179    y_p = {
180        'b': [array_ops.zeros((10, 1)),
181              array_ops.zeros((10, 1))],
182        'a': array_ops.ones((10, 1))
183    }
184    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
185
186    total_loss = loss_container(y_t, y_p, sample_weight=sw)
187    self.assertEqual(total_loss.numpy(), 0.75)
188    self.assertLen(loss_container.metrics, 3)
189
190    loss_metric = loss_container.metrics[0]
191    self.assertEqual(loss_metric.name, 'loss')
192    self.assertEqual(loss_metric.result().numpy(), 0.75)
193
194    a_metric = loss_container.metrics[1]
195    self.assertEqual(a_metric.name, 'a_loss')
196    self.assertEqual(a_metric.result().numpy(), 0.5)
197
198    b_1_metric = loss_container.metrics[2]
199    self.assertEqual(b_1_metric.name, 'b_1_loss')
200    self.assertEqual(b_1_metric.result().numpy(), 0.5)
201
202  def test_broadcast_single_loss(self):
203    loss_container = compile_utils.LossesContainer('mse')
204
205    y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
206    y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
207    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
208
209    total_loss = loss_container(y_t, y_p, sample_weight=sw)
210    self.assertEqual(total_loss.numpy(), 0.5)
211    self.assertLen(loss_container.metrics, 3)
212
213    loss_metric = loss_container.metrics[0]
214    self.assertEqual(loss_metric.name, 'loss')
215    self.assertEqual(loss_metric.result().numpy(), 0.5)
216
217    output_1_metric = loss_container.metrics[1]
218    self.assertEqual(output_1_metric.name, 'output_1_loss')
219    self.assertEqual(output_1_metric.result().numpy(), 0.)
220
221    output_2_metric = loss_container.metrics[2]
222    self.assertEqual(output_2_metric.name, 'output_2_loss')
223    self.assertEqual(output_2_metric.result().numpy(), 0.5)
224
225  def test_missing_label_with_no_loss(self):
226    # It's ok to exclude a label if that label has no
227    # losses or metrics associated with it.
228    loss_container = compile_utils.LossesContainer({
229        'output1': 'mse',
230        'output3': 'mae'
231    })
232
233    y_p = {
234        'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]),
235        'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]),
236        'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]])
237    }
238    y_t = {
239        'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]),
240        'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]])
241    }
242
243    total_loss = loss_container(y_t, y_p)
244    self.assertEqual(total_loss.numpy(), 3.)
245    self.assertLen(loss_container.metrics, 3)
246
247    loss_metric = loss_container.metrics[0]
248    self.assertEqual(loss_metric.name, 'loss')
249    self.assertEqual(loss_metric.result().numpy(), 3.)
250
251    output_1_metric = loss_container.metrics[1]
252    self.assertEqual(output_1_metric.name, 'output1_loss')
253    self.assertEqual(output_1_metric.result().numpy(), 1.)
254
255    output_3_metric = loss_container.metrics[2]
256    self.assertEqual(output_3_metric.name, 'output3_loss')
257    self.assertEqual(output_3_metric.result().numpy(), 2.)
258
259  def test_mismatched_dtypes(self):
260    y_t = constant_op.constant([1, 9, 2, -5], shape=(2, 2))
261    y_p = constant_op.constant([4, 8, 12, 8],
262                               shape=(2, 2),
263                               dtype=dtypes.float32)
264
265    def my_mae(labels, preds):
266      self.assertEqual(labels.dtype, dtypes.int32)
267      self.assertEqual(preds.dtype, dtypes.float32)
268      labels = math_ops.cast(labels, preds.dtype)
269      return backend.mean(math_ops.abs(preds - labels), axis=-1)
270
271    loss_container = compile_utils.LossesContainer(my_mae)
272    total_loss = loss_container(y_t, y_p)
273    self.assertEqual(total_loss.dtype, dtypes.float32)
274
275  def test_integer_dtypes(self):
276    y_t = constant_op.constant([1, 9, 2, -5], shape=(2, 2))
277    y_p = constant_op.constant([4, 8, 12, 8], shape=(2, 2), dtype=dtypes.int64)
278
279    def my_mae(labels, preds):
280      self.assertEqual(labels.dtype, dtypes.int64)
281      self.assertEqual(preds.dtype, dtypes.int64)
282      return backend.mean(math_ops.abs(preds - labels), axis=-1)
283
284    loss_container = compile_utils.LossesContainer(my_mae)
285    total_loss = loss_container(y_t, y_p)
286    self.assertEqual(total_loss.dtype, dtypes.int64)
287
288  def test_float_dtypes(self):
289    y_t = constant_op.constant([1, 9, 2, -5],
290                               shape=(2, 2),
291                               dtype=dtypes.float32)
292    y_p = constant_op.constant([4, 8, 12, 8],
293                               shape=(2, 2),
294                               dtype=dtypes.float64)
295
296    def my_mae(labels, preds):
297      self.assertEqual(labels.dtype, dtypes.float64)
298      self.assertEqual(preds.dtype, dtypes.float64)
299      return backend.mean(math_ops.abs(preds - labels), axis=-1)
300
301    loss_container = compile_utils.LossesContainer(my_mae)
302    total_loss = loss_container(y_t, y_p)
303    self.assertEqual(total_loss.dtype, dtypes.float64)
304
305  def test_loss_masking(self):
306    loss_container = compile_utils.LossesContainer('mae')
307    y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
308    y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
309    y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
310                                           dtype=dtypes.float32)
311
312    total_loss = loss_container(y_t, y_p)
313    self.assertAlmostEqual(total_loss.numpy(), .25)  # sum over batch size
314
315    self.assertLen(loss_container.metrics, 1)
316    loss_metric = loss_container.metrics[0]
317    self.assertEqual(loss_metric.name, 'loss')
318    self.assertAlmostEqual(loss_metric.result().numpy(), .25)
319
320  def test_loss_sample_weight(self):
321    loss_container = compile_utils.LossesContainer('mae')
322    y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
323    y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
324    sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
325
326    total_loss = loss_container(y_t, y_p, sample_weight=sw)
327    # (0 * .2 + 0 * .3 + 1 * .5 + 1 * 0) / 4
328    self.assertAlmostEqual(total_loss.numpy(), .125)
329
330    self.assertLen(loss_container.metrics, 1)
331    loss_metric = loss_container.metrics[0]
332    self.assertEqual(loss_metric.name, 'loss')
333    self.assertAlmostEqual(loss_metric.result().numpy(), .125)
334
335  def test_loss_masking_sample_weight(self):
336    loss_container = compile_utils.LossesContainer('mae')
337    y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
338    y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
339    sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
340    y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
341                                           dtype=dtypes.float32)
342
343    total_loss = loss_container(y_t, y_p, sample_weight=sw)
344    # (0 * .2 + 1 * .5) / 4
345    self.assertAlmostEqual(total_loss.numpy(), .125)  # sum over batch size
346
347    self.assertLen(loss_container.metrics, 1)
348    loss_metric = loss_container.metrics[0]
349    self.assertEqual(loss_metric.name, 'loss')
350    self.assertAlmostEqual(loss_metric.result().numpy(), .125)
351
352  def test_custom_loss_callables(self):
353
354    def custom_loss_fn(y_true, y_pred):
355      return math_ops.reduce_sum(y_true - y_pred)
356
357    class CustomLossClass(object):
358
359      def __call__(self, y_true, y_pred):
360        return math_ops.reduce_sum(y_true - y_pred)
361
362    loss_container = compile_utils.LossesContainer(
363        [custom_loss_fn, CustomLossClass()])
364    y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
365    loss_container(y_t, y_p)
366
367    self.assertEqual(loss_container._losses[0].name, 'custom_loss_fn')
368    self.assertEqual(loss_container._losses[1].name, 'custom_loss_class')
369
370  def test_ragged_tensor_output(self):
371    """Ensure that ragged tensors can be passed as targets and predictions."""
372
373    def custom_loss_fn(y_true, y_pred):
374      """MSE supports RaggedTensors directly."""
375      return losses_mod.mse(y_true, y_pred)
376
377    class CustomLossClass(losses_mod.Loss):
378      """User defined loss function must implement RaggedTensor support."""
379
380      def call(self, y_true, y_pred):
381        losses = ragged_functional_ops.map_flat_values(
382            math_ops.squared_difference, y_true, y_pred)
383        return math_ops.reduce_mean(losses)
384
385    loss_container = compile_utils.LossesContainer(
386        [custom_loss_fn, CustomLossClass()])
387
388    v_t = constant_op.constant([[3., 4.], [1., 2.], [3., 5.]])
389    v_p = constant_op.constant([[3.1, 4.], [1., 2.], [3., 5.]])
390
391    y_t = array_ops.expand_dims(
392        ragged_tensor.RaggedTensor.from_row_splits(v_t, [0, 2, 3]), 0)
393    y_p = array_ops.expand_dims(
394        ragged_tensor.RaggedTensor.from_row_splits(v_p, [0, 2, 3]), 0)
395    loss_container(y_t, y_p)
396
397    self.assertEqual(loss_container._losses[0].name, 'custom_loss_fn')
398
399
400class MetricsContainerTest(keras_parameterized.TestCase):
401
402  def test_single_metric(self):
403    metric_container = compile_utils.MetricsContainer('mse')
404    y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
405    metric_container.update_state(y_t, y_p)
406
407    self.assertLen(metric_container.metrics, 1)
408    metric = metric_container.metrics[0]
409    self.assertEqual(metric.name, 'mse')
410    self.assertEqual(metric.result().numpy(), 1.)
411
412    metric_container.reset_state()
413    self.assertEqual(metric.result().numpy(), 0.)
414
415  def test_list_of_metrics_one_output(self):
416    metric_container = compile_utils.MetricsContainer(['mse', 'mae'])
417    y_t, y_p = 2 * array_ops.ones((10, 5)), array_ops.zeros((10, 5))
418    metric_container.update_state(y_t, y_p)
419    self.assertLen(metric_container.metrics, 2)
420
421    mse_metric = metric_container.metrics[0]
422    self.assertEqual(mse_metric.name, 'mse')
423    self.assertEqual(mse_metric.result().numpy(), 4.)
424
425    mae_metric = metric_container.metrics[1]
426    self.assertEqual(mae_metric.name, 'mae')
427    self.assertEqual(mae_metric.result().numpy(), 2.)
428
429    metric_container.reset_state()
430    self.assertEqual(mse_metric.result().numpy(), 0.)
431    self.assertEqual(mae_metric.result().numpy(), 0.)
432
433  def test_list_of_metrics_list_of_outputs(self):
434    metric_container = compile_utils.MetricsContainer(
435        metrics=['mse', 'mae'],  # Should broadcast to both outputs.
436        weighted_metrics=['accuracy'])  # Should broadcast to both outputs.
437
438    y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
439    y_p = [array_ops.ones((10, 1)), 2 * array_ops.ones((10, 1))]
440    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
441    metric_container.update_state(y_t, y_p, sample_weight=sw)
442    self.assertLen(metric_container.metrics, 6)
443
444    mse_metric = metric_container.metrics[0]
445    self.assertEqual(mse_metric.name, 'output_1_mse')
446    self.assertEqual(mse_metric.result().numpy(), 0.)
447
448    mse_metric = metric_container.metrics[1]
449    self.assertEqual(mse_metric.name, 'output_1_mae')
450    self.assertEqual(mse_metric.result().numpy(), 0.)
451
452    acc_metric_1 = metric_container.metrics[2]
453    self.assertEqual(acc_metric_1.name, 'output_1_accuracy')
454    self.assertEqual(acc_metric_1.result().numpy(), 1.)
455    self.assertEqual(acc_metric_1._fn, metrics_mod.binary_accuracy)
456
457    mae_metric = metric_container.metrics[3]
458    self.assertEqual(mae_metric.name, 'output_2_mse')
459    self.assertEqual(mae_metric.result().numpy(), 4.)
460
461    mae_metric = metric_container.metrics[4]
462    self.assertEqual(mae_metric.name, 'output_2_mae')
463    self.assertEqual(mae_metric.result().numpy(), 2.)
464
465    acc_metric_2 = metric_container.metrics[5]
466    self.assertEqual(acc_metric_2.name, 'output_2_accuracy')
467    self.assertEqual(acc_metric_2.result().numpy(), 0.)
468    self.assertEqual(acc_metric_2._fn, metrics_mod.binary_accuracy)
469
470    weighted_metrics = metric_container.weighted_metrics
471    self.assertLen(weighted_metrics, 2)
472    self.assertEqual(weighted_metrics[0].name, 'output_1_accuracy')
473    self.assertEqual(weighted_metrics[1].name, 'output_2_accuracy')
474
475    unweighted_metrics = metric_container.unweighted_metrics
476    self.assertLen(unweighted_metrics, 4)
477    self.assertEqual(unweighted_metrics[0].name, 'output_1_mse')
478    self.assertEqual(unweighted_metrics[1].name, 'output_1_mae')
479    self.assertEqual(unweighted_metrics[2].name, 'output_2_mse')
480    self.assertEqual(unweighted_metrics[3].name, 'output_2_mae')
481
482  def test_metric_dict(self):
483    metric_container = compile_utils.MetricsContainer(
484        metrics={
485            'out1': 'mse',
486            'out2': 'mae'
487        },
488        weighted_metrics={
489            'out1': 'mse',
490            'out2': 'mae'
491        })
492
493    y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
494    y_p = {'out1': array_ops.ones((10, 1)), 'out2': 2 * array_ops.ones((10, 1))}
495    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
496    metric_container.update_state(y_t, y_p, sample_weight=sw)
497
498    mse_metric = metric_container.metrics[0]
499    self.assertEqual(mse_metric.name, 'out1_mse')
500    self.assertEqual(mse_metric.result().numpy(), 0.)
501
502    weighted_mse_metric = metric_container.metrics[1]
503    self.assertEqual(weighted_mse_metric.name, 'out1_weighted_mse')
504    self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
505
506    mae_metric = metric_container.metrics[2]
507    self.assertEqual(mae_metric.name, 'out2_mae')
508    self.assertEqual(mae_metric.result().numpy(), 2.)
509
510    weighted_mae_metric = metric_container.metrics[3]
511    self.assertEqual(weighted_mae_metric.name, 'out2_weighted_mae')
512    self.assertEqual(weighted_mae_metric.result().numpy(), 2.)
513
514    metric_container.reset_state()
515    self.assertEqual(mse_metric.result().numpy(), 0.)
516    self.assertEqual(weighted_mse_metric.result().numpy(), 0.)
517    self.assertEqual(mae_metric.result().numpy(), 0.)
518    self.assertEqual(weighted_mae_metric.result().numpy(), 0.)
519
520  def test_metric_partial_dict_with_output_names(self):
521    metric_container = compile_utils.MetricsContainer(
522        {'out2': 'mae'}, output_names=['out1', 'out2'])
523
524    y_t = [array_ops.ones((10, 1)), array_ops.zeros((10, 1))]
525    y_p = [array_ops.ones((10, 1)), array_ops.ones((10, 1))]
526    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
527
528    metric_container.update_state(y_t, y_p, sample_weight=sw)
529    self.assertLen(metric_container.metrics, 1)
530
531    mae_metric = metric_container.metrics[0]
532    self.assertEqual(mae_metric.name, 'out2_mae')
533    self.assertEqual(mae_metric.result().numpy(), 1.)
534
535  def test_metric_partial_dict_with_nones(self):
536    metric_container = compile_utils.MetricsContainer({
537        'out1': None,
538        'out2': 'mae'
539    })
540
541    y_t = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.zeros((10, 1))}
542    y_p = {'out1': array_ops.ones((10, 1)), 'out2': array_ops.ones((10, 1))}
543    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
544
545    metric_container.update_state(y_t, y_p, sample_weight=sw)
546    self.assertLen(metric_container.metrics, 1)
547
548    mae_metric = metric_container.metrics[0]
549    self.assertEqual(mae_metric.name, 'out2_mae')
550    self.assertEqual(mae_metric.result().numpy(), 1.)
551
552  def test_nested_structure(self):
553    metric_container = compile_utils.MetricsContainer(
554        metrics={
555            'b': ['mse', None],
556            'a': 'mae'
557        },
558        weighted_metrics={
559            'b': [None, None],
560            'a': 'mse'
561        })
562
563    y_t = {
564        'b': [2 * array_ops.ones((10, 1)),
565              array_ops.zeros((10, 1))],
566        'a': array_ops.zeros((10, 1))
567    }
568    y_p = {
569        'b': [array_ops.zeros((10, 1)),
570              array_ops.zeros((10, 1))],
571        'a': array_ops.ones((10, 1))
572    }
573    sw = ops.convert_to_tensor_v2_with_dispatch([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
574
575    metric_container.update_state(y_t, y_p, sample_weight=sw)
576    self.assertLen(metric_container.metrics, 3)
577
578    a_mae_metric = metric_container.metrics[0]
579    self.assertEqual(a_mae_metric.name, 'a_mae')
580    self.assertEqual(a_mae_metric.result().numpy(), 1.)
581
582    weighted_a_mae_metric = metric_container.metrics[1]
583    self.assertEqual(weighted_a_mae_metric.name, 'a_mse')
584    self.assertEqual(weighted_a_mae_metric.result().numpy(), 1.)
585
586    b_1_mse_metric = metric_container.metrics[2]
587    self.assertEqual(b_1_mse_metric.name, 'b_1_mse')
588    self.assertEqual(b_1_mse_metric.result().numpy(), 4.)
589
590  def test_crossentropy(self):
591    metric_container = compile_utils.MetricsContainer('crossentropy')
592    y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1))
593    metric_container.update_state(y_t, y_p)
594    self.assertEqual(metric_container.metrics[0]._fn,
595                     metrics_mod.binary_crossentropy)
596
597    metric_container = compile_utils.MetricsContainer('crossentropy')
598    y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 20))
599    self.assertEqual(y_p.shape.as_list()[-1], 20)
600    metric_container.update_state(y_t, y_p)
601    self.assertEqual(metric_container.metrics[0]._fn,
602                     metrics_mod.sparse_categorical_crossentropy)
603
604    metric_container = compile_utils.MetricsContainer('crossentropy')
605    y_t, y_p = array_ops.ones((10, 20)), array_ops.ones((10, 20))
606    metric_container.update_state(y_t, y_p)
607    self.assertEqual(metric_container.metrics[0]._fn,
608                     metrics_mod.categorical_crossentropy)
609
610  def test_accuracy(self):
611    metric_container = compile_utils.MetricsContainer('accuracy')
612    y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1))
613    metric_container.update_state(y_t, y_p)
614    self.assertEqual(metric_container.metrics[0]._fn,
615                     metrics_mod.binary_accuracy)
616
617    metric_container = compile_utils.MetricsContainer('Accuracy')
618    y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 1))
619    metric_container.update_state(y_t, y_p)
620    self.assertEqual(metric_container.metrics[0]._fn,
621                     metrics_mod.binary_accuracy)
622
623    metric_container = compile_utils.MetricsContainer('accuracy')
624    y_t, y_p = array_ops.ones((10, 1)), array_ops.ones((10, 20))
625    self.assertEqual(y_p.shape.as_list()[-1], 20)
626    metric_container.update_state(y_t, y_p)
627    self.assertEqual(metric_container.metrics[0]._fn,
628                     metrics_mod.sparse_categorical_accuracy)
629
630    metric_container = compile_utils.MetricsContainer('accuracy')
631    y_t, y_p = array_ops.ones((10, 20)), array_ops.ones((10, 20))
632    metric_container.update_state(y_t, y_p)
633    self.assertEqual(metric_container.metrics[0]._fn,
634                     metrics_mod.categorical_accuracy)
635
636  def test_metric_weighting(self):
637    metric_container = compile_utils.MetricsContainer(
638        metrics=['mae'], weighted_metrics=['mae'])
639
640    y_t = ops.convert_to_tensor_v2_with_dispatch([[0], [3], [0]])
641    y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [0], [0]])
642    sw = ops.convert_to_tensor_v2_with_dispatch([[1], [0], [1]])
643
644    metric_container.update_state(y_t, y_p, sample_weight=sw)
645    self.assertLen(metric_container.metrics, 2)
646
647    mae_metric = metric_container.metrics[0]
648    self.assertEqual(mae_metric.name, 'mae')
649    self.assertEqual(mae_metric.result().numpy(), 1.)
650
651    weighted_mae_metric = metric_container.metrics[1]
652    self.assertEqual(weighted_mae_metric.name, 'weighted_mae')
653    self.assertEqual(weighted_mae_metric.result().numpy(), 0.)
654
655  def test_broadcast_metrics_to_dict(self):
656    metric_container = compile_utils.MetricsContainer(metrics=['mae'])
657
658    y_p = {'output': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]])}
659    y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])}
660    metric_container.update_state(y_t, y_p)
661
662    mae_metric = metric_container.metrics[0]
663    self.assertEqual(mae_metric.name, 'mae')
664    self.assertEqual(mae_metric.result().numpy(), 1.)
665
666  def test_broadcast_metrics_to_dict_with_output_names(self):
667    metric_container = compile_utils.MetricsContainer(
668        metrics=['mae'], output_names=['output'])
669
670    y_p = ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]])
671    y_t = {'output': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]])}
672    metric_container.update_state(y_t, y_p)
673
674    mae_metric = metric_container.metrics[0]
675    self.assertEqual(mae_metric.name, 'mae')
676    self.assertEqual(mae_metric.result().numpy(), 1.)
677
678  def test_missing_label_with_no_metrics(self):
679    # It's ok to exclude a label if that label has no
680    # losses or metrics associated with it.
681    metric_container = compile_utils.MetricsContainer(metrics={
682        'output1': 'mae',
683        'output3': 'mse'
684    })
685
686    y_p = {
687        'output1': ops.convert_to_tensor_v2_with_dispatch([[0], [1], [2]]),
688        'output2': ops.convert_to_tensor_v2_with_dispatch([[3], [4], [5]]),
689        'output3': ops.convert_to_tensor_v2_with_dispatch([[6], [7], [8]])
690    }
691    y_t = {
692        'output1': ops.convert_to_tensor_v2_with_dispatch([[1], [2], [3]]),
693        'output3': ops.convert_to_tensor_v2_with_dispatch([[4], [5], [6]])
694    }
695
696    metric_container.update_state(y_t, y_p)
697    self.assertLen(metric_container.metrics, 2)
698
699    mae_metric = metric_container.metrics[0]
700    self.assertEqual(mae_metric.name, 'output1_mae')
701    self.assertEqual(mae_metric.result().numpy(), 1.)
702
703    mse_metric = metric_container.metrics[1]
704    self.assertEqual(mse_metric.name, 'output3_mse')
705    self.assertEqual(mse_metric.result().numpy(), 4.)
706
707  def test_metrics_masking(self):
708    metrics_container = compile_utils.MetricsContainer(
709        metrics=['mae'], weighted_metrics=['mse'])
710    y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
711    y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
712    y_p._keras_mask = constant_op.constant([[1, 1], [0, 0]],
713                                           dtype=dtypes.float32)
714
715    metrics_container.update_state(y_t, y_p)
716    self.assertLen(metrics_container.metrics, 2)
717
718    mae_metric = metrics_container.metrics[0]
719    self.assertEqual(mae_metric.name, 'mae')
720    self.assertAlmostEqual(mae_metric.result().numpy(), 0)
721
722    weighted_mae_metric = metrics_container.metrics[1]
723    self.assertEqual(weighted_mae_metric.name, 'mse')
724    self.assertAlmostEqual(weighted_mae_metric.result().numpy(), 0)
725
726  def test_metrics_sample_weight(self):
727    metrics_container = compile_utils.MetricsContainer(
728        metrics=['mae'], weighted_metrics=['mse'])
729    y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32)
730    y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
731    sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
732
733    metrics_container.update_state(y_t, y_p, sample_weight=sw)
734    self.assertLen(metrics_container.metrics, 2)
735
736    mae_metric = metrics_container.metrics[0]
737    self.assertEqual(mae_metric.name, 'mae')
738    self.assertAlmostEqual(mae_metric.result().numpy(), .25)  # 1 / 4
739
740    weighted_mae_metric = metrics_container.metrics[1]
741    self.assertEqual(weighted_mae_metric.name, 'mse')
742    self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .5)  # .5 / 1
743
744  def test_metrics_masking_sample_weight(self):
745    metrics_container = compile_utils.MetricsContainer(
746        metrics=['mae'], weighted_metrics=['mse'])
747    y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32)
748    y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
749    sw = constant_op.constant([[.3, .2], [.2, .3]], dtype=dtypes.float32)
750    y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
751                                           dtype=dtypes.float32)
752
753    metrics_container.update_state(y_t, y_p, sample_weight=sw)
754    self.assertLen(metrics_container.metrics, 2)
755
756    mae_metric = metrics_container.metrics[0]
757    self.assertEqual(mae_metric.name, 'mae')
758    self.assertAlmostEqual(mae_metric.result().numpy(), .5)  # 1 / .5
759
760    weighted_mae_metric = metrics_container.metrics[1]
761    self.assertEqual(weighted_mae_metric.name, 'mse')
762    self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .2 / .5)
763
764  def test_loss_class_as_metric_with_distribution(self):
765    distribution = one_device_strategy.OneDeviceStrategy('/device:CPU:0')
766    with distribution.scope():
767      metric_container = compile_utils.MetricsContainer(
768          losses_mod.MeanSquaredError())
769      y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
770      metric_container.update_state(y_t, y_p)
771
772      self.assertLen(metric_container.metrics, 1)
773      metric = metric_container.metrics[0]
774      self.assertEqual(metric.name, 'mean_squared_error')
775      self.assertEqual(metric.result().numpy(), 1.)
776
777  def test_custom_metric_callables(self):
778
779    def custom_metric_fn(y_true, y_pred):
780      return math_ops.reduce_sum(y_true - y_pred)
781
782    class CustomMetricClass(object):
783
784      def __call__(self, y_true, y_pred):
785        return math_ops.reduce_sum(y_true - y_pred)
786
787    metric_container = compile_utils.MetricsContainer(
788        [custom_metric_fn, CustomMetricClass()])
789    y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
790    metric_container.update_state(y_t, y_p)
791
792    self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
793    self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
794
795  def test_reset_state_existing_metric_before_built(self):
796    metric = metrics_mod.Mean()
797    metric.update_state([2.0, 4.0])
798    self.assertEqual(metric.result().numpy(), 3.0)
799
800    metric_container = compile_utils.MetricsContainer(metric)
801    metric_container.reset_state()
802    self.assertEqual(metric.result().numpy(), 0.0)
803
804
805if __name__ == '__main__':
806  ops.enable_eager_execution()
807  test.main()
808