1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the Policy class for mixed precision training.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21 22from tensorflow.python.util.tf_export import keras_export 23 24 25@keras_export('keras.mixed_precision.experimental.Policy') 26class Policy(object): 27 """A mixed precision policy for a Keras layer. 28 29 A mixed precision policy determines the floating-point dtype that Keras layers 30 should create variables in. For non-default policies, if the variable dtype 31 does not match the input dtype, variables will automatically be casted to the 32 input dtype to avoid type errors. Policies can be passed to the 'dtype' 33 argument of layer constructors, or a global policy can be set with 34 'set_policy'. 35 36 In the near future, policies will also determine the computation dtype of 37 layers, as well as the loss scaling algorithm. 38 39 Policies are intended to enable mixed precision training, which require using 40 float32 variables and [b]float16 computations for most layers. The term "mixed 41 precision" refers to the use of both float16 (or bfloat16) and float32 in a 42 model. See https://arxiv.org/abs/1710.03740 for more information on mixed 43 precision training. 44 45 Policies are constructed by passing a string to the `name` constructor 46 argument. `name` determines the behavior of the policy. Currently, `name` can 47 be one of the following values. 48 49 * 'infer': Infer the variable and computation dtypes from the input dtype. 50 This is the default behavior. 51 * 'infer_float32_vars': Infer the computation dtypes from the input 52 dtype, but create variables in float32. Variables will be casted to the 53 computation dtype. This is intended to enable mixed precision. Users can 54 cast tensors to float16 before passing them to a layer, which causes the 55 layer to run it's computation in float16 while keeping variables in 56 float32. 57 58 To use mixed precision in a model, the 'infer_float32_vars' policy can be used 59 alongside float16 input tensors, which results in float16 computations and 60 float32 variables. For example: 61 62 ```python 63 tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars') 64 model = tf.keras.models.Sequential( 65 tf.keras.layers.Input((100,), dtype='float16'), 66 tf.keras.layers.Dense(10), 67 tf.keras.layers.Dense(10), 68 tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')), 69 tf.keras.layers.Activation('Softmax') 70 ) 71 ``` 72 73 Alternatively, the policy can be passed to individual layers instead of 74 setting the global policy with `set_policy`: 75 76 ```python 77 policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars') 78 model = tf.keras.models.Sequential( 79 tf.keras.layers.Input((100,), dtype='float16'), 80 tf.keras.layers.Dense(10, dtype=policy), 81 tf.keras.layers.Dense(10, dtype=policy), 82 tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')), 83 tf.keras.layers.Activation('Softmax') 84 ) 85 ``` 86 87 Note that a LossScaleOptimizer should also be used for mixed precision models 88 to avoid numerical underflow. See `LossScaleOptimizer`. 89 """ 90 91 def __init__(self, name): 92 self._name = name 93 if name == 'infer': 94 self._default_variable_dtype = None 95 elif name == 'infer_float32_vars': 96 self._default_variable_dtype = 'float32' 97 else: 98 raise ValueError('"name" argument to Policy constructor must be "infer" ' 99 'or "infer_float32_vars", but got: %s' % name) 100 101 @property 102 def name(self): 103 """Returns the name of the policy: "infer" or "infer_float32_vars.""" 104 return self._name 105 106 @property 107 def default_variable_dtype(self): 108 """Returns the default variable dtype of this policy. 109 110 This is the dtype layers will create their variables in, unless a layer 111 explicit chooses a different dtype. Layers will cast variables to the 112 appropriate dtype to avoid type errors. 113 114 Returns: 115 The default variable dtype of this policy, or None if the default variable 116 dtype should be derived from the inputs. 117 """ 118 return self._default_variable_dtype 119 120 @property 121 def should_cast_variables(self): 122 """Returns true if variables should be casted.""" 123 return self.default_variable_dtype is not None 124 125 # TODO(reedwm): Implement get_config/from_config. 126 127 128# TODO(reedwm): Make this thread local? 129_global_policy = Policy('infer') 130 131 132@keras_export('keras.mixed_precision.experimental.global_policy') 133def global_policy(): 134 """Returns the global Policy. 135 136 The global policy is the default policy used for layers, if no policy is 137 passed to the layer constructor. When TensorFlow starts, the global policy is 138 set to an "infer" policy, and can be changed with `set_policy`. 139 140 Returns: 141 The global Policy. 142 """ 143 return _global_policy 144 145 146@keras_export('keras.mixed_precision.experimental.set_policy') 147def set_policy(policy): 148 """Sets the global Policy.""" 149 global _global_policy 150 if not isinstance(policy, Policy): 151 policy = Policy(policy) 152 _global_policy = policy 153 154 155# TODO(reedwm): Make this thread local 156@contextlib.contextmanager 157def policy_scope(policy): 158 old_policy = _global_policy 159 try: 160 set_policy(policy) 161 yield 162 finally: 163 set_policy(old_policy) 164