1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training-related utilities.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.keras.utils import generic_utils 25from tensorflow.python.ops import array_ops 26from tensorflow.python.util import nest 27 28 29def slice_arrays(arrays, indices, contiguous=True): 30 """Slices batches out of provided arrays (workaround for eager tensors). 31 32 Unfortunately eager tensors don't have the same slicing behavior as 33 Numpy arrays (they follow the same slicing behavior as symbolic TF tensors), 34 hence we cannot use `generic_utils.slice_arrays` directly 35 and we have to implement this workaround based on `concat`. This has a 36 performance cost. 37 38 Args: 39 arrays: Single array or list of arrays. 40 indices: List of indices in the array that should be included in the output 41 batch. 42 contiguous: Boolean flag indicating whether the indices are contiguous. 43 44 Returns: 45 Slice of data (either single array or list of arrays). 46 """ 47 converted_to_list = False 48 if not isinstance(arrays, list): 49 converted_to_list = True 50 arrays = [arrays] 51 if any(tensor_util.is_tf_type(x) for x in arrays): 52 if not contiguous: 53 entries = [[x[i:i + 1] for i in indices] for x in arrays] 54 slices = [array_ops.concat(x, axis=0) for x in entries] 55 else: 56 slices = [x[indices[0]:indices[-1] + 1] for x in arrays] 57 else: 58 slices = generic_utils.slice_arrays(arrays, indices) 59 60 if converted_to_list: 61 slices = slices[0] 62 return slices 63 64 65def handle_partial_sample_weights(outputs, sample_weights, sample_weight_modes, 66 check_all_flat=False): 67 """Adds 1.0 as sample weights for the outputs for which there is no weight. 68 69 Args: 70 outputs: List of model outputs. 71 sample_weights: List of sample weight inputs. 72 sample_weight_modes: List of sample weight modes or None. 73 check_all_flat: Ensure that inputs are not nested structures. This is not 74 a free check, so we may not want to run it eagerly every iteration. 75 76 Returns: 77 Tuple of sample weights, one sample weight for every output, and booleans 78 describing the raw sample weights. 79 """ 80 any_sample_weight = sample_weights is not None and any( 81 w is not None for w in sample_weights) 82 partial_sample_weight = any_sample_weight and any( 83 w is None for w in sample_weights) 84 85 if not any_sample_weight: 86 return None, any_sample_weight, partial_sample_weight 87 88 if not partial_sample_weight: 89 return sample_weights, any_sample_weight, partial_sample_weight 90 91 if check_all_flat: 92 nest.assert_same_structure( 93 list_to_tuple(sample_weights), 94 list_to_tuple(nest.flatten(sample_weights))) 95 nest.assert_same_structure( 96 list_to_tuple(outputs), 97 list_to_tuple(nest.flatten(outputs))) 98 if sample_weight_modes is not None: 99 nest.assert_same_structure( 100 sample_weight_modes, nest.flatten(sample_weight_modes)) 101 102 new_sample_weights = [] 103 for i, sw in enumerate(sample_weights): 104 if sw is None: 105 as_numpy = isinstance(outputs[i], np.ndarray) 106 output = outputs[i] 107 output_shape = output.shape if as_numpy else array_ops.shape(output) 108 109 is_temporal = ( 110 sample_weight_modes is not None and 111 sample_weight_modes[i] == 'temporal') 112 sw_shape = (output_shape[0], 113 output_shape[1]) if is_temporal else (output_shape[0],) 114 115 new_sample_weights.append( 116 np.ones(sw_shape) if as_numpy else array_ops.ones(sw_shape)) 117 118 else: 119 new_sample_weights.append(sw) 120 return (list_to_tuple(new_sample_weights), 121 any_sample_weight, partial_sample_weight) 122 123 124class RespectCompiledTrainableState(object): 125 """Set and restore trainable state if it has changed since compile. 126 127 The keras API guarantees that the value of each Layer's `trainable` property 128 at `Model.compile` time will be used when training that model. In order to 129 respect this requirement, it may be necessary to set the trainable value of 130 layers to their compile time values before beginning a training endpoint and 131 restore the values before returing from said endpoint. This scope checks if 132 any layer's trainable state has changed since Model compile, and performs this 133 set and un-set bookkeeping. 134 135 However, the trainable state of a layer changes quite infrequently, if ever, 136 for many kinds of workflows. Moreover, updating every layer in a model is an 137 expensive operation. As a result, we will only explicitly set and unset the 138 trainable state of a model if a trainable value has changed since compile. 139 """ 140 141 def __init__(self, model): 142 self._model = model 143 self._current_trainable_state = None 144 self._compiled_trainable_state = None 145 self._should_set_trainable = False 146 147 def __enter__(self): 148 self._current_trainable_state = self._model._get_trainable_state() # pylint: disable=protected-access 149 self._compiled_trainable_state = self._model._compiled_trainable_state # pylint: disable=protected-access 150 151 # Check to see if any layer's trainable state has changed since `compile`. 152 for layer, trainable in self._compiled_trainable_state.items(): 153 if (layer in self._current_trainable_state and 154 trainable != self._current_trainable_state[layer]): 155 self._should_set_trainable = True 156 break 157 158 # If so, restore the model to its compiled state. 159 if self._should_set_trainable: 160 self._model._set_trainable_state(self._compiled_trainable_state) # pylint: disable=protected-access 161 162 def __exit__(self, type_arg, value_arg, traceback_arg): 163 # If we set the values to their compiled state in __enter__, we need to 164 # restore the original values before leaving the scope. 165 if self._should_set_trainable: 166 self._model._set_trainable_state(self._current_trainable_state) # pylint: disable=protected-access 167 return False # False values do not suppress exceptions 168 169 170# Allow use of methods not exposed to the user. 171# pylint: disable=protected-access 172def get_input_shape_and_dtype(layer): 173 """Retrieves input shape and input dtype of layer if applicable. 174 175 Args: 176 layer: Layer (or model) instance. 177 178 Returns: 179 Tuple (input_shape, input_dtype). Both could be None if the layer 180 does not have a defined input shape. 181 182 Raises: 183 ValueError: in case an empty Sequential or Functional model is passed. 184 """ 185 186 def _is_graph_model(layer): 187 return ((hasattr(layer, '_is_graph_network') and layer._is_graph_network) or 188 layer.__class__.__name__ == 'Sequential') 189 190 # In case of nested models: recover the first layer 191 # of the deepest model to infer input shape and dtype. 192 # Subclassed Models may not have been built so can't be checked. 193 while _is_graph_model(layer): 194 if not layer.layers: 195 raise ValueError('An empty Model cannot be used as a Layer.') 196 layer = layer.layers[0] 197 198 if getattr(layer, '_batch_input_shape', None): 199 return layer._batch_input_shape, layer.dtype 200 return None, None 201 202 203# pylint: enable=protected-access 204 205 206def get_static_batch_size(layer): 207 """Gets the static batch size of a Layer. 208 209 Args: 210 layer: a `Layer` instance. 211 212 Returns: 213 The static batch size of a Layer. 214 """ 215 batch_input_shape, _ = get_input_shape_and_dtype(layer) 216 if batch_input_shape is not None: 217 return tensor_shape.Dimension(batch_input_shape[0]).value 218 return None 219 220 221def list_to_tuple(maybe_list): 222 """Datasets will stack the list of tensor, so switch them to tuples.""" 223 if isinstance(maybe_list, list): 224 return tuple(maybe_list) 225 return maybe_list 226