• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the Policy class for mixed precision training."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import contextlib
21
22import six
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.keras import backend
26from tensorflow.python.keras.engine import base_layer_utils
27from tensorflow.python.keras.mixed_precision import device_compatibility_check
28from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module
29from tensorflow.python.keras.utils import generic_utils
30from tensorflow.python.platform import tf_logging
31from tensorflow.python.training.experimental import mixed_precision_global_state
32from tensorflow.python.util.tf_export import keras_export
33
34
35# pylint: disable=g-classes-have-attributes
36@keras_export('keras.mixed_precision.Policy', v1=[])
37class Policy(object):
38  """A dtype policy for a Keras layer.
39
40  A dtype policy determines a layer's computation and variable dtypes. Each
41  layer has a policy. Policies can be passed to the `dtype` argument of layer
42  constructors, or a global policy can be set with
43  `tf.keras.mixed_precision.set_global_policy`.
44
45  Args:
46    name: The policy name, which determines the compute and variable dtypes. Can
47      be any dtype name, such as `'float32'` or `'float64'`, which causes both
48      the compute and variable dtypes will be that dtype. Can also be the string
49      `'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute dtype to
50      be float16 or bfloat16 and the variable dtype to be float32.
51
52  Typically you only need to interact with dtype policies when using mixed
53  precision, which is the use of float16 or bfloat16 for computations and
54  float32 for variables. This is why the term `mixed_precision` appears in the
55  API name. Mixed precision can be enabled by passing `'mixed_float16'` or
56  `'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the
57  mixed precision guide](https://www.tensorflow.org/guide/keras/mixed_precision)
58  for more information on how to use mixed precision.
59
60  >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
61  >>> layer1 = tf.keras.layers.Dense(10)
62  >>> layer1.dtype_policy  # `layer1` will automatically use mixed precision
63  <Policy "mixed_float16">
64  >>> # Can optionally override layer to use float32 instead of mixed precision.
65  >>> layer2 = tf.keras.layers.Dense(10, dtype='float32')
66  >>> layer2.dtype_policy
67  <Policy "float32">
68  >>> # Set policy back to initial float32 for future examples.
69  >>> tf.keras.mixed_precision.set_global_policy('float32')
70
71  In the example above, passing `dtype='float32'` to the layer is equivalent to
72  passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general,
73  passing a dtype to a layer is equivalent to passing the corresponding policy,
74  so it is never necessary to explicitly construct a `Policy` object.
75
76  Note: `Model.compile` will automatically wrap an optimizer with a
77  `tf.keras.mixed_precision.LossScaleOptimizer` if you use the `'mixed_float16'`
78  policy. If you use a custom training loop instead of calling `Model.compile`,
79  you should explicitly use a `tf.keras.mixed_precision.LossScaleOptimizer` to
80  avoid numeric underflow with float16.
81
82  ### How a layer uses its policy's compute dtype
83
84  A layer casts its inputs to its compute dtype. This causes the layer's
85  computations and output to also be in the compute dtype. For example:
86
87  >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
88  >>> # `layer`'s policy defaults to float32.
89  >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
90  >>> layer.compute_dtype  # Equivalent to layer.dtype_policy.compute_dtype
91  'float32'
92  >>> # `layer` casts its inputs to its compute dtype and does computations in
93  >>> # that dtype.
94  >>> y = layer(x)
95  >>> y.dtype
96  tf.float32
97
98  Note that the base `tf.keras.layers.Layer` class inserts the casts. If
99  subclassing your own layer, you do not have to insert any casts.
100
101  Currently, only tensors in the first argument to the layer's `call` method are
102  casted (although this will likely be changed in a future minor release). For
103  example:
104
105  >>> class MyLayer(tf.keras.layers.Layer):
106  ...   # Bug! `b` will not be casted.
107  ...   def call(self, a, b):
108  ...     return a + 1., b + 1.
109  >>> a = tf.constant(1., dtype="float32")
110  >>> b = tf.constant(1., dtype="float32")
111  >>> layer = MyLayer(dtype="float64")
112  >>> x, y = layer(a, b)
113  >>> x.dtype
114  tf.float64
115  >>> y.dtype
116  tf.float32
117
118  If writing your own layer with multiple inputs, you should either explicitly
119  cast other tensors to `self.compute_dtype` in `call` or accept all tensors in
120  the first argument as a list.
121
122  The casting only occurs in TensorFlow 2. If
123  `tf.compat.v1.disable_v2_behavior()` has been called, you can enable the
124  casting behavior with `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`.
125
126  ### How a layer uses its policy's variable dtype
127
128  The default dtype of variables created by `tf.keras.layers.Layer.add_weight`
129  is the layer's policy's variable dtype.
130
131  If a layer's compute and variable dtypes differ, `add_weight` will wrap
132  floating-point variables with a special wrapper called an `AutoCastVariable`.
133  `AutoCastVariable` is identical to the original variable except it casts
134  itself to the layer's compute dtype when used within `Layer.call`. This means
135  if you are writing a layer, you do not have to explicitly cast the variables
136  to the layer's compute dtype. For example:
137
138  >>> class SimpleDense(tf.keras.layers.Layer):
139  ...
140  ...   def build(self, input_shape):
141  ...     # With mixed precision, self.kernel is a float32 AutoCastVariable
142  ...     self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
143  ...
144  ...   def call(self, inputs):
145  ...     # With mixed precision, self.kernel will be casted to float16
146  ...     return tf.linalg.matmul(inputs, self.kernel)
147  ...
148  >>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16')
149  >>> layer = SimpleDense(dtype=dtype_policy)
150  >>> y = layer(tf.ones((10, 10)))
151  >>> y.dtype
152  tf.float16
153  >>> layer.kernel.dtype
154  tf.float32
155
156  A layer author can prevent a variable from being wrapped with an
157  `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`,
158  which is useful if the float32 value of the variable must be accessed within
159  the layer.
160
161  ### How to write a layer that supports mixed precision and float64.
162
163  For the most part, layers will automatically support mixed precision and
164  float64 without any additional work, due to the fact the base layer
165  automatically casts inputs, creates variables of the correct type, and in the
166  case of mixed precision, wraps variables with `AutoCastVariables`.
167
168  The primary case where you need extra work to support mixed precision or
169  float64 is when you create a new tensor, such as with `tf.ones` or
170  `tf.random.normal`, In such cases, you must create the tensor of the correct
171  dtype. For example, if you call `tf.random.normal`, you must pass the compute
172  dtype, which is the dtype the inputs have been casted to:
173
174  >>> class AddRandom(tf.keras.layers.Layer):
175  ...
176  ...   def call(self, inputs):
177  ...     # We must pass `dtype=inputs.dtype`, otherwise a TypeError may
178  ...     # occur when adding `inputs` to `rand`.
179  ...     rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype)
180  ...     return inputs + rand
181
182  >>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16')
183  >>> layer = AddRandom(dtype=dtype_policy)
184  >>> y = layer(x)
185  >>> y.dtype
186  tf.float16
187
188  If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a
189  `TypeError` would have occurred. This is because the `tf.random.normal`'s
190  dtype defaults to `"float32"`, but the input dtype is float16. You cannot add
191  a float32 tensor with a float16 tensor.
192  """
193
194  def __init__(self, name):
195    if isinstance(name, dtypes.DType):
196      raise TypeError("'name' must be a string, not a DType. "
197                      "Instead, pass DType.name. Got: %s" % (name.name,))
198    elif not isinstance(name, six.string_types):
199      raise TypeError("'name' must be a string, but got: %s" % (name,))
200    self._name = name
201    self._compute_dtype, self._variable_dtype = self._parse_name(name)
202    if name in ('mixed_float16', 'mixed_bloat16'):
203      device_compatibility_check.log_device_compatibility_check(name)
204
205  def _parse_name(self, name):
206    """Parses a Policy name into a compute and variable dtype.
207
208    Args:
209      name: The name of the policy:
210
211    Returns:
212      The (compute_dtype, variable_dtype) pair.
213    """
214    if name.endswith('_float32_vars'):
215      error_msg = ('Policies ending in \'_float32_vars\' have been removed '
216                   'from TensorFlow.')
217      if name in ('infer_float32_vars', 'infer_with_float32_vars'):
218        error_msg += (' Please use the \'mixed_float16\' or \'mixed_bfloat16\' '
219                      'policy instead.')
220      elif name == 'float16_with_float32_vars':
221        error_msg += (' Please use the \'mixed_float16\' policy instead.')
222      elif name == 'bfloat16_with_float32_vars':
223        error_msg += (' Please use the \'mixed_bfloat16\' policy instead.')
224      error_msg += ' Got policy name: \'%s\'' % name
225      raise ValueError(error_msg)
226
227    if name == 'mixed_float16':
228      return 'float16', 'float32'
229    elif name == 'mixed_bfloat16':
230      return 'bfloat16', 'float32'
231    elif name == '_infer':
232      # The "_infer" policy exists only for compatibility with TF 1, where
233      # "_infer" is the default. The behavior matches the behavior of TF 1's
234      # behavior before policies were introduced. With "_infer", the computation
235      # and variable dtype are inferred from the first input the first time the
236      # layer is called. Once the layer is called for the first time, the
237      # layer's policy will change to the dtype of the first input, and it will
238      # no longer have the "_infer" policy.
239      #
240      # The infer policy should be considered an implementation detail and may
241      # be removed in the future.
242      return None, None
243
244    try:
245      dtype = dtypes.as_dtype(name).name
246    except TypeError:
247      error = ("Cannot convert value %s to a mixed precision Policy. "
248               "Valid policies include 'mixed_float16', 'mixed_bfloat16', "
249               "and the name of any dtype such as 'float32'." % (name,))
250      # six.raise_from suppresses the original TypeError from being raised
251      six.raise_from(ValueError(error), None)
252    return dtype, dtype
253
254  @property
255  def variable_dtype(self):
256    """The variable dtype of this policy.
257
258    This is the dtype layers will create their variables in, unless a layer
259    explicitly chooses a different dtype. If this is different than
260    `Policy.compute_dtype`, Layers will cast variables to the compute dtype to
261    avoid type errors.
262
263    Variable regularizers are run in the variable dtype, not the compute dtype.
264
265    Returns:
266      The variable dtype of this policy, as a string.
267    """
268    return self._variable_dtype
269
270  @property
271  def compute_dtype(self):
272    """The compute dtype of this policy.
273
274    This is the dtype layers will do their computations in. Typically layers
275    output tensors with the compute dtype as well.
276
277    Note that even if the compute dtype is float16 or bfloat16, hardware devices
278    may not do individual adds, multiplies, and other fundamental operations in
279    float16 or bfloat16, but instead may do some of them in float32 for numeric
280    stability. The compute dtype is the dtype of the inputs and outputs of the
281    TensorFlow ops that the layer executes. Internally, many TensorFlow ops will
282    do certain internal calculations in float32 or some other device-internal
283    intermediate format with higher precision than float16/bfloat16, to increase
284    numeric stability.
285
286    For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a
287    float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`. But,
288    `tf.linalg.matmul` will do use float32 intermediate math. The performance
289    benefit of float16 is still apparent, due to increased memory bandwidth and
290    the fact modern GPUs have specialized hardware for computing matmuls on
291    float16 inputs while still keeping intermediate computations in float32.
292
293    Returns:
294      The compute dtype of this policy, as a string.
295    """
296    return self._compute_dtype
297
298  @property
299  def name(self):
300    """Returns the name of this policy."""
301    return self._name
302
303  def __repr__(self):
304    return '<Policy "%s">' % self._name
305
306  def get_config(self):
307    return {'name': self.name}
308
309  @classmethod
310  def from_config(cls, config, custom_objects=None):
311    del custom_objects
312    if 'loss_scale' in config:
313      config = config.copy()
314      # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We
315      # silently drop it.
316      del config['loss_scale']
317    return cls(**config)
318
319
320@keras_export('keras.mixed_precision.experimental.Policy', v1=[])
321class PolicyV1(Policy):
322  """A deprecated dtype policy for a Keras layer.
323
324  Warning: This class is now deprecated and will be removed soon. Please use the
325  non-experimental class `tf.keras.mixed_precision.Policy` instead.
326
327  The difference between this class and the non-experimental class is that this
328  class has a `loss_scale` field and the non-experimental class does not. The
329  loss scale is only used by `tf.keras.Model.compile`, which automatically wraps
330  the optimizer with a `LossScaleOptimizer` if the optimizer is not already a
331  `LossScaleOptimizer`. For the non-experimental Policy class, `Model.compile`
332  instead wraps the optimizer with a `LossScaleOptimizer` if `Policy.name` is
333  "mixed_float16".
334
335  When deserializing objects with an experimental policy using functions like
336  `tf.keras.utils.deserialize_keras_object`, the policy will be deserialized as
337  the non-experimental `tf.keras.mixed_precision.Policy`, and the loss scale
338  will silently be dropped. This is so that SavedModels that are generated
339  with an experimental policy can be restored after the experimental policy is
340  removed.
341  """
342
343  def __init__(self, name, loss_scale='auto'):
344    """Constructs the policy.
345
346    The `name` argument determines the compute and variable dtype, the default
347    loss scale, and has no additional effect on the Policy. The compute and
348    variable dtypes can only be specified through `name`, and cannot be
349    specified directly.
350
351    Args:
352      name: A string. Can be one of the following values:
353        * Any dtype name, such as 'float32' or 'float64'. Both the variable and
354          compute dtypes will be that dtype.
355        * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
356          bfloat16, while the variable dtype is float32. With 'mixed_float16',
357          a dynamic loss scale is used. These policies are used for mixed
358          precision training.
359      loss_scale: A `tf.compat.v1.mixed_precision.LossScale`, an int (which
360        uses a `FixedLossScale`), the string "dynamic" (which uses a
361        `DynamicLossScale`), or None (which uses no loss scale). Defaults to
362        `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then
363        use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only
364        `tf.keras.Model`s, not layers, use the loss scale, and it is only used
365        during `Model.fit`, `Model.train_on_batch`, and other similar methods.
366    """
367    super(PolicyV1, self).__init__(name)
368    if loss_scale == 'auto':
369      loss_scale = 'dynamic' if name == 'mixed_float16' else None
370      self._using_default_loss_scale = True
371    else:
372      self._using_default_loss_scale = False
373    if loss_scale and self._compute_dtype not in (None, 'float16'):
374      tf_logging.warn('Creating a Policy with a loss scale is only useful for '
375                      'float16 policies. You passed loss_scale=%r for policy '
376                      '%s. Consider not passing any loss_scale instead.' %
377                      (loss_scale, name))
378    self._loss_scale = keras_loss_scale_module.get(loss_scale)
379
380  @property
381  def loss_scale(self):
382    """Returns the loss scale of this Policy.
383
384    Returns:
385      A `tf.compat.v1.mixed_precision.experimental.LossScale`, or None.
386    """
387    return self._loss_scale
388
389  def __repr__(self):
390    return '<PolicyV1 "%s", loss_scale=%s>' % (self._name, self.loss_scale)
391
392  def get_config(self):
393    config = {
394        'name': self.name
395    }
396    if not self._using_default_loss_scale:
397      # We only include the loss scale if the default loss scale is not used.
398      # This allows us to change the loss scale config format without breaking
399      # users who use the default loss scale.
400      config['loss_scale'] = keras_loss_scale_module.serialize(self.loss_scale)
401    return config
402
403  @classmethod
404  def from_config(cls, config, custom_objects=None):
405    if 'loss_scale' in config and isinstance(config['loss_scale'], dict):
406      config = config.copy()
407      config['loss_scale'] = keras_loss_scale_module.deserialize(
408          config['loss_scale'], custom_objects=custom_objects)
409    return cls(**config)
410
411
412# The current global policy in effect. If None, it means the current value of
413# floatx should be used as the policy if the V2 dtype behavior is enabled,
414# or "_infer" otherwise.
415# TODO(reedwm): Make this thread local?
416_global_policy = None
417
418
419@keras_export('keras.mixed_precision.global_policy',
420              'keras.mixed_precision.experimental.global_policy', v1=[])
421def global_policy():
422  """Returns the global dtype policy.
423
424  The global policy is the default `tf.keras.mixed_precision.Policy` used for
425  layers, if no policy is passed to the layer constructor. If no policy has been
426  set with `keras.mixed_precision.set_global_policy`, this will return a policy
427  constructed from `tf.keras.backend.floatx()` (floatx defaults to float32).
428
429  >>> tf.keras.mixed_precision.global_policy()
430  <Policy "float32">
431  >>> tf.keras.layers.Dense(10).dtype_policy  # Defaults to the global policy
432  <Policy "float32">
433
434  If TensorFlow 2 behavior has been disabled with
435  `tf.compat.v1.disable_v2_behavior()`, this will instead return a special
436  "_infer" policy which infers the dtype from the dtype of the first input the
437  first time the layer is called. This behavior matches the behavior that
438  existed in TensorFlow 1.
439
440  See `tf.keras.mixed_precision.Policy` for more information on policies.
441
442  Returns:
443    The global Policy.
444  """
445  if _global_policy is None:
446    if base_layer_utils.v2_dtype_behavior_enabled():
447      return Policy(backend.floatx())
448    else:
449      return Policy('_infer')
450  return _global_policy
451
452
453def _check_if_mixed_precision_graph_rewrite_is_enabled(policy):
454  if mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
455    raise ValueError(
456        'The global dtype policy cannot be set to "{policy.name}", because the '
457        'mixed precision graph rewrite has already been enabled.\n'
458        'At most, one of the following can be called:\n\n'
459        '  1. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
460        '(You called this first)\n'
461        '  2. tf.keras.mixed_precision.experimental.set_policy() with a mixed '
462        'precision policy (You called this second)\n\n'
463        'You called both functions, which is an error, because both functions '
464        'enable you to use mixed precision. If in doubt which function to use, '
465        'use the second, as it supports Eager execution and is more '
466        'customizable.'.format(policy=policy))
467
468
469@keras_export('keras.mixed_precision.set_global_policy',
470              'keras.mixed_precision.experimental.set_policy', v1=[])
471def set_policy(policy):
472  """Sets the global dtype policy.
473
474  The global policy is the default `tf.keras.mixed_precision.Policy` used for
475  layers, if no policy is passed to the layer constructor.
476
477  >>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
478  >>> tf.keras.mixed_precision.global_policy()
479  <Policy "mixed_float16">
480  >>> tf.keras.layers.Dense(10).dtype_policy
481  <Policy "mixed_float16">
482  >>> # Global policy is not used if a policy is directly passed to constructor
483  >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy
484  <Policy "float64">
485  >>> tf.keras.mixed_precision.set_global_policy('float32')
486
487  If no global policy is set, layers will instead default to a Policy
488  constructed from `tf.keras.backend.floatx()`.
489
490  To use mixed precision, the global policy should be set to `'mixed_float16'`
491  or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and
492  float32 variable dtype by default.
493
494  Only floating point policies can be set as the global policy, such as
495  `'float32'` and `'mixed_float16'`. Non-floating point policies such as
496  `'int32'` and `'complex64'` cannot be set as the global policy because most
497  layers do not support such policies.
498
499  See `tf.keras.mixed_precision.Policy` for more information.
500
501  Args:
502    policy: A Policy, or a string that will be converted to a Policy. Can also
503      be None, in which case the global policy will be constructed from
504      `tf.keras.backend.floatx()`
505  """
506  global _global_policy
507  if not base_layer_utils.v2_dtype_behavior_enabled():
508    raise ValueError('The global policy can only be set in TensorFlow 2 or if '
509                     'V2 dtype behavior has been set. To enable V2 dtype '
510                     'behavior, call '
511                     '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"')
512  if policy is not None and not isinstance(policy, Policy):
513    policy = Policy(policy)
514  is_mixed_policy = (policy is not None and
515                     policy.compute_dtype != policy.variable_dtype)
516  if is_mixed_policy:
517    _check_if_mixed_precision_graph_rewrite_is_enabled(policy)
518  if (policy is not None and policy.compute_dtype is not None and
519      not dtypes.as_dtype(policy.compute_dtype).is_floating):
520    raise ValueError('set_policy can only be used to set the global policy to '
521                     'floating-point policies, such as "float32" and '
522                     '"mixed_float16", but got policy: %s'
523                     % (policy.name,))
524  _global_policy = policy
525  mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy
526
527
528# TODO(reedwm): Make this thread local
529@contextlib.contextmanager
530def policy_scope(policy):
531  """A context manager that sets the global Policy under it.
532
533  Args:
534    policy: A Policy, or a string that will be converted to a Policy..
535
536  Yields:
537    Nothing.
538  """
539  old_policy = _global_policy
540  try:
541    set_policy(policy)
542    yield
543  finally:
544    set_policy(old_policy)
545
546
547def _is_convertible_to_dtype(dtype):
548  try:
549    dtypes.as_dtype(dtype)
550    return True
551  except TypeError:
552    return False
553
554
555def _policy_equivalent_to_dtype(policy):
556  """Returns True if the Policy is equivalent to a single dtype.
557
558  A policy is equivalent to a single dtype if the policy's compute and variable
559  dtypes are the same and the policy's type is Policy and not a subclass of
560  Policy (such as PolicyV1).
561
562  The "_infer" policy is considered equivalent to a single dtype.
563
564  Args:
565    policy: A Policy.
566
567  Returns:
568    True, if the policy is equivalent to a single dtype.
569  """
570  # We use type() instead of isinstance because a subclass of Policy is never
571  # equivalent to a dtype.
572  return (type(policy) == Policy and  # pylint: disable=unidiomatic-typecheck
573          list(policy.get_config().keys()) == ['name'] and
574          (policy.name == '_infer' or _is_convertible_to_dtype(policy.name)))
575
576
577def serialize(policy):
578  if _policy_equivalent_to_dtype(policy):
579    # We return either None or the policy name for compatibility with older
580    # versions of Keras. If the policy name is returned, it is a dtype string
581    # such as 'float32'.
582    return None if policy.name == '_infer' else policy.name
583  return generic_utils.serialize_keras_object(policy)
584
585
586def deserialize(config, custom_objects=None):
587  if isinstance(config, str) and _is_convertible_to_dtype(config):
588    return Policy(config)
589  if config is None:
590    return Policy('_infer')
591  module_objects = {'Policy': Policy, 'PolicyV1': Policy}
592  return generic_utils.deserialize_keras_object(
593      config,
594      module_objects=module_objects,
595      custom_objects=custom_objects,
596      printable_module_name='dtype policy')
597