• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests metrics correctness using Keras model."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy as np
23
24from tensorflow.python import tf2
25from tensorflow.python.keras import keras_parameterized
26from tensorflow.python.keras import layers
27from tensorflow.python.keras import losses
28from tensorflow.python.keras import metrics
29from tensorflow.python.keras import testing_utils
30from tensorflow.python.ops.losses import loss_reduction
31from tensorflow.python.platform import test
32
33
34def get_multi_io_model():
35  inp_1 = layers.Input(shape=(1,), name='input_1')
36  inp_2 = layers.Input(shape=(1,), name='input_2')
37  x = layers.Dense(3, kernel_initializer='ones', trainable=False)
38  out_1 = layers.Dense(
39      1, kernel_initializer='ones', name='output_1', trainable=False)
40  out_2 = layers.Dense(
41      1, kernel_initializer='ones', name='output_2', trainable=False)
42
43  branch_a = [inp_1, x, out_1]
44  branch_b = [inp_2, x, out_2]
45  return testing_utils.get_multi_io_model(branch_a, branch_b)
46
47
48def custom_generator_multi_io(sample_weights=None):
49  batch_size = 2
50  num_samples = 4
51  inputs = np.asarray([[1.], [2.], [3.], [4.]])
52  targets_1 = np.asarray([[2.], [4.], [6.], [8.]])
53  targets_2 = np.asarray([[1.], [2.], [3.], [4.]])
54  if sample_weights:
55    assert len(sample_weights) == 2
56    w1 = sample_weights[0]
57    w2 = sample_weights[1]
58  else:
59    w1 = None
60    w2 = None
61  i = 0
62  while True:
63    batch_index = i * batch_size % num_samples
64    i += 1
65    start = batch_index
66    end = start + batch_size
67    x = [inputs[start:end], inputs[start:end]]
68    y = [targets_1[start:end], targets_2[start:end]]
69    if sample_weights:
70      w = [
71          None if w1 is None else w1[start:end],
72          None if w2 is None else w2[start:end]
73      ]
74    else:
75      w = None
76    yield x, y, w
77
78
79@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
80@keras_parameterized.run_all_keras_modes
81class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
82
83  def _get_compiled_multi_io_model(self):
84    model = get_multi_io_model()
85    model.compile(
86        optimizer='rmsprop',
87        loss='mse',
88        metrics=[metrics.MeanSquaredError(name='mean_squared_error')],
89        weighted_metrics=[
90            metrics.MeanSquaredError(name='mean_squared_error_2')
91        ],
92        run_eagerly=testing_utils.should_run_eagerly(),
93        experimental_run_tf_function=testing_utils.should_run_tf_function())
94    return model
95
96  def setUp(self):
97    super(TestMetricsCorrectnessMultiIO, self).setUp()
98    self.x = np.asarray([[1.], [2.], [3.], [4.]])
99    self.y1 = np.asarray([[2.], [4.], [6.], [8.]])
100    self.y2 = np.asarray([[1.], [2.], [3.], [4.]])
101    self.sample_weight_1 = np.asarray([2., 3., 4., 5.])
102    self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
103    self.class_weight_1 = {2: 2, 4: 3, 6: 4, 8: 5}
104    self.class_weight_2 = {1: 3.5, 2: 2.5, 3: 1.5, 4: 0.5}
105
106    # y_true_1 = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
107    # y_true_2 = [[1.], [2.], [3.], [4.]], y_pred = [[3.], [6.], [9.], [12.]]
108
109    # Weighted metric `output_1`:
110    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
111    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
112    #         = 130
113    #   Count = (2 + 3) + (4 + 5)
114    #   Result = 9.2857141
115
116    # Weighted metric `output_2`:
117    #   Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
118    #           ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5)
119    #         = 140
120    #   Count = (3.5 + 2.5) + (1.5 + 0.5)
121    #   Result = 17.5
122
123    # Loss `output_1` with weights:
124    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
125    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
126    #         = 130
127    #   Count = 2 + 2
128    #   Result = 32.5
129
130    # Loss `output_1` without weights/Metric `output_1`:
131    #   Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30
132    #   Count = 2 + 2
133    #   Result = 7.5
134
135    # Loss `output_2` with weights:
136    #   Total = ((3 - 1)^2 * 3.5 + (6 - 2)^2 * 2.5) +
137    #           ((9 - 3)^2 * 1.5 + (12 - 4)^2 * 0.5)
138    #         = 140
139    #   Count = 2 + 2
140    #   Result = 35
141
142    # Loss `output_2` without weights/Metric `output_2`:
143    #   Total = ((3 - 1)^2 + (6 - 2)^2) + ((9 - 3)^2 + (12 - 4)^2) = 120
144    #   Count = 2 + 2
145    #   Result = 30
146
147    # Total loss with weights = 32.5 + 35 = 67.5
148    # Total loss without weights = 7.5 + 30 = 37.5
149
150    self.wmse = 'mean_squared_error_2'
151    if not tf2.enabled():
152      self.wmse = 'weighted_' + self.wmse
153    self.expected_fit_result_with_weights = {
154        'output_1_mean_squared_error': [7.5, 7.5],
155        'output_2_mean_squared_error': [30, 30],
156        'output_1_' + self.wmse: [9.286, 9.286],
157        'output_2_' + self.wmse: [17.5, 17.5],
158        'loss': [67.5, 67.5],
159        'output_1_loss': [32.5, 32.5],
160        'output_2_loss': [35, 35],
161    }
162
163    self.expected_fit_result_with_weights_output_2 = {
164        'output_1_mean_squared_error': [7.5, 7.5],
165        'output_2_mean_squared_error': [30, 30],
166        'output_1_' + self.wmse: [7.5, 7.5],
167        'output_2_' + self.wmse: [17.5, 17.5],
168        'loss': [42.5, 42.5],
169        'output_1_loss': [7.5, 7.5],
170        'output_2_loss': [35, 35],
171    }
172
173    self.expected_fit_result = {
174        'output_1_mean_squared_error': [7.5, 7.5],
175        'output_2_mean_squared_error': [30, 30],
176        'output_1_' + self.wmse: [7.5, 7.5],
177        'output_2_' + self.wmse: [30, 30],
178        'loss': [37.5, 37.5],
179        'output_1_loss': [7.5, 7.5],
180        'output_2_loss': [30, 30],
181    }
182
183    # In the order: 'loss', 'output_1_loss', 'output_2_loss',
184    # 'output_1_mean_squared_error', 'output_1_mean_squared_error_2',
185    # 'output_2_mean_squared_error', 'output_2_mean_squared_error_2'
186    self.expected_batch_result_with_weights = [
187        67.5, 32.5, 35, 7.5, 9.286, 30, 17.5
188    ]
189    self.expected_batch_result_with_weights_output_2 = [
190        42.5, 7.5, 35, 7.5, 7.5, 30, 17.5
191    ]
192    self.expected_batch_result = [37.5, 7.5, 30, 7.5, 7.5, 30, 30]
193
194  def test_fit(self):
195    model = self._get_compiled_multi_io_model()
196    history = model.fit([self.x, self.x], [self.y1, self.y2],
197                        batch_size=2,
198                        epochs=2,
199                        shuffle=False)
200    for key, value in self.expected_fit_result.items():
201      self.assertAllClose(history.history[key], value, 1e-3)
202
203  def test_fit_with_sample_weight(self):
204    model = self._get_compiled_multi_io_model()
205    history = model.fit([self.x, self.x], [self.y1, self.y2],
206                        sample_weight={
207                            'output_1': self.sample_weight_1,
208                            'output_2': self.sample_weight_2,
209                        },
210                        batch_size=2,
211                        epochs=2,
212                        shuffle=False)
213    for key, value in self.expected_fit_result_with_weights.items():
214      self.assertAllClose(history.history[key], value, 1e-3)
215
216    # Set weights for one output (use batch size).
217    history = model.fit([self.x, self.x], [self.y1, self.y2],
218                        sample_weight={'output_2': self.sample_weight_2},
219                        batch_size=2,
220                        epochs=2,
221                        shuffle=False)
222
223    for key, value in self.expected_fit_result_with_weights_output_2.items():
224      self.assertAllClose(history.history[key], value, 1e-3)
225
226  def test_fit_with_class_weight(self):
227    model = self._get_compiled_multi_io_model()
228    history = model.fit([self.x, self.x], [self.y1, self.y2],
229                        class_weight={
230                            'output_1': self.class_weight_1,
231                            'output_2': self.class_weight_2,
232                        },
233                        batch_size=2,
234                        epochs=2,
235                        shuffle=False)
236    for key, value in self.expected_fit_result_with_weights.items():
237      self.assertAllClose(history.history[key], value, 1e-3)
238
239    # Set weights for one output.
240    history = model.fit([self.x, self.x], [self.y1, self.y2],
241                        class_weight={'output_2': self.class_weight_2},
242                        batch_size=2,
243                        epochs=2,
244                        shuffle=False)
245
246    for key, value in self.expected_fit_result_with_weights_output_2.items():
247      self.assertAllClose(history.history[key], value, 1e-3)
248
249  def test_eval(self):
250    model = self._get_compiled_multi_io_model()
251    eval_result = model.evaluate([self.x, self.x], [self.y1, self.y2],
252                                 batch_size=2)
253    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
254
255  def test_eval_with_sample_weight(self):
256    model = self._get_compiled_multi_io_model()
257    eval_result = model.evaluate([self.x, self.x], [self.y1, self.y2],
258                                 batch_size=2,
259                                 sample_weight={
260                                     'output_1': self.sample_weight_1,
261                                     'output_2': self.sample_weight_2,
262                                 })
263    self.assertAllClose(eval_result, self.expected_batch_result_with_weights,
264                        1e-3)
265
266    # Set weights for one output.
267    model = self._get_compiled_multi_io_model()
268    eval_result = model.evaluate([self.x, self.x], [self.y1, self.y2],
269                                 batch_size=2,
270                                 sample_weight={
271                                     'output_2': self.sample_weight_2,
272                                 })
273    self.assertAllClose(eval_result,
274                        self.expected_batch_result_with_weights_output_2, 1e-3)
275
276    # Verify that metric value is same with arbitrary weights and batch size.
277    x = np.random.random((50, 1))
278    y = np.random.random((50, 1))
279    w = np.random.random((50,))
280    mse1 = model.evaluate([x, x], [y, y], sample_weight=[w, w], batch_size=5)[3]
281    mse2 = model.evaluate([x, x], [y, y], sample_weight=[w, w],
282                          batch_size=10)[3]
283    self.assertAllClose(mse1, mse2, 1e-3)
284
285  def test_train_on_batch(self):
286    model = self._get_compiled_multi_io_model()
287    result = model.train_on_batch([self.x, self.x], [self.y1, self.y2])
288    self.assertAllClose(result, self.expected_batch_result, 1e-3)
289
290  def test_train_on_batch_with_sample_weight(self):
291    model = self._get_compiled_multi_io_model()
292    result = model.train_on_batch([self.x, self.x], [self.y1, self.y2],
293                                  sample_weight={
294                                      'output_1': self.sample_weight_1,
295                                      'output_2': self.sample_weight_2,
296                                  })
297    self.assertAllClose(result, self.expected_batch_result_with_weights, 1e-3)
298
299    # Set weights for one output.
300    result = model.train_on_batch([self.x, self.x], [self.y1, self.y2],
301                                  sample_weight={
302                                      'output_2': self.sample_weight_2,
303                                  })
304    self.assertAllClose(result,
305                        self.expected_batch_result_with_weights_output_2, 1e-3)
306
307  def test_train_on_batch_with_class_weight(self):
308    model = self._get_compiled_multi_io_model()
309    result = model.train_on_batch([self.x, self.x], [self.y1, self.y2],
310                                  class_weight={
311                                      'output_1': self.class_weight_1,
312                                      'output_2': self.class_weight_2,
313                                  })
314    self.assertAllClose(result, self.expected_batch_result_with_weights, 1e-3)
315
316    # Set weights for one output.
317    result = model.train_on_batch([self.x, self.x], [self.y1, self.y2],
318                                  class_weight={
319                                      'output_2': self.class_weight_2,
320                                  })
321    self.assertAllClose(result,
322                        self.expected_batch_result_with_weights_output_2, 1e-3)
323
324  def test_test_on_batch(self):
325    model = self._get_compiled_multi_io_model()
326    result = model.test_on_batch([self.x, self.x], [self.y1, self.y2])
327    self.assertAllClose(result, self.expected_batch_result, 1e-3)
328
329  def test_test_on_batch_with_sample_weight(self):
330    model = self._get_compiled_multi_io_model()
331    result = model.test_on_batch([self.x, self.x], [self.y1, self.y2],
332                                 sample_weight={
333                                     'output_1': self.sample_weight_1,
334                                     'output_2': self.sample_weight_2,
335                                 })
336    self.assertAllClose(result, self.expected_batch_result_with_weights, 1e-3)
337
338    # Set weights for one output.
339    result = model.test_on_batch([self.x, self.x], [self.y1, self.y2],
340                                 sample_weight={
341                                     'output_2': self.sample_weight_2,
342                                 })
343    self.assertAllClose(result,
344                        self.expected_batch_result_with_weights_output_2, 1e-3)
345
346  def test_fit_generator(self):
347    model = self._get_compiled_multi_io_model()
348    history = model.fit_generator(
349        custom_generator_multi_io(), steps_per_epoch=2, epochs=2)
350    for key, value in self.expected_fit_result.items():
351      self.assertAllClose(history.history[key], value, 1e-3)
352
353  def test_fit_generator_with_sample_weight(self):
354    model = self._get_compiled_multi_io_model()
355    history = model.fit_generator(
356        custom_generator_multi_io(
357            sample_weights=[self.sample_weight_1, self.sample_weight_2]),
358        steps_per_epoch=2,
359        epochs=2)
360    for key, value in self.expected_fit_result_with_weights.items():
361      self.assertAllClose(history.history[key], value, 1e-3)
362
363    # Set weights for one output.
364    history = model.fit_generator(
365        custom_generator_multi_io(sample_weights=[None, self.sample_weight_2]),
366        steps_per_epoch=2,
367        epochs=2)
368    for key, value in self.expected_fit_result_with_weights_output_2.items():
369      self.assertAllClose(history.history[key], value, 1e-3)
370
371  def test_fit_generator_with_class_weight(self):
372    model = self._get_compiled_multi_io_model()
373    history = model.fit_generator(
374        custom_generator_multi_io(),
375        class_weight={
376            'output_1': self.class_weight_1,
377            'output_2': self.class_weight_2,
378        },
379        steps_per_epoch=2,
380        epochs=2)
381    for key, value in self.expected_fit_result_with_weights.items():
382      self.assertAllClose(history.history[key], value, 1e-3)
383
384    # Set weights for one output.
385    history = model.fit_generator(
386        custom_generator_multi_io(),
387        class_weight={'output_2': self.class_weight_2},
388        steps_per_epoch=2,
389        epochs=2)
390    for key, value in self.expected_fit_result_with_weights_output_2.items():
391      self.assertAllClose(history.history[key], value, 1e-3)
392
393  def test_eval_generator(self):
394    model = self._get_compiled_multi_io_model()
395    eval_result = model.evaluate_generator(custom_generator_multi_io(), steps=2)
396    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
397
398  def test_eval_generator_with_sample_weight(self):
399    model = self._get_compiled_multi_io_model()
400    eval_result = model.evaluate_generator(
401        custom_generator_multi_io(
402            sample_weights=[self.sample_weight_1, self.sample_weight_2]),
403        steps=2)
404    self.assertAllClose(eval_result, self.expected_batch_result_with_weights,
405                        1e-3)
406
407    # Set weights for one output.
408    eval_result = model.evaluate_generator(
409        custom_generator_multi_io(sample_weights=[None, self.sample_weight_2]),
410        steps=2)
411    self.assertAllClose(eval_result,
412                        self.expected_batch_result_with_weights_output_2, 1e-3)
413
414
415@keras_parameterized.run_with_all_model_types
416@keras_parameterized.run_all_keras_modes
417class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase):
418
419  def _get_model(self):
420    x = layers.Dense(3, kernel_initializer='ones', trainable=False)
421    out = layers.Dense(
422        1, kernel_initializer='ones', name='output', trainable=False)
423    model = testing_utils.get_model_from_layers([x, out], input_shape=(1,))
424    model.compile(
425        optimizer='rmsprop',
426        loss='mse',
427        metrics=[metrics.MeanSquaredError(name='mean_squared_error')],
428        weighted_metrics=[
429            metrics.MeanSquaredError(name='mean_squared_error_2')
430        ],
431        run_eagerly=testing_utils.should_run_eagerly(),
432        experimental_run_tf_function=testing_utils.should_run_tf_function())
433    return model
434
435  def _custom_generator(self, sample_weight=None):
436    batch_size = 2
437    num_samples = 4
438    x = np.asarray([[1.], [2.], [3.], [4.]])
439    y = np.asarray([[2.], [4.], [6.], [8.]])
440    w = sample_weight
441    i = 0
442
443    while True:
444      batch_index = i * batch_size % num_samples
445      i += 1
446      start = batch_index
447      end = start + batch_size
448      yield x[start:end], y[start:end], None if w is None else w[start:end]
449
450  def setUp(self):
451    super(TestMetricsCorrectnessSingleIO, self).setUp()
452    self.x = np.asarray([[1.], [2.], [3.], [4.]])
453    self.y = np.asarray([[2.], [4.], [6.], [8.]])
454    self.sample_weight = np.asarray([2., 3., 4., 5.])
455    self.class_weight = {2: 2, 4: 3, 6: 4, 8: 5}
456
457    # y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
458
459    # Metric:
460    #   Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30,
461    #   Count = 2 + 2
462    #   Result = 7.5
463
464    # Weighted metric:
465    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
466    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
467    #         = 130
468    #   Count = (2 + 3) + (4 + 5)
469    #   Result = 9.2857141
470
471    # Total loss with weights:
472    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
473    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
474    #         = 130,
475    #   Count = 2 + 2
476    #   Result = 32.5
477
478    # Total loss without weights:
479    #   Total = ((3 - 2)^2 + (6 - 4)^2) +
480    #           ((9 - 6)^2 + (12 - 8)^2)
481    #         = 30,
482    #   Count = 2 + 2
483    #   Result = 7.5
484
485    wmse = 'mean_squared_error_2'
486    if not tf2.enabled():
487      wmse = 'weighted_' + wmse
488
489    self.expected_fit_result_with_weights = {
490        'mean_squared_error': [7.5, 7.5],
491        wmse: [9.286, 9.286],
492        'loss': [32.5, 32.5]
493    }
494
495    self.expected_fit_result = {
496        'mean_squared_error': [7.5, 7.5],
497        wmse: [7.5, 7.5],
498        'loss': [7.5, 7.5]
499    }
500
501    # In the order: 'loss', 'mean_squared_error', 'mean_squared_error_2'
502    self.expected_batch_result_with_weights = [32.5, 7.5, 9.286]
503    self.expected_batch_result = [7.5, 7.5, 7.5]
504
505  def test_fit(self):
506    model = self._get_model()
507
508    history = model.fit(
509        self.x,
510        self.y,
511        batch_size=2,
512        epochs=2,
513        shuffle=False)
514    for key, value in self.expected_fit_result.items():
515      self.assertAllClose(history.history[key], value, 1e-3)
516
517  def test_fit_with_sample_weight(self):
518    model = self._get_model()
519    history = model.fit(
520        self.x,
521        self.y,
522        sample_weight=self.sample_weight,
523        batch_size=2,
524        epochs=2,
525        shuffle=False)
526    for key, value in self.expected_fit_result_with_weights.items():
527      self.assertAllClose(history.history[key], value, 1e-3)
528
529  def test_fit_with_class_weight(self):
530    model = self._get_model()
531    history = model.fit(
532        self.x,
533        self.y,
534        class_weight=self.class_weight,
535        batch_size=2,
536        epochs=2,
537        shuffle=False)
538    for key, value in self.expected_fit_result_with_weights.items():
539      self.assertAllClose(history.history[key], value, 1e-3)
540
541  def test_eval(self):
542    model = self._get_model()
543    eval_result = model.evaluate(self.x, self.y, batch_size=2)
544    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
545
546  def test_eval_with_sample_weight(self):
547    model = self._get_model()
548    eval_result = model.evaluate(
549        self.x, self.y, batch_size=2, sample_weight=self.sample_weight)
550    self.assertAllClose(eval_result, self.expected_batch_result_with_weights,
551                        1e-3)
552
553    # Verify that metric value is same with arbitrary weights and batch size.
554    x = np.random.random((50, 1))
555    y = np.random.random((50, 1))
556    w = np.random.random((50,))
557    mse1 = model.evaluate(x, y, sample_weight=w, batch_size=5)[1]
558    mse2 = model.evaluate(x, y, sample_weight=w, batch_size=10)[1]
559    self.assertAllClose(mse1, mse2, 1e-3)
560
561  def test_train_on_batch(self):
562    model = self._get_model()
563    result = model.train_on_batch(self.x, self.y)
564    self.assertAllClose(result, self.expected_batch_result, 1e-3)
565
566  def test_train_on_batch_with_sample_weight(self):
567    model = self._get_model()
568    result = model.train_on_batch(
569        self.x, self.y, sample_weight=self.sample_weight)
570    self.assertAllClose(result, self.expected_batch_result_with_weights, 1e-3)
571
572  def test_train_on_batch_with_class_weight(self):
573    model = self._get_model()
574    result = model.train_on_batch(
575        self.x, self.y, class_weight=self.class_weight)
576    self.assertAllClose(result, self.expected_batch_result_with_weights, 1e-3)
577
578  def test_test_on_batch(self):
579    model = self._get_model()
580    result = model.test_on_batch(self.x, self.y)
581    self.assertAllClose(result, self.expected_batch_result, 1e-3)
582
583  def test_test_on_batch_with_sample_weight(self):
584    model = self._get_model()
585    result = model.test_on_batch(
586        self.x, self.y, sample_weight=self.sample_weight)
587    self.assertAllClose(result, self.expected_batch_result_with_weights, 1e-3)
588
589  def test_fit_generator(self):
590    model = self._get_model()
591    history = model.fit_generator(
592        self._custom_generator(), steps_per_epoch=2, epochs=2)
593    for key, value in self.expected_fit_result.items():
594      self.assertAllClose(history.history[key], value, 1e-3)
595
596  def test_fit_generator_with_sample_weight(self):
597    model = self._get_model()
598    history = model.fit_generator(
599        self._custom_generator(sample_weight=self.sample_weight),
600        steps_per_epoch=2,
601        epochs=2)
602    for key, value in self.expected_fit_result_with_weights.items():
603      self.assertAllClose(history.history[key], value, 1e-3)
604
605  def test_fit_generator_with_class_weight(self):
606    model = self._get_model()
607    history = model.fit_generator(
608        self._custom_generator(),
609        steps_per_epoch=2,
610        epochs=2,
611        class_weight=self.class_weight)
612    for key, value in self.expected_fit_result_with_weights.items():
613      self.assertAllClose(history.history[key], value, 1e-3)
614
615  def test_eval_generator(self):
616    model = self._get_model()
617    eval_result = model.evaluate_generator(self._custom_generator(), steps=2)
618    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
619
620  def test_eval_generator_with_sample_weight(self):
621    model = self._get_model()
622    eval_result = model.evaluate_generator(
623        self._custom_generator(sample_weight=self.sample_weight), steps=2)
624    self.assertAllClose(eval_result, self.expected_batch_result_with_weights,
625                        1e-3)
626
627
628@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
629@keras_parameterized.run_all_keras_modes
630@parameterized.parameters([
631    loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE,
632    loss_reduction.ReductionV2.AUTO,
633    loss_reduction.ReductionV2.SUM
634])
635class TestOutputLossMetrics(keras_parameterized.TestCase):
636
637  def _get_compiled_multi_io_model(self, loss):
638    model = get_multi_io_model()
639    model.compile(
640        optimizer='rmsprop',
641        loss=loss,
642        run_eagerly=testing_utils.should_run_eagerly(),
643        experimental_run_tf_function=testing_utils.should_run_tf_function())
644    return model
645
646  def setUp(self):
647    super(TestOutputLossMetrics, self).setUp()
648    self.x = np.asarray([[1.], [2.], [3.], [4.]])
649    self.y1 = np.asarray([[2.], [4.], [6.], [8.]])
650    self.y2 = np.asarray([[1.], [2.], [3.], [4.]])
651    self.sample_weight_1 = np.asarray([2., 3., 4., 5.])
652    self.sample_weight_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
653
654    # y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
655
656    # Loss `output_1`:
657    #   Per-sample weighted losses
658    #   Batch 1 = [(3 - 2)^2 * 2, (6 - 4)^2 * 3)] = [2, 12]
659    #   Batch 2 = [((9 - 6)^2 * 4, (12 - 8)^2 * 5)] = [36, 80]
660
661    #   Result (reduction=SUM) = ((2 + 12) + (36 + 80))/2 = 65
662    #   Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 130 / 4 = 32.5
663
664    # Loss `output_2`:
665    #   Per-sample weighted losses
666    #   Batch 1 = [(3 - 1)^2 * 3.5, (6 - 2)^2 * 2.5)] = [14, 40]
667    #   Batch 2 = [(9 - 3)^2 * 1.5, (12 - 4)^2 * 0.5)] = [54, 32]
668
669    #   Result (reduction=SUM) = ((14 + 40) + (54 + 32))/2 = 70
670    #   Result (reduction=SUM_OVER_BATCH_SIZE/AUTO/NONE) = 140 / 4 = 35
671
672    # When reduction is 'NONE' loss value that is passed to the optimizer will
673    # be vector loss but what is reported is a scalar, which is an average of
674    # all the values in all the batch vectors.
675
676    # Total loss = Output_loss_1 + Output_loss_2
677
678    sum_over_batch_size_fit_result = {
679        'loss': [67.5, 67.5],
680        'output_1_loss': [32.5, 32.5],
681        'output_2_loss': [35, 35],
682    }
683
684    self.expected_fit_result = {
685        loss_reduction.ReductionV2.NONE:
686            sum_over_batch_size_fit_result,
687        loss_reduction.ReductionV2.SUM: {
688            'loss': [135, 135],
689            'output_1_loss': [65, 65],
690            'output_2_loss': [70, 70],
691        },
692        loss_reduction.ReductionV2.AUTO:
693            sum_over_batch_size_fit_result,
694        loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE:
695            sum_over_batch_size_fit_result,
696    }
697
698    # In the order: 'loss', 'output_1_loss', 'output_2_loss',
699    self.expected_batch_result = {
700        loss_reduction.ReductionV2.NONE: [67.5, 32.5, 35],
701        loss_reduction.ReductionV2.SUM: [135, 65, 70],
702        loss_reduction.ReductionV2.AUTO: [67.5, 32.5, 35],
703        loss_reduction.ReductionV2.SUM_OVER_BATCH_SIZE: [67.5, 32.5, 35],
704    }
705
706  def test_fit(self, reduction):
707    model = self._get_compiled_multi_io_model(
708        loss=losses.MeanSquaredError(reduction=reduction))
709    history = model.fit([self.x, self.x], [self.y1, self.y2],
710                        sample_weight={
711                            'output_1': self.sample_weight_1,
712                            'output_2': self.sample_weight_2,
713                        },
714                        batch_size=2,
715                        epochs=2,
716                        shuffle=False)
717    for key, value in self.expected_fit_result[reduction].items():
718      self.assertAllClose(history.history[key], value)
719
720  def test_eval(self, reduction):
721    model = self._get_compiled_multi_io_model(
722        loss=losses.MeanSquaredError(reduction=reduction))
723    eval_result = model.evaluate([self.x, self.x], [self.y1, self.y2],
724                                 batch_size=2,
725                                 sample_weight={
726                                     'output_1': self.sample_weight_1,
727                                     'output_2': self.sample_weight_2,
728                                 })
729    self.assertAllClose(eval_result, self.expected_batch_result[reduction])
730
731  def test_train_on_batch(self, reduction):
732    model = self._get_compiled_multi_io_model(
733        loss=losses.MeanSquaredError(reduction=reduction))
734    result = model.train_on_batch([self.x, self.x], [self.y1, self.y2],
735                                  sample_weight={
736                                      'output_1': self.sample_weight_1,
737                                      'output_2': self.sample_weight_2,
738                                  })
739
740    expected_values = self.expected_batch_result[reduction]
741    if reduction == loss_reduction.ReductionV2.SUM:
742      # We are taking all the data as one batch, so undo the averaging here.
743      expected_values = [x * 2 for x in self.expected_batch_result[reduction]]
744    self.assertAllClose(result, expected_values)
745
746  def test_test_on_batch(self, reduction):
747    model = self._get_compiled_multi_io_model(
748        loss=losses.MeanSquaredError(reduction=reduction))
749    result = model.test_on_batch([self.x, self.x], [self.y1, self.y2],
750                                 sample_weight={
751                                     'output_1': self.sample_weight_1,
752                                     'output_2': self.sample_weight_2,
753                                 })
754    expected_values = self.expected_batch_result[reduction]
755    if reduction == loss_reduction.ReductionV2.SUM:
756      # We are taking all the data as one batch, so undo the averaging here.
757      expected_values = [x * 2 for x in self.expected_batch_result[reduction]]
758    self.assertAllClose(result, expected_values)
759
760  def test_fit_generator(self, reduction):
761    model = self._get_compiled_multi_io_model(
762        loss=losses.MeanSquaredError(reduction=reduction))
763    history = model.fit_generator(
764        custom_generator_multi_io(
765            sample_weights=[self.sample_weight_1, self.sample_weight_2]),
766        steps_per_epoch=2,
767        epochs=2)
768    for key, value in self.expected_fit_result[reduction].items():
769      self.assertAllClose(history.history[key], value)
770
771  def test_eval_generator(self, reduction):
772    model = self._get_compiled_multi_io_model(
773        loss=losses.MeanSquaredError(reduction=reduction))
774    eval_result = model.evaluate_generator(
775        custom_generator_multi_io(
776            sample_weights=[self.sample_weight_1, self.sample_weight_2]),
777        steps=2)
778    self.assertAllClose(eval_result, self.expected_batch_result[reduction])
779
780
781if __name__ == '__main__':
782  test.main()
783