1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras training and evaluation routines for eager execution.""" 16# pylint: disable=protected-access 17 18import numpy as np 19 20from tensorflow.python.eager.backprop import GradientTape 21from tensorflow.python.framework import ops 22from tensorflow.python.keras import backend 23from tensorflow.python.keras.engine import training_utils 24from tensorflow.python.keras.engine import training_utils_v1 25from tensorflow.python.keras.mixed_precision import loss_scale_optimizer 26from tensorflow.python.keras.utils import losses_utils 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.util import nest 30 31 32def _eager_loss_fn(outputs, targets, loss_fn, output_name): 33 with backend.name_scope(output_name + '_loss'): 34 loss = loss_fn(targets, outputs) 35 return loss 36 37 38def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None): 39 """Calculates the metrics for each output of the given model. 40 41 Args: 42 model: The model on which metrics are being calculated. 43 outputs: The outputs of the given model. 44 targets: The predictions or targets of the given model. 45 sample_weights: Optional list of sample weights for each output. 46 masks: Optional list of masks for each output. 47 48 Returns: 49 Returns the metric results for each output of the model. 50 """ 51 outputs = nest.flatten(outputs) 52 targets = nest.flatten(targets) 53 # Invoke all(weighted and unweighted) metrics. 54 metric_results = [] 55 if targets: 56 # Insert None values corresponding to the targets that need to be skipped 57 # on the model. 58 if len(model._targets) != len(targets): 59 new_targets = [ 60 None if t is None else targets.pop(0) for t in model._targets 61 ] 62 targets = new_targets 63 64 metric_results = model._handle_metrics( 65 outputs, 66 targets=targets, 67 sample_weights=sample_weights, 68 masks=masks, 69 return_weighted_and_unweighted_metrics=True, 70 skip_target_masks=model._prepare_skip_target_masks()) 71 72 # Add metric results from the `add_metric` metrics. 73 metric_results.extend([ 74 m.result() 75 for m in model.metrics 76 if m not in model._compile_metric_functions 77 ]) 78 return metric_results 79 80 81def _model_loss(model, 82 inputs, 83 targets, 84 output_loss_metrics=None, 85 sample_weights=None, 86 training=False): 87 """Calculates the loss for a given model. 88 89 Args: 90 model: The model on which metrics are being calculated. 91 inputs: Either a dictionary of inputs to the model or a list of input 92 arrays. 93 targets: List of target arrays. 94 output_loss_metrics: List of metrics that are used to aggregated output 95 loss values. 96 sample_weights: Optional list of sample weight arrays. 97 training: Whether the model should be run in inference or training mode. 98 99 Returns: 100 Returns the model output, total loss, loss value calculated using the 101 specified loss function and masks for each output. The total loss includes 102 regularization losses and applies masking and sample weighting 103 to the loss value. 104 """ 105 # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn. 106 # Used to keep track of the total loss value (stateless). 107 # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + 108 # loss_weight_2 * output_2_loss_fn(...) + 109 # layer losses. 110 total_loss = 0 111 kwargs = {} 112 if model._expects_training_arg: 113 kwargs['training'] = training 114 if len(inputs) == 1 and not isinstance(inputs, dict): 115 inputs = inputs[0] 116 117 # Allow mixed `NumPy` and `EagerTensor` input here. 118 if any( 119 isinstance(input_t, (np.ndarray, float, int)) 120 for input_t in nest.flatten(inputs)): 121 inputs = nest.map_structure(ops.convert_to_tensor_v2_with_dispatch, inputs) 122 123 outs = model(inputs, **kwargs) 124 outs = nest.flatten(outs) 125 126 if targets: 127 targets = training_utils_v1.cast_if_floating_dtype_and_mismatch( 128 targets, outs) 129 # TODO(sallymatson/psv): check if we should do same mismatch fix for weights 130 if sample_weights: 131 sample_weights = [ 132 training_utils_v1.cast_if_floating_dtype( 133 ops.convert_to_tensor_v2_with_dispatch(val)) 134 if val is not None else None for val in sample_weights 135 ] 136 137 masks = [getattr(t, '_keras_mask', None) for t in outs] 138 targets = nest.flatten(targets) 139 140 # Used to keep track of individual output losses. 141 output_losses = [] 142 143 with backend.name_scope('loss'): 144 loss_fns = [ 145 loss_fn for loss_fn in model.loss_functions if loss_fn is not None 146 ] 147 custom_losses = model.losses # Regularization losses 148 149 if not loss_fns and not custom_losses: 150 if training: 151 raise ValueError('The model cannot be trained ' 152 'because it has no loss to optimize.') 153 else: 154 raise ValueError('The model cannot be evaluated ' 155 'because it has no loss to compute.') 156 157 for i, loss_fn in enumerate(loss_fns): 158 weights = sample_weights[i] if sample_weights else None 159 mask = masks[i] 160 with backend.name_scope(model.output_names[i] + '_loss'): 161 if mask is not None: 162 mask = math_ops.cast(mask, outs[i].dtype) 163 # Update weights with mask. 164 if weights is None: 165 weights = mask 166 else: 167 # Update dimensions of weights to match with mask if possible. 168 weights = math_ops.cast(weights, outs[i].dtype) 169 mask, _, weights = ( 170 losses_utils.squeeze_or_expand_dimensions( 171 mask, sample_weight=weights)) 172 weights *= mask 173 174 if hasattr(loss_fn, 'reduction'): 175 per_sample_losses = loss_fn.call(targets[i], outs[i]) 176 weighted_losses = losses_utils.compute_weighted_loss( 177 per_sample_losses, 178 sample_weight=weights, 179 reduction=losses_utils.ReductionV2.NONE) 180 loss_reduction = loss_fn.reduction 181 182 # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all 183 # compile use cases. 184 if loss_reduction == losses_utils.ReductionV2.AUTO: 185 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 186 187 # Compute the stateless loss value. 188 output_loss = losses_utils.reduce_weighted_loss( 189 weighted_losses, reduction=loss_reduction) 190 else: 191 # Compute the stateless loss value for a custom loss class. 192 # Here we assume that the class takes care of loss reduction 193 # because if this class returns a vector value we cannot 194 # differentiate between use case where a custom optimizer 195 # expects a vector loss value vs unreduced per-sample loss value. 196 output_loss = loss_fn(targets[i], outs[i], sample_weight=weights) 197 loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE 198 199 # If the number of outputs is 1 then we don't append the loss metric 200 # associated with each model output. When there are multiple outputs 201 # associated with a model, each output's loss is calculated and returned 202 # as part of the loss_metrics. 203 if len(model.outputs) > 1: 204 # Keep track of the stateful output loss result. 205 output_losses.append(output_loss_metrics[i](output_loss)) 206 207 # Scale output loss for distribution. For custom losses we assume 208 # reduction was mean. 209 if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: 210 output_loss = losses_utils.scale_loss_for_distribution(output_loss) 211 total_loss += model._loss_weights_list[i] * output_loss 212 213 # Add regularization losses 214 if custom_losses: 215 total_loss += losses_utils.scale_loss_for_distribution( 216 math_ops.add_n(custom_losses)) 217 return outs, total_loss, output_losses, masks 218 219 220def _process_single_batch(model, 221 inputs, 222 targets, 223 output_loss_metrics=None, 224 sample_weights=None, 225 training=False): 226 """Calculate the loss and gradient for one input batch. 227 228 The model weights are updated if training is set to True. 229 230 Args: 231 model: Model whose loss has to be calculated. 232 inputs: List of input arrays. 233 targets: List of target arrays. 234 output_loss_metrics: List of metrics that are used to aggregated output 235 loss values. 236 sample_weights: Optional list of sample weight arrays. 237 training: The boolean represents if the weights of the model are updated. 238 'fit' methods will set this to True while 'evaluate' methods will 239 set this to False. 240 241 Returns: 242 output of the model, total loss, the loss and the mask 243 associated with each output. 244 245 Raises: 246 ValueError: If the model has no loss to optimize. 247 """ 248 with backend.eager_learning_phase_scope(1 if training else 0), \ 249 training_utils.RespectCompiledTrainableState(model): 250 with GradientTape() as tape: 251 outs, total_loss, output_losses, masks = ( 252 _model_loss( 253 model, 254 inputs, 255 targets, 256 output_loss_metrics=output_loss_metrics, 257 sample_weights=sample_weights, 258 training=training)) 259 if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer): 260 scaled_total_loss = model.optimizer.get_scaled_loss(total_loss) 261 else: 262 scaled_total_loss = total_loss 263 if training: 264 trainable_weights = model.trainable_weights 265 if trainable_weights: 266 # TODO(tanzheny) b/132690565: Provide mechanism for user to override 267 # model.train_on_batch. 268 if hasattr(model, '_backwards'): 269 model._backwards(tape, scaled_total_loss) 270 else: 271 grads = tape.gradient(scaled_total_loss, trainable_weights) 272 if isinstance(model.optimizer, 273 loss_scale_optimizer.LossScaleOptimizer): 274 grads = model.optimizer.get_unscaled_gradients(grads) 275 model.optimizer.apply_gradients(zip(grads, trainable_weights)) 276 else: 277 logging.warning('The list of trainable weights is empty. Make sure that' 278 ' you are not setting model.trainable to False before ' 279 'compiling the model.') 280 return outs, total_loss, output_losses, masks 281 282 283def train_on_batch(model, 284 inputs, 285 targets, 286 sample_weights=None, 287 output_loss_metrics=None): 288 """Calculates the loss and gradient updates for one input batch. 289 290 Args: 291 model: Model whose loss has to be calculated. 292 inputs: Input batch data. 293 targets: Target batch data. 294 sample_weights: Sample weight batch data. 295 output_loss_metrics: List of metrics that are used to aggregated output 296 loss values. 297 298 Returns: 299 Dict with three items: 300 'total_loss': list with a single tensor for overall loss, 301 'output_losses': list of tensors for loss corresponding to each of the 302 model output. Could be a empty list when model has only one output. 303 'metrics': list of tensors for metric specified. 304 """ 305 inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model) 306 outs, total_loss, output_losses, masks = ( 307 _process_single_batch( 308 model, 309 inputs, 310 targets, 311 sample_weights=sample_weights, 312 training=True, 313 output_loss_metrics=output_loss_metrics)) 314 if not isinstance(outs, list): 315 outs = [outs] 316 metrics_results = _eager_metrics_fn( 317 model, outs, targets, sample_weights=sample_weights, masks=masks) 318 total_loss = nest.flatten(total_loss) 319 return {'total_loss': total_loss, 320 'output_losses': output_losses, 321 'metrics': metrics_results} 322 323 324def test_on_batch(model, 325 inputs, 326 targets, 327 sample_weights=None, 328 output_loss_metrics=None): 329 """Calculates the loss for one input batch. 330 331 Args: 332 model: Model whose loss has to be calculated. 333 inputs: Input batch data. 334 targets: Target batch data. 335 sample_weights: Sample weight batch data. 336 output_loss_metrics: List of metrics that are used to aggregated output 337 loss values. 338 339 Returns: 340 Dict with three items: 341 'total_loss': single tensor for overall loss, 342 'output_losses': list of tensors for loss corresponding to each of the 343 model output. Could be a empty list when model has only one output. 344 'metrics': list of tensors for metric specified. 345 """ 346 inputs = training_utils_v1.cast_to_model_input_dtypes(inputs, model) 347 348 with backend.eager_learning_phase_scope(0): 349 outs, total_loss, output_losses, masks = ( 350 _model_loss( 351 model, 352 inputs, 353 targets, 354 sample_weights=sample_weights, 355 training=False, 356 output_loss_metrics=output_loss_metrics)) 357 if not isinstance(outs, list): 358 outs = [outs] 359 metrics_results = _eager_metrics_fn( 360 model, outs, targets, sample_weights=sample_weights, masks=masks) 361 total_loss = nest.flatten(total_loss) 362 363 return {'total_loss': total_loss, 364 'output_losses': output_losses, 365 'metrics': metrics_results} 366