• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the Policy class for mixed precision training."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21
22from tensorflow.python.util.tf_export import keras_export
23
24
25@keras_export('keras.mixed_precision.experimental.Policy')
26class Policy(object):
27  """A mixed precision policy for a Keras layer.
28
29  A mixed precision policy determines the floating-point dtype that Keras layers
30  should create variables in. For non-default policies, if the variable dtype
31  does not match the input dtype, variables will automatically be casted to the
32  input dtype to avoid type errors. Policies can be passed to the 'dtype'
33  argument of layer constructors, or a global policy can be set with
34  'set_policy'.
35
36  In the near future, policies will also determine the computation dtype of
37  layers, as well as the loss scaling algorithm.
38
39  Policies are intended to enable mixed precision training, which require using
40  float32 variables and [b]float16 computations for most layers. The term "mixed
41  precision" refers to the use of both float16 (or bfloat16) and float32 in a
42  model. See https://arxiv.org/abs/1710.03740 for more information on mixed
43  precision training.
44
45  Policies are constructed by passing a string to the `name` constructor
46  argument. `name` determines the behavior of the policy. Currently, `name` can
47  be one of the following values.
48
49    * 'infer': Infer the variable and computation dtypes from the input dtype.
50      This is the default behavior.
51    * 'infer_float32_vars': Infer the computation dtypes from the input
52      dtype, but create variables in float32. Variables will be casted to the
53      computation dtype. This is intended to enable mixed precision. Users can
54      cast tensors to float16 before passing them to a layer, which causes the
55      layer to run it's computation in float16 while keeping variables in
56      float32.
57
58  To use mixed precision in a model, the 'infer_float32_vars' policy can be used
59  alongside float16 input tensors, which results in float16 computations and
60  float32 variables. For example:
61
62  ```python
63  tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars')
64  model = tf.keras.models.Sequential(
65      tf.keras.layers.Input((100,), dtype='float16'),
66      tf.keras.layers.Dense(10),
67      tf.keras.layers.Dense(10),
68      tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
69      tf.keras.layers.Activation('Softmax')
70  )
71  ```
72
73  Alternatively, the policy can be passed to individual layers instead of
74  setting the global policy with `set_policy`:
75
76  ```python
77  policy = tf.keras.mixed_precision.experimental.Policy('infer_float32_vars')
78  model = tf.keras.models.Sequential(
79      tf.keras.layers.Input((100,), dtype='float16'),
80      tf.keras.layers.Dense(10, dtype=policy),
81      tf.keras.layers.Dense(10, dtype=policy),
82      tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
83      tf.keras.layers.Activation('Softmax')
84  )
85  ```
86
87  Note that a LossScaleOptimizer should also be used for mixed precision models
88  to avoid numerical underflow. See `LossScaleOptimizer`.
89  """
90
91  def __init__(self, name):
92    self._name = name
93    if name == 'infer':
94      self._default_variable_dtype = None
95    elif name == 'infer_float32_vars':
96      self._default_variable_dtype = 'float32'
97    else:
98      raise ValueError('"name" argument to Policy constructor must be "infer" '
99                       'or "infer_float32_vars", but got: %s' % name)
100
101  @property
102  def name(self):
103    """Returns the name of the policy: "infer" or "infer_float32_vars."""
104    return self._name
105
106  @property
107  def default_variable_dtype(self):
108    """Returns the default variable dtype of this policy.
109
110    This is the dtype layers will create their variables in, unless a layer
111    explicit chooses a different dtype. Layers will cast variables to the
112    appropriate dtype to avoid type errors.
113
114    Returns:
115      The default variable dtype of this policy, or None if the default variable
116      dtype should be derived from the inputs.
117    """
118    return self._default_variable_dtype
119
120  @property
121  def should_cast_variables(self):
122    """Returns true if variables should be casted."""
123    return self.default_variable_dtype is not None
124
125  # TODO(reedwm): Implement get_config/from_config.
126
127
128# TODO(reedwm): Make this thread local?
129_global_policy = Policy('infer')
130
131
132@keras_export('keras.mixed_precision.experimental.global_policy')
133def global_policy():
134  """Returns the global Policy.
135
136  The global policy is the default policy used for layers, if no policy is
137  passed to the layer constructor. When TensorFlow starts, the global policy is
138  set to an "infer" policy, and can be changed with `set_policy`.
139
140  Returns:
141    The global Policy.
142  """
143  return _global_policy
144
145
146@keras_export('keras.mixed_precision.experimental.set_policy')
147def set_policy(policy):
148  """Sets the global Policy."""
149  global _global_policy
150  if not isinstance(policy, Policy):
151    policy = Policy(policy)
152  _global_policy = policy
153
154
155# TODO(reedwm): Make this thread local
156@contextlib.contextmanager
157def policy_scope(policy):
158  old_policy = _global_policy
159  try:
160    set_policy(policy)
161    yield
162  finally:
163    set_policy(old_policy)
164