• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the Policy class for mixed precision training."""
16
17import contextlib
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.keras import backend
21from tensorflow.python.keras.engine import base_layer_utils
22from tensorflow.python.keras.mixed_precision import device_compatibility_check
23from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module
24from tensorflow.python.keras.utils import generic_utils
25from tensorflow.python.platform import tf_logging
26from tensorflow.python.training.experimental import mixed_precision_global_state
27from tensorflow.python.util.tf_export import keras_export
28
29
30# pylint: disable=g-classes-have-attributes
31@keras_export('keras.mixed_precision.Policy', v1=[])
32class Policy(object):
33  """A dtype policy for a Keras layer.
34
35  A dtype policy determines a layer's computation and variable dtypes. Each
36  layer has a policy. Policies can be passed to the `dtype` argument of layer
37  constructors, or a global policy can be set with
38  `tf.keras.mixed_precision.set_global_policy`.
39
40  Args:
41    name: The policy name, which determines the compute and variable dtypes. Can
42      be any dtype name, such as `'float32'` or `'float64'`, which causes both
43      the compute and variable dtypes will be that dtype. Can also be the string
44      `'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute dtype to
45      be float16 or bfloat16 and the variable dtype to be float32.
46
47  Typically you only need to interact with dtype policies when using mixed
48  precision, which is the use of float16 or bfloat16 for computations and
49  float32 for variables. This is why the term `mixed_precision` appears in the
50  API name. Mixed precision can be enabled by passing `'mixed_float16'` or
51  `'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the
52  mixed precision guide](https://www.tensorflow.org/guide/keras/mixed_precision)
53  for more information on how to use mixed precision.
54
55  >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
56  >>> layer1 = tf.keras.layers.Dense(10)
57  >>> layer1.dtype_policy  # `layer1` will automatically use mixed precision
58  <Policy "mixed_float16">
59  >>> # Can optionally override layer to use float32 instead of mixed precision.
60  >>> layer2 = tf.keras.layers.Dense(10, dtype='float32')
61  >>> layer2.dtype_policy
62  <Policy "float32">
63  >>> # Set policy back to initial float32 for future examples.
64  >>> tf.keras.mixed_precision.set_global_policy('float32')
65
66  In the example above, passing `dtype='float32'` to the layer is equivalent to
67  passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general,
68  passing a dtype policy name to a layer is equivalent to passing the
69  corresponding policy, so it is never necessary to explicitly construct a
70  `Policy` object.
71
72  Note: `Model.compile` will automatically wrap an optimizer with a
73  `tf.keras.mixed_precision.LossScaleOptimizer` if you use the `'mixed_float16'`
74  policy. If you use a custom training loop instead of calling `Model.compile`,
75  you should explicitly use a `tf.keras.mixed_precision.LossScaleOptimizer` to
76  avoid numeric underflow with float16.
77
78  ### How a layer uses its policy's compute dtype
79
80  A layer casts its inputs to its compute dtype. This causes the layer's
81  computations and output to also be in the compute dtype. For example:
82
83  >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
84  >>> # `layer`'s policy defaults to float32.
85  >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
86  >>> layer.compute_dtype  # Equivalent to layer.dtype_policy.compute_dtype
87  'float32'
88  >>> # `layer` casts its inputs to its compute dtype and does computations in
89  >>> # that dtype.
90  >>> y = layer(x)
91  >>> y.dtype
92  tf.float32
93
94  Note that the base `tf.keras.layers.Layer` class inserts the casts. If
95  subclassing your own layer, you do not have to insert any casts.
96
97  Currently, only tensors in the first argument to the layer's `call` method are
98  casted (although this will likely be changed in a future minor release). For
99  example:
100
101  >>> class MyLayer(tf.keras.layers.Layer):
102  ...   # Bug! `b` will not be casted.
103  ...   def call(self, a, b):
104  ...     return a + 1., b + 1.
105  >>> a = tf.constant(1., dtype="float32")
106  >>> b = tf.constant(1., dtype="float32")
107  >>> layer = MyLayer(dtype="float64")
108  >>> x, y = layer(a, b)
109  >>> x.dtype
110  tf.float64
111  >>> y.dtype
112  tf.float32
113
114  If writing your own layer with multiple inputs, you should either explicitly
115  cast other tensors to `self.compute_dtype` in `call` or accept all tensors in
116  the first argument as a list.
117
118  The casting only occurs in TensorFlow 2. If
119  `tf.compat.v1.disable_v2_behavior()` has been called, you can enable the
120  casting behavior with `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`.
121
122  ### How a layer uses its policy's variable dtype
123
124  The default dtype of variables created by `tf.keras.layers.Layer.add_weight`
125  is the layer's policy's variable dtype.
126
127  If a layer's compute and variable dtypes differ, `add_weight` will wrap
128  floating-point variables with a special wrapper called an `AutoCastVariable`.
129  `AutoCastVariable` is identical to the original variable except it casts
130  itself to the layer's compute dtype when used within `Layer.call`. This means
131  if you are writing a layer, you do not have to explicitly cast the variables
132  to the layer's compute dtype. For example:
133
134  >>> class SimpleDense(tf.keras.layers.Layer):
135  ...
136  ...   def build(self, input_shape):
137  ...     # With mixed precision, self.kernel is a float32 AutoCastVariable
138  ...     self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
139  ...
140  ...   def call(self, inputs):
141  ...     # With mixed precision, self.kernel will be casted to float16
142  ...     return tf.linalg.matmul(inputs, self.kernel)
143  ...
144  >>> layer = SimpleDense(dtype='mixed_float16')
145  >>> y = layer(tf.ones((10, 10)))
146  >>> y.dtype
147  tf.float16
148  >>> layer.kernel.dtype
149  tf.float32
150
151  A layer author can prevent a variable from being wrapped with an
152  `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`,
153  which is useful if the float32 value of the variable must be accessed within
154  the layer.
155
156  ### How to write a layer that supports mixed precision and float64.
157
158  For the most part, layers will automatically support mixed precision and
159  float64 without any additional work, due to the fact the base layer
160  automatically casts inputs, creates variables of the correct type, and in the
161  case of mixed precision, wraps variables with `AutoCastVariables`.
162
163  The primary case where you need extra work to support mixed precision or
164  float64 is when you create a new tensor, such as with `tf.ones` or
165  `tf.random.normal`, In such cases, you must create the tensor of the correct
166  dtype. For example, if you call `tf.random.normal`, you must pass the compute
167  dtype, which is the dtype the inputs have been casted to:
168
169  >>> class AddRandom(tf.keras.layers.Layer):
170  ...
171  ...   def call(self, inputs):
172  ...     # We must pass `dtype=inputs.dtype`, otherwise a TypeError may
173  ...     # occur when adding `inputs` to `rand`.
174  ...     rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype)
175  ...     return inputs + rand
176  >>> layer = AddRandom(dtype='mixed_float16')
177  >>> y = layer(x)
178  >>> y.dtype
179  tf.float16
180
181  If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a
182  `TypeError` would have occurred. This is because the `tf.random.normal`'s
183  dtype defaults to `"float32"`, but the input dtype is float16. You cannot add
184  a float32 tensor with a float16 tensor.
185  """
186
187  def __init__(self, name):
188    if isinstance(name, dtypes.DType):
189      raise TypeError("'name' must be a string, not a DType. "
190                      "Instead, pass DType.name. Got: %s" % (name.name,))
191    elif not isinstance(name, str):
192      raise TypeError("'name' must be a string, but got: %s" % (name,))
193    self._name = name
194    self._compute_dtype, self._variable_dtype = self._parse_name(name)
195    if name in ('mixed_float16', 'mixed_bloat16'):
196      device_compatibility_check.log_device_compatibility_check(name)
197
198  def _parse_name(self, name):
199    """Parses a Policy name into a compute and variable dtype.
200
201    Args:
202      name: The name of the policy:
203
204    Returns:
205      The (compute_dtype, variable_dtype) pair.
206    """
207    if name.endswith('_float32_vars'):
208      error_msg = ('Policies ending in \'_float32_vars\' have been removed '
209                   'from TensorFlow.')
210      if name in ('infer_float32_vars', 'infer_with_float32_vars'):
211        error_msg += (' Please use the \'mixed_float16\' or \'mixed_bfloat16\' '
212                      'policy instead.')
213      elif name == 'float16_with_float32_vars':
214        error_msg += (' Please use the \'mixed_float16\' policy instead.')
215      elif name == 'bfloat16_with_float32_vars':
216        error_msg += (' Please use the \'mixed_bfloat16\' policy instead.')
217      error_msg += ' Got policy name: \'%s\'' % name
218      raise ValueError(error_msg)
219
220    if name == 'mixed_float16':
221      return 'float16', 'float32'
222    elif name == 'mixed_bfloat16':
223      return 'bfloat16', 'float32'
224    elif name == '_infer':
225      # The "_infer" policy exists only for compatibility with TF 1, where
226      # "_infer" is the default. The behavior matches the behavior of TF 1's
227      # behavior before policies were introduced. With "_infer", the computation
228      # and variable dtype are inferred from the first input the first time the
229      # layer is called. Once the layer is called for the first time, the
230      # layer's policy will change to the dtype of the first input, and it will
231      # no longer have the "_infer" policy.
232      #
233      # The infer policy should be considered an implementation detail and may
234      # be removed in the future.
235      return None, None
236
237    try:
238      dtype = dtypes.as_dtype(name).name
239    except TypeError:
240      error = ("Cannot convert value %s to a mixed precision Policy. "
241               "Valid policies include 'mixed_float16', 'mixed_bfloat16', "
242               "and the name of any dtype such as 'float32'." % (name,))
243      raise ValueError(error)
244    return dtype, dtype
245
246  @property
247  def variable_dtype(self):
248    """The variable dtype of this policy.
249
250    This is the dtype layers will create their variables in, unless a layer
251    explicitly chooses a different dtype. If this is different than
252    `Policy.compute_dtype`, Layers will cast variables to the compute dtype to
253    avoid type errors.
254
255    Variable regularizers are run in the variable dtype, not the compute dtype.
256
257    Returns:
258      The variable dtype of this policy, as a string.
259    """
260    return self._variable_dtype
261
262  @property
263  def compute_dtype(self):
264    """The compute dtype of this policy.
265
266    This is the dtype layers will do their computations in. Typically layers
267    output tensors with the compute dtype as well.
268
269    Note that even if the compute dtype is float16 or bfloat16, hardware devices
270    may not do individual adds, multiplies, and other fundamental operations in
271    float16 or bfloat16, but instead may do some of them in float32 for numeric
272    stability. The compute dtype is the dtype of the inputs and outputs of the
273    TensorFlow ops that the layer executes. Internally, many TensorFlow ops will
274    do certain internal calculations in float32 or some other device-internal
275    intermediate format with higher precision than float16/bfloat16, to increase
276    numeric stability.
277
278    For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a
279    float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`. But,
280    `tf.linalg.matmul` will do use float32 intermediate math. The performance
281    benefit of float16 is still apparent, due to increased memory bandwidth and
282    the fact modern GPUs have specialized hardware for computing matmuls on
283    float16 inputs while still keeping intermediate computations in float32.
284
285    Returns:
286      The compute dtype of this policy, as a string.
287    """
288    return self._compute_dtype
289
290  @property
291  def name(self):
292    """Returns the name of this policy."""
293    return self._name
294
295  def __repr__(self):
296    return '<Policy "%s">' % self._name
297
298  def get_config(self):
299    return {'name': self.name}
300
301  @classmethod
302  def from_config(cls, config, custom_objects=None):
303    del custom_objects
304    if 'loss_scale' in config:
305      config = config.copy()
306      # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We
307      # silently drop it.
308      del config['loss_scale']
309    return cls(**config)
310
311
312@keras_export('keras.mixed_precision.experimental.Policy', v1=[])
313class PolicyV1(Policy):
314  """A deprecated dtype policy for a Keras layer.
315
316  Warning: This class is now deprecated and will be removed soon. Please use the
317  non-experimental class `tf.keras.mixed_precision.Policy` instead.
318
319  The difference between this class and the non-experimental class is that this
320  class has a `loss_scale` field and the non-experimental class does not. The
321  loss scale is only used by `tf.keras.Model.compile`, which automatically wraps
322  the optimizer with a `LossScaleOptimizer` if the optimizer is not already a
323  `LossScaleOptimizer`. For the non-experimental Policy class, `Model.compile`
324  instead wraps the optimizer with a `LossScaleOptimizer` if `Policy.name` is
325  "mixed_float16".
326
327  When deserializing objects with an experimental policy using functions like
328  `tf.keras.utils.deserialize_keras_object`, the policy will be deserialized as
329  the non-experimental `tf.keras.mixed_precision.Policy`, and the loss scale
330  will silently be dropped. This is so that SavedModels that are generated
331  with an experimental policy can be restored after the experimental policy is
332  removed.
333  """
334
335  def __init__(self, name, loss_scale='auto'):
336    """Constructs the policy.
337
338    The `name` argument determines the compute and variable dtype, the default
339    loss scale, and has no additional effect on the Policy. The compute and
340    variable dtypes can only be specified through `name`, and cannot be
341    specified directly.
342
343    Args:
344      name: A string. Can be one of the following values:
345        * Any dtype name, such as 'float32' or 'float64'. Both the variable and
346          compute dtypes will be that dtype.
347        * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
348          bfloat16, while the variable dtype is float32. With 'mixed_float16',
349          a dynamic loss scale is used. These policies are used for mixed
350          precision training.
351      loss_scale: A `tf.compat.v1.mixed_precision.LossScale`, an int (which
352        uses a `FixedLossScale`), the string "dynamic" (which uses a
353        `DynamicLossScale`), or None (which uses no loss scale). Defaults to
354        `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then
355        use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only
356        `tf.keras.Model`s, not layers, use the loss scale, and it is only used
357        during `Model.fit`, `Model.train_on_batch`, and other similar methods.
358    """
359    super(PolicyV1, self).__init__(name)
360    if loss_scale == 'auto':
361      loss_scale = 'dynamic' if name == 'mixed_float16' else None
362      self._using_default_loss_scale = True
363    else:
364      self._using_default_loss_scale = False
365    if loss_scale and self._compute_dtype not in (None, 'float16'):
366      tf_logging.warning(
367          'Creating a Policy with a loss scale is only useful for '
368          'float16 policies. You passed loss_scale=%r for policy '
369          '%s. Consider not passing any loss_scale instead.' %
370          (loss_scale, name))
371    self._loss_scale = keras_loss_scale_module.get(loss_scale)
372
373  @property
374  def loss_scale(self):
375    """Returns the loss scale of this Policy.
376
377    Returns:
378      A `tf.compat.v1.mixed_precision.experimental.LossScale`, or None.
379    """
380    return self._loss_scale
381
382  def __repr__(self):
383    return '<PolicyV1 "%s", loss_scale=%s>' % (self._name, self.loss_scale)
384
385  def get_config(self):
386    config = {
387        'name': self.name
388    }
389    if not self._using_default_loss_scale:
390      # We only include the loss scale if the default loss scale is not used.
391      # This allows us to change the loss scale config format without breaking
392      # users who use the default loss scale.
393      config['loss_scale'] = keras_loss_scale_module.serialize(self.loss_scale)
394    return config
395
396  @classmethod
397  def from_config(cls, config, custom_objects=None):
398    if 'loss_scale' in config and isinstance(config['loss_scale'], dict):
399      config = config.copy()
400      config['loss_scale'] = keras_loss_scale_module.deserialize(
401          config['loss_scale'], custom_objects=custom_objects)
402    return cls(**config)
403
404
405# The current global policy in effect. If None, it means the current value of
406# floatx should be used as the policy if the V2 dtype behavior is enabled,
407# or "_infer" otherwise.
408# TODO(reedwm): Make this thread local?
409_global_policy = None
410
411
412@keras_export('keras.mixed_precision.global_policy',
413              'keras.mixed_precision.experimental.global_policy', v1=[])
414def global_policy():
415  """Returns the global dtype policy.
416
417  The global policy is the default `tf.keras.mixed_precision.Policy` used for
418  layers, if no policy is passed to the layer constructor. If no policy has been
419  set with `keras.mixed_precision.set_global_policy`, this will return a policy
420  constructed from `tf.keras.backend.floatx()` (floatx defaults to float32).
421
422  >>> tf.keras.mixed_precision.global_policy()
423  <Policy "float32">
424  >>> tf.keras.layers.Dense(10).dtype_policy  # Defaults to the global policy
425  <Policy "float32">
426
427  If TensorFlow 2 behavior has been disabled with
428  `tf.compat.v1.disable_v2_behavior()`, this will instead return a special
429  "_infer" policy which infers the dtype from the dtype of the first input the
430  first time the layer is called. This behavior matches the behavior that
431  existed in TensorFlow 1.
432
433  See `tf.keras.mixed_precision.Policy` for more information on policies.
434
435  Returns:
436    The global Policy.
437  """
438  if _global_policy is None:
439    if base_layer_utils.v2_dtype_behavior_enabled():
440      return Policy(backend.floatx())
441    else:
442      return Policy('_infer')
443  return _global_policy
444
445
446def _check_if_mixed_precision_graph_rewrite_is_enabled(policy):
447  if mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled():
448    raise ValueError(
449        'The global dtype policy cannot be set to "{policy.name}", because the '
450        'mixed precision graph rewrite has already been enabled.\n'
451        'At most, one of the following can be called:\n\n'
452        '  1. tf.compat.v1.train.enable_mixed_precision_graph_rewrite() '
453        '(You called this first)\n'
454        '  2. tf.keras.mixed_precision.experimental.set_policy() with a mixed '
455        'precision policy (You called this second)\n\n'
456        'You called both functions, which is an error, because both functions '
457        'enable you to use mixed precision. If in doubt which function to use, '
458        'use the second, as it supports Eager execution and is more '
459        'customizable.'.format(policy=policy))
460
461
462@keras_export('keras.mixed_precision.set_global_policy',
463              'keras.mixed_precision.experimental.set_policy', v1=[])
464def set_policy(policy):
465  """Sets the global dtype policy.
466
467  The global policy is the default `tf.keras.mixed_precision.Policy` used for
468  layers, if no policy is passed to the layer constructor.
469
470  >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
471  >>> tf.keras.mixed_precision.global_policy()
472  <Policy "mixed_float16">
473  >>> tf.keras.layers.Dense(10).dtype_policy
474  <Policy "mixed_float16">
475  >>> # Global policy is not used if a policy is directly passed to constructor
476  >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy
477  <Policy "float64">
478  >>> tf.keras.mixed_precision.set_global_policy('float32')
479
480  If no global policy is set, layers will instead default to a Policy
481  constructed from `tf.keras.backend.floatx()`.
482
483  To use mixed precision, the global policy should be set to `'mixed_float16'`
484  or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and
485  float32 variable dtype by default.
486
487  Only floating point policies can be set as the global policy, such as
488  `'float32'` and `'mixed_float16'`. Non-floating point policies such as
489  `'int32'` and `'complex64'` cannot be set as the global policy because most
490  layers do not support such policies.
491
492  See `tf.keras.mixed_precision.Policy` for more information.
493
494  Args:
495    policy: A Policy, or a string that will be converted to a Policy. Can also
496      be None, in which case the global policy will be constructed from
497      `tf.keras.backend.floatx()`
498  """
499  global _global_policy
500  if not base_layer_utils.v2_dtype_behavior_enabled():
501    raise ValueError('The global policy can only be set in TensorFlow 2 or if '
502                     'V2 dtype behavior has been set. To enable V2 dtype '
503                     'behavior, call '
504                     '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"')
505  if policy is not None and not isinstance(policy, Policy):
506    policy = Policy(policy)
507  is_mixed_policy = (policy is not None and
508                     policy.compute_dtype != policy.variable_dtype)
509  if is_mixed_policy:
510    _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
511  if (policy is not None and policy.compute_dtype is not None and
512      not dtypes.as_dtype(policy.compute_dtype).is_floating):
513    raise ValueError('set_policy can only be used to set the global policy to '
514                     'floating-point policies, such as "float32" and '
515                     '"mixed_float16", but got policy: %s'
516                     % (policy.name,))
517  _global_policy = policy
518  mixed_precision_global_state.set_using_mixed_precision_policy(is_mixed_policy)
519
520
521# TODO(reedwm): Make this thread local
522@contextlib.contextmanager
523def policy_scope(policy):
524  """A context manager that sets the global Policy under it.
525
526  Args:
527    policy: A Policy, or a string that will be converted to a Policy..
528
529  Yields:
530    Nothing.
531  """
532  old_policy = _global_policy
533  try:
534    set_policy(policy)
535    yield
536  finally:
537    set_policy(old_policy)
538
539
540def _is_convertible_to_dtype(dtype):
541  try:
542    dtypes.as_dtype(dtype)
543    return True
544  except TypeError:
545    return False
546
547
548def _policy_equivalent_to_dtype(policy):
549  """Returns True if the Policy is equivalent to a single dtype.
550
551  A policy is equivalent to a single dtype if the policy's compute and variable
552  dtypes are the same and the policy's type is Policy and not a subclass of
553  Policy (such as PolicyV1).
554
555  The "_infer" policy is considered equivalent to a single dtype.
556
557  Args:
558    policy: A Policy.
559
560  Returns:
561    True, if the policy is equivalent to a single dtype.
562  """
563  # We use type() instead of isinstance because a subclass of Policy is never
564  # equivalent to a dtype.
565  return (type(policy) == Policy and  # pylint: disable=unidiomatic-typecheck
566          list(policy.get_config().keys()) == ['name'] and
567          (policy.name == '_infer' or _is_convertible_to_dtype(policy.name)))
568
569
570def serialize(policy):
571  if _policy_equivalent_to_dtype(policy):
572    # We return either None or the policy name for compatibility with older
573    # versions of Keras. If the policy name is returned, it is a dtype string
574    # such as 'float32'.
575    return None if policy.name == '_infer' else policy.name
576  return generic_utils.serialize_keras_object(policy)
577
578
579def deserialize(config, custom_objects=None):
580  if isinstance(config, str) and _is_convertible_to_dtype(config):
581    return Policy(config)
582  if config is None:
583    return Policy('_infer')
584  module_objects = {'Policy': Policy, 'PolicyV1': Policy}
585  return generic_utils.deserialize_keras_object(
586      config,
587      module_objects=module_objects,
588      custom_objects=custom_objects,
589      printable_module_name='dtype policy')
590