• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras training and evaluation routines for eager execution."""
16# pylint: disable=protected-access
17
18import numpy as np
19
20from tensorflow.python.eager.backprop import GradientTape
21from tensorflow.python.framework import ops
22from tensorflow.python.keras import backend
23from tensorflow.python.keras.engine import training_utils
24from tensorflow.python.keras.engine import training_utils_v1
25from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
26from tensorflow.python.keras.utils import losses_utils
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.util import nest
30
31
32def _eager_loss_fn(outputs, targets, loss_fn, output_name):
33  with backend.name_scope(output_name + '_loss'):
34    loss = loss_fn(targets, outputs)
35  return loss
36
37
38def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None):
39  """Calculates the metrics for each output of the given model.
40
41  Args:
42      model: The model on which metrics are being calculated.
43      outputs: The outputs of the given model.
44      targets: The predictions or targets of the given model.
45      sample_weights: Optional list of sample weights for each output.
46      masks: Optional list of masks for each output.
47
48  Returns:
49      Returns the metric results for each output of the model.
50  """
51  outputs = nest.flatten(outputs)
52  targets = nest.flatten(targets)
53  # Invoke all(weighted and unweighted) metrics.
54  metric_results = []
55  if targets:
56    # Insert None values corresponding to the targets that need to be skipped
57    # on the model.
58    if len(model._targets) != len(targets):
59      new_targets = [
60          None if t is None else targets.pop(0) for t in model._targets
61      ]
62      targets = new_targets
63
64    metric_results = model._handle_metrics(
65        outputs,
66        targets=targets,
67        sample_weights=sample_weights,
68        masks=masks,
69        return_weighted_and_unweighted_metrics=True,
70        skip_target_masks=model._prepare_skip_target_masks())
71
72  # Add metric results from the `add_metric` metrics.
73  metric_results.extend([
74      m.result()
75      for m in model.metrics
76      if m not in model._compile_metric_functions
77  ])
78  return metric_results
79
80
81def _model_loss(model,
82                inputs,
83                targets,
84                output_loss_metrics=None,
85                sample_weights=None,
86                training=False):
87  """Calculates the loss for a given model.
88
89  Args:
90      model: The model on which metrics are being calculated.
91      inputs: Either a dictionary of inputs to the model or a list of input
92        arrays.
93      targets: List of target arrays.
94      output_loss_metrics: List of metrics that are used to aggregated output
95        loss values.
96      sample_weights: Optional list of sample weight arrays.
97      training: Whether the model should be run in inference or training mode.
98
99  Returns:
100     Returns the model output, total loss, loss value calculated using the
101     specified loss function and masks for each output. The total loss includes
102     regularization losses and applies masking and sample weighting
103     to the loss value.
104  """
105  # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn.
106  # Used to keep track of the total loss value (stateless).
107  # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) +
108  #                   loss_weight_2 * output_2_loss_fn(...) +
109  #                   layer losses.
110  total_loss = 0
111  kwargs = {}
112  if model._expects_training_arg:
113    kwargs['training'] = training
114  if len(inputs) == 1 and not isinstance(inputs, dict):
115    inputs = inputs[0]
116
117  # Allow mixed `NumPy` and `EagerTensor` input here.
118  if any(
119      isinstance(input_t, (np.ndarray, float, int))
120      for input_t in nest.flatten(inputs)):
121    inputs = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, inputs)
122
123  outs = model(inputs, **kwargs)
124  outs = nest.flatten(outs)
125
126  if targets:
127    targets = training_utils_v1.cast_if_floating_dtype_and_mismatch(
128        targets, outs)
129  # TODO(sallymatson/psv): check if we should do same mismatch fix for weights
130  if sample_weights:
131    sample_weights = [
132        training_utils_v1.cast_if_floating_dtype(
133            ops.convert_to_tensor_v2_with_dispatch(val))
134        if val is not None else None for val in sample_weights
135    ]
136
137  masks = [getattr(t, '_keras_mask', None) for t in outs]
138  targets = nest.flatten(targets)
139
140  # Used to keep track of individual output losses.
141  output_losses = []
142
143  with backend.name_scope('loss'):
144    loss_fns = [
145        loss_fn for loss_fn in model.loss_functions if loss_fn is not None
146    ]
147    custom_losses = model.losses  # Regularization losses
148
149    if not loss_fns and not custom_losses:
150      if training:
151        raise ValueError('The model cannot be trained '
152                         'because it has no loss to optimize.')
153      else:
154        raise ValueError('The model cannot be evaluated '
155                         'because it has no loss to compute.')
156
157    for i, loss_fn in enumerate(loss_fns):
158      weights = sample_weights[i] if sample_weights else None
159      mask = masks[i]
160      with backend.name_scope(model.output_names[i] + '_loss'):
161        if mask is not None:
162          mask = math_ops.cast(mask, outs[i].dtype)
163          # Update weights with mask.
164          if weights is None:
165            weights = mask
166          else:
167            # Update dimensions of weights to match with mask if possible.
168            weights = math_ops.cast(weights, outs[i].dtype)
169            mask, _, weights = (
170                losses_utils.squeeze_or_expand_dimensions(
171                    mask, sample_weight=weights))
172            weights *= mask
173
174        if hasattr(loss_fn, 'reduction'):
175          per_sample_losses = loss_fn.call(targets[i], outs[i])
176          weighted_losses = losses_utils.compute_weighted_loss(
177              per_sample_losses,
178              sample_weight=weights,
179              reduction=losses_utils.ReductionV2.NONE)
180          loss_reduction = loss_fn.reduction
181
182          # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all
183          # compile use cases.
184          if loss_reduction == losses_utils.ReductionV2.AUTO:
185            loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
186
187          # Compute the stateless loss value.
188          output_loss = losses_utils.reduce_weighted_loss(
189              weighted_losses, reduction=loss_reduction)
190        else:
191          # Compute the stateless loss value for a custom loss class.
192          # Here we assume that the class takes care of loss reduction
193          # because if this class returns a vector value we cannot
194          # differentiate between use case where a custom optimizer
195          # expects a vector loss value vs unreduced per-sample loss value.
196          output_loss = loss_fn(targets[i], outs[i], sample_weight=weights)
197          loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE
198
199      # If the number of outputs is 1 then we don't append the loss metric
200      # associated with each model output. When there are multiple outputs
201      # associated with a model, each output's loss is calculated and returned
202      # as part of the loss_metrics.
203      if len(model.outputs) > 1:
204        # Keep track of the stateful output loss result.
205        output_losses.append(output_loss_metrics[i](output_loss))
206
207      # Scale output loss for distribution. For custom losses we assume
208      # reduction was mean.
209      if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE:
210        output_loss = losses_utils.scale_loss_for_distribution(output_loss)
211      total_loss += model._loss_weights_list[i] * output_loss
212
213    # Add regularization losses
214    if custom_losses:
215      total_loss += losses_utils.scale_loss_for_distribution(
216          math_ops.add_n(custom_losses))
217  return outs, total_loss, output_losses, masks
218
219
220def _process_single_batch(model,
221                          inputs,
222                          targets,
223                          output_loss_metrics=None,
224                          sample_weights=None,
225                          training=False):
226  """Calculate the loss and gradient for one input batch.
227
228     The model weights are updated if training is set to True.
229
230  Args:
231      model: Model whose loss has to be calculated.
232      inputs: List of input arrays.
233      targets: List of target arrays.
234      output_loss_metrics: List of metrics that are used to aggregated output
235        loss values.
236      sample_weights: Optional list of sample weight arrays.
237      training: The boolean represents if the weights of the model are updated.
238              'fit' methods will set this to True while 'evaluate' methods will
239              set this to False.
240
241  Returns:
242      output of the model, total loss, the loss and the mask
243      associated with each output.
244
245  Raises:
246      ValueError: If the model has no loss to optimize.
247  """
248  with backend.eager_learning_phase_scope(1 if training else 0), \
249      training_utils.RespectCompiledTrainableState(model):
250    with GradientTape() as tape:
251      outs, total_loss, output_losses, masks = (
252          _model_loss(
253              model,
254              inputs,
255              targets,
256              output_loss_metrics=output_loss_metrics,
257              sample_weights=sample_weights,
258              training=training))
259      if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
260        scaled_total_loss = model.optimizer.get_scaled_loss(total_loss)
261      else:
262        scaled_total_loss = total_loss
263    if training:
264      trainable_weights = model.trainable_weights
265      if trainable_weights:
266        # TODO(tanzheny) b/132690565: Provide mechanism for user to override
267        # model.train_on_batch.
268        if hasattr(model, '_backwards'):
269          model._backwards(tape, scaled_total_loss)
270        else:
271          grads = tape.gradient(scaled_total_loss, trainable_weights)
272          if isinstance(model.optimizer,
273                        loss_scale_optimizer.LossScaleOptimizer):
274            grads = model.optimizer.get_unscaled_gradients(grads)
275          model.optimizer.apply_gradients(zip(grads, trainable_weights))
276      else:
277        logging.warning('The list of trainable weights is empty. Make sure that'
278                        ' you are not setting model.trainable to False before '
279                        'compiling the model.')
280    return outs, total_loss, output_losses, masks
281
282
283def train_on_batch(model,
284                   inputs,
285                   targets,
286                   sample_weights=None,
287                   output_loss_metrics=None):
288  """Calculates the loss and gradient updates for one input batch.
289
290  Args:
291      model: Model whose loss has to be calculated.
292      inputs: Input batch data.
293      targets: Target batch data.
294      sample_weights: Sample weight batch data.
295      output_loss_metrics: List of metrics that are used to aggregated output
296        loss values.
297
298  Returns:
299      Dict with three items:
300        'total_loss': list with a single tensor for overall loss,
301        'output_losses': list of tensors for loss corresponding to each of the
302          model output. Could be a empty list when model has only one output.
303        'metrics': list of tensors for metric specified.
304  """
305  inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
306  outs, total_loss, output_losses, masks = (
307      _process_single_batch(
308          model,
309          inputs,
310          targets,
311          sample_weights=sample_weights,
312          training=True,
313          output_loss_metrics=output_loss_metrics))
314  if not isinstance(outs, list):
315    outs = [outs]
316  metrics_results = _eager_metrics_fn(
317      model, outs, targets, sample_weights=sample_weights, masks=masks)
318  total_loss = nest.flatten(total_loss)
319  return {'total_loss': total_loss,
320          'output_losses': output_losses,
321          'metrics': metrics_results}
322
323
324def test_on_batch(model,
325                  inputs,
326                  targets,
327                  sample_weights=None,
328                  output_loss_metrics=None):
329  """Calculates the loss for one input batch.
330
331  Args:
332      model: Model whose loss has to be calculated.
333      inputs: Input batch data.
334      targets: Target batch data.
335      sample_weights: Sample weight batch data.
336      output_loss_metrics: List of metrics that are used to aggregated output
337        loss values.
338
339  Returns:
340      Dict with three items:
341        'total_loss': single tensor for overall loss,
342        'output_losses': list of tensors for loss corresponding to each of the
343          model output. Could be a empty list when model has only one output.
344        'metrics': list of tensors for metric specified.
345  """
346  inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model)
347
348  with backend.eager_learning_phase_scope(0):
349    outs, total_loss, output_losses, masks = (
350        _model_loss(
351            model,
352            inputs,
353            targets,
354            sample_weights=sample_weights,
355            training=False,
356            output_loss_metrics=output_loss_metrics))
357  if not isinstance(outs, list):
358    outs = [outs]
359  metrics_results = _eager_metrics_fn(
360      model, outs, targets, sample_weights=sample_weights, masks=masks)
361  total_loss = nest.flatten(total_loss)
362
363  return {'total_loss': total_loss,
364          'output_losses': output_losses,
365          'metrics': metrics_results}
366