1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the Policy class for mixed precision training.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21 22import six 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.keras import backend 26from tensorflow.python.keras.engine import base_layer_utils 27from tensorflow.python.keras.mixed_precision import device_compatibility_check 28from tensorflow.python.keras.mixed_precision import loss_scale as keras_loss_scale_module 29from tensorflow.python.keras.utils import generic_utils 30from tensorflow.python.platform import tf_logging 31from tensorflow.python.training.experimental import mixed_precision_global_state 32from tensorflow.python.util.tf_export import keras_export 33 34 35# pylint: disable=g-classes-have-attributes 36@keras_export('keras.mixed_precision.Policy', v1=[]) 37class Policy(object): 38 """A dtype policy for a Keras layer. 39 40 A dtype policy determines a layer's computation and variable dtypes. Each 41 layer has a policy. Policies can be passed to the `dtype` argument of layer 42 constructors, or a global policy can be set with 43 `tf.keras.mixed_precision.set_global_policy`. 44 45 Args: 46 name: The policy name, which determines the compute and variable dtypes. Can 47 be any dtype name, such as `'float32'` or `'float64'`, which causes both 48 the compute and variable dtypes will be that dtype. Can also be the string 49 `'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute dtype to 50 be float16 or bfloat16 and the variable dtype to be float32. 51 52 Typically you only need to interact with dtype policies when using mixed 53 precision, which is the use of float16 or bfloat16 for computations and 54 float32 for variables. This is why the term `mixed_precision` appears in the 55 API name. Mixed precision can be enabled by passing `'mixed_float16'` or 56 `'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the 57 mixed precision guide](https://www.tensorflow.org/guide/keras/mixed_precision) 58 for more information on how to use mixed precision. 59 60 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') 61 >>> layer1 = tf.keras.layers.Dense(10) 62 >>> layer1.dtype_policy # `layer1` will automatically use mixed precision 63 <Policy "mixed_float16"> 64 >>> # Can optionally override layer to use float32 instead of mixed precision. 65 >>> layer2 = tf.keras.layers.Dense(10, dtype='float32') 66 >>> layer2.dtype_policy 67 <Policy "float32"> 68 >>> # Set policy back to initial float32 for future examples. 69 >>> tf.keras.mixed_precision.set_global_policy('float32') 70 71 In the example above, passing `dtype='float32'` to the layer is equivalent to 72 passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general, 73 passing a dtype to a layer is equivalent to passing the corresponding policy, 74 so it is never necessary to explicitly construct a `Policy` object. 75 76 Note: `Model.compile` will automatically wrap an optimizer with a 77 `tf.keras.mixed_precision.LossScaleOptimizer` if you use the `'mixed_float16'` 78 policy. If you use a custom training loop instead of calling `Model.compile`, 79 you should explicitly use a `tf.keras.mixed_precision.LossScaleOptimizer` to 80 avoid numeric underflow with float16. 81 82 ### How a layer uses its policy's compute dtype 83 84 A layer casts its inputs to its compute dtype. This causes the layer's 85 computations and output to also be in the compute dtype. For example: 86 87 >>> x = tf.ones((4, 4, 4, 4), dtype='float64') 88 >>> # `layer`'s policy defaults to float32. 89 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) 90 >>> layer.compute_dtype # Equivalent to layer.dtype_policy.compute_dtype 91 'float32' 92 >>> # `layer` casts its inputs to its compute dtype and does computations in 93 >>> # that dtype. 94 >>> y = layer(x) 95 >>> y.dtype 96 tf.float32 97 98 Note that the base `tf.keras.layers.Layer` class inserts the casts. If 99 subclassing your own layer, you do not have to insert any casts. 100 101 Currently, only tensors in the first argument to the layer's `call` method are 102 casted (although this will likely be changed in a future minor release). For 103 example: 104 105 >>> class MyLayer(tf.keras.layers.Layer): 106 ... # Bug! `b` will not be casted. 107 ... def call(self, a, b): 108 ... return a + 1., b + 1. 109 >>> a = tf.constant(1., dtype="float32") 110 >>> b = tf.constant(1., dtype="float32") 111 >>> layer = MyLayer(dtype="float64") 112 >>> x, y = layer(a, b) 113 >>> x.dtype 114 tf.float64 115 >>> y.dtype 116 tf.float32 117 118 If writing your own layer with multiple inputs, you should either explicitly 119 cast other tensors to `self.compute_dtype` in `call` or accept all tensors in 120 the first argument as a list. 121 122 The casting only occurs in TensorFlow 2. If 123 `tf.compat.v1.disable_v2_behavior()` has been called, you can enable the 124 casting behavior with `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`. 125 126 ### How a layer uses its policy's variable dtype 127 128 The default dtype of variables created by `tf.keras.layers.Layer.add_weight` 129 is the layer's policy's variable dtype. 130 131 If a layer's compute and variable dtypes differ, `add_weight` will wrap 132 floating-point variables with a special wrapper called an `AutoCastVariable`. 133 `AutoCastVariable` is identical to the original variable except it casts 134 itself to the layer's compute dtype when used within `Layer.call`. This means 135 if you are writing a layer, you do not have to explicitly cast the variables 136 to the layer's compute dtype. For example: 137 138 >>> class SimpleDense(tf.keras.layers.Layer): 139 ... 140 ... def build(self, input_shape): 141 ... # With mixed precision, self.kernel is a float32 AutoCastVariable 142 ... self.kernel = self.add_weight('kernel', (input_shape[-1], 10)) 143 ... 144 ... def call(self, inputs): 145 ... # With mixed precision, self.kernel will be casted to float16 146 ... return tf.linalg.matmul(inputs, self.kernel) 147 ... 148 >>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16') 149 >>> layer = SimpleDense(dtype=dtype_policy) 150 >>> y = layer(tf.ones((10, 10))) 151 >>> y.dtype 152 tf.float16 153 >>> layer.kernel.dtype 154 tf.float32 155 156 A layer author can prevent a variable from being wrapped with an 157 `AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`, 158 which is useful if the float32 value of the variable must be accessed within 159 the layer. 160 161 ### How to write a layer that supports mixed precision and float64. 162 163 For the most part, layers will automatically support mixed precision and 164 float64 without any additional work, due to the fact the base layer 165 automatically casts inputs, creates variables of the correct type, and in the 166 case of mixed precision, wraps variables with `AutoCastVariables`. 167 168 The primary case where you need extra work to support mixed precision or 169 float64 is when you create a new tensor, such as with `tf.ones` or 170 `tf.random.normal`, In such cases, you must create the tensor of the correct 171 dtype. For example, if you call `tf.random.normal`, you must pass the compute 172 dtype, which is the dtype the inputs have been casted to: 173 174 >>> class AddRandom(tf.keras.layers.Layer): 175 ... 176 ... def call(self, inputs): 177 ... # We must pass `dtype=inputs.dtype`, otherwise a TypeError may 178 ... # occur when adding `inputs` to `rand`. 179 ... rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype) 180 ... return inputs + rand 181 182 >>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16') 183 >>> layer = AddRandom(dtype=dtype_policy) 184 >>> y = layer(x) 185 >>> y.dtype 186 tf.float16 187 188 If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a 189 `TypeError` would have occurred. This is because the `tf.random.normal`'s 190 dtype defaults to `"float32"`, but the input dtype is float16. You cannot add 191 a float32 tensor with a float16 tensor. 192 """ 193 194 def __init__(self, name): 195 if isinstance(name, dtypes.DType): 196 raise TypeError("'name' must be a string, not a DType. " 197 "Instead, pass DType.name. Got: %s" % (name.name,)) 198 elif not isinstance(name, six.string_types): 199 raise TypeError("'name' must be a string, but got: %s" % (name,)) 200 self._name = name 201 self._compute_dtype, self._variable_dtype = self._parse_name(name) 202 if name in ('mixed_float16', 'mixed_bloat16'): 203 device_compatibility_check.log_device_compatibility_check(name) 204 205 def _parse_name(self, name): 206 """Parses a Policy name into a compute and variable dtype. 207 208 Args: 209 name: The name of the policy: 210 211 Returns: 212 The (compute_dtype, variable_dtype) pair. 213 """ 214 if name.endswith('_float32_vars'): 215 error_msg = ('Policies ending in \'_float32_vars\' have been removed ' 216 'from TensorFlow.') 217 if name in ('infer_float32_vars', 'infer_with_float32_vars'): 218 error_msg += (' Please use the \'mixed_float16\' or \'mixed_bfloat16\' ' 219 'policy instead.') 220 elif name == 'float16_with_float32_vars': 221 error_msg += (' Please use the \'mixed_float16\' policy instead.') 222 elif name == 'bfloat16_with_float32_vars': 223 error_msg += (' Please use the \'mixed_bfloat16\' policy instead.') 224 error_msg += ' Got policy name: \'%s\'' % name 225 raise ValueError(error_msg) 226 227 if name == 'mixed_float16': 228 return 'float16', 'float32' 229 elif name == 'mixed_bfloat16': 230 return 'bfloat16', 'float32' 231 elif name == '_infer': 232 # The "_infer" policy exists only for compatibility with TF 1, where 233 # "_infer" is the default. The behavior matches the behavior of TF 1's 234 # behavior before policies were introduced. With "_infer", the computation 235 # and variable dtype are inferred from the first input the first time the 236 # layer is called. Once the layer is called for the first time, the 237 # layer's policy will change to the dtype of the first input, and it will 238 # no longer have the "_infer" policy. 239 # 240 # The infer policy should be considered an implementation detail and may 241 # be removed in the future. 242 return None, None 243 244 try: 245 dtype = dtypes.as_dtype(name).name 246 except TypeError: 247 error = ("Cannot convert value %s to a mixed precision Policy. " 248 "Valid policies include 'mixed_float16', 'mixed_bfloat16', " 249 "and the name of any dtype such as 'float32'." % (name,)) 250 # six.raise_from suppresses the original TypeError from being raised 251 six.raise_from(ValueError(error), None) 252 return dtype, dtype 253 254 @property 255 def variable_dtype(self): 256 """The variable dtype of this policy. 257 258 This is the dtype layers will create their variables in, unless a layer 259 explicitly chooses a different dtype. If this is different than 260 `Policy.compute_dtype`, Layers will cast variables to the compute dtype to 261 avoid type errors. 262 263 Variable regularizers are run in the variable dtype, not the compute dtype. 264 265 Returns: 266 The variable dtype of this policy, as a string. 267 """ 268 return self._variable_dtype 269 270 @property 271 def compute_dtype(self): 272 """The compute dtype of this policy. 273 274 This is the dtype layers will do their computations in. Typically layers 275 output tensors with the compute dtype as well. 276 277 Note that even if the compute dtype is float16 or bfloat16, hardware devices 278 may not do individual adds, multiplies, and other fundamental operations in 279 float16 or bfloat16, but instead may do some of them in float32 for numeric 280 stability. The compute dtype is the dtype of the inputs and outputs of the 281 TensorFlow ops that the layer executes. Internally, many TensorFlow ops will 282 do certain internal calculations in float32 or some other device-internal 283 intermediate format with higher precision than float16/bfloat16, to increase 284 numeric stability. 285 286 For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a 287 float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`. But, 288 `tf.linalg.matmul` will do use float32 intermediate math. The performance 289 benefit of float16 is still apparent, due to increased memory bandwidth and 290 the fact modern GPUs have specialized hardware for computing matmuls on 291 float16 inputs while still keeping intermediate computations in float32. 292 293 Returns: 294 The compute dtype of this policy, as a string. 295 """ 296 return self._compute_dtype 297 298 @property 299 def name(self): 300 """Returns the name of this policy.""" 301 return self._name 302 303 def __repr__(self): 304 return '<Policy "%s">' % self._name 305 306 def get_config(self): 307 return {'name': self.name} 308 309 @classmethod 310 def from_config(cls, config, custom_objects=None): 311 del custom_objects 312 if 'loss_scale' in config: 313 config = config.copy() 314 # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We 315 # silently drop it. 316 del config['loss_scale'] 317 return cls(**config) 318 319 320@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) 321class PolicyV1(Policy): 322 """A deprecated dtype policy for a Keras layer. 323 324 Warning: This class is now deprecated and will be removed soon. Please use the 325 non-experimental class `tf.keras.mixed_precision.Policy` instead. 326 327 The difference between this class and the non-experimental class is that this 328 class has a `loss_scale` field and the non-experimental class does not. The 329 loss scale is only used by `tf.keras.Model.compile`, which automatically wraps 330 the optimizer with a `LossScaleOptimizer` if the optimizer is not already a 331 `LossScaleOptimizer`. For the non-experimental Policy class, `Model.compile` 332 instead wraps the optimizer with a `LossScaleOptimizer` if `Policy.name` is 333 "mixed_float16". 334 335 When deserializing objects with an experimental policy using functions like 336 `tf.keras.utils.deserialize_keras_object`, the policy will be deserialized as 337 the non-experimental `tf.keras.mixed_precision.Policy`, and the loss scale 338 will silently be dropped. This is so that SavedModels that are generated 339 with an experimental policy can be restored after the experimental policy is 340 removed. 341 """ 342 343 def __init__(self, name, loss_scale='auto'): 344 """Constructs the policy. 345 346 The `name` argument determines the compute and variable dtype, the default 347 loss scale, and has no additional effect on the Policy. The compute and 348 variable dtypes can only be specified through `name`, and cannot be 349 specified directly. 350 351 Args: 352 name: A string. Can be one of the following values: 353 * Any dtype name, such as 'float32' or 'float64'. Both the variable and 354 compute dtypes will be that dtype. 355 * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or 356 bfloat16, while the variable dtype is float32. With 'mixed_float16', 357 a dynamic loss scale is used. These policies are used for mixed 358 precision training. 359 loss_scale: A `tf.compat.v1.mixed_precision.LossScale`, an int (which 360 uses a `FixedLossScale`), the string "dynamic" (which uses a 361 `DynamicLossScale`), or None (which uses no loss scale). Defaults to 362 `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then 363 use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only 364 `tf.keras.Model`s, not layers, use the loss scale, and it is only used 365 during `Model.fit`, `Model.train_on_batch`, and other similar methods. 366 """ 367 super(PolicyV1, self).__init__(name) 368 if loss_scale == 'auto': 369 loss_scale = 'dynamic' if name == 'mixed_float16' else None 370 self._using_default_loss_scale = True 371 else: 372 self._using_default_loss_scale = False 373 if loss_scale and self._compute_dtype not in (None, 'float16'): 374 tf_logging.warn('Creating a Policy with a loss scale is only useful for ' 375 'float16 policies. You passed loss_scale=%r for policy ' 376 '%s. Consider not passing any loss_scale instead.' % 377 (loss_scale, name)) 378 self._loss_scale = keras_loss_scale_module.get(loss_scale) 379 380 @property 381 def loss_scale(self): 382 """Returns the loss scale of this Policy. 383 384 Returns: 385 A `tf.compat.v1.mixed_precision.experimental.LossScale`, or None. 386 """ 387 return self._loss_scale 388 389 def __repr__(self): 390 return '<PolicyV1 "%s", loss_scale=%s>' % (self._name, self.loss_scale) 391 392 def get_config(self): 393 config = { 394 'name': self.name 395 } 396 if not self._using_default_loss_scale: 397 # We only include the loss scale if the default loss scale is not used. 398 # This allows us to change the loss scale config format without breaking 399 # users who use the default loss scale. 400 config['loss_scale'] = keras_loss_scale_module.serialize(self.loss_scale) 401 return config 402 403 @classmethod 404 def from_config(cls, config, custom_objects=None): 405 if 'loss_scale' in config and isinstance(config['loss_scale'], dict): 406 config = config.copy() 407 config['loss_scale'] = keras_loss_scale_module.deserialize( 408 config['loss_scale'], custom_objects=custom_objects) 409 return cls(**config) 410 411 412# The current global policy in effect. If None, it means the current value of 413# floatx should be used as the policy if the V2 dtype behavior is enabled, 414# or "_infer" otherwise. 415# TODO(reedwm): Make this thread local? 416_global_policy = None 417 418 419@keras_export('keras.mixed_precision.global_policy', 420 'keras.mixed_precision.experimental.global_policy', v1=[]) 421def global_policy(): 422 """Returns the global dtype policy. 423 424 The global policy is the default `tf.keras.mixed_precision.Policy` used for 425 layers, if no policy is passed to the layer constructor. If no policy has been 426 set with `keras.mixed_precision.set_global_policy`, this will return a policy 427 constructed from `tf.keras.backend.floatx()` (floatx defaults to float32). 428 429 >>> tf.keras.mixed_precision.global_policy() 430 <Policy "float32"> 431 >>> tf.keras.layers.Dense(10).dtype_policy # Defaults to the global policy 432 <Policy "float32"> 433 434 If TensorFlow 2 behavior has been disabled with 435 `tf.compat.v1.disable_v2_behavior()`, this will instead return a special 436 "_infer" policy which infers the dtype from the dtype of the first input the 437 first time the layer is called. This behavior matches the behavior that 438 existed in TensorFlow 1. 439 440 See `tf.keras.mixed_precision.Policy` for more information on policies. 441 442 Returns: 443 The global Policy. 444 """ 445 if _global_policy is None: 446 if base_layer_utils.v2_dtype_behavior_enabled(): 447 return Policy(backend.floatx()) 448 else: 449 return Policy('_infer') 450 return _global_policy 451 452 453def _check_if_mixed_precision_graph_rewrite_is_enabled(policy): 454 if mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled: 455 raise ValueError( 456 'The global dtype policy cannot be set to "{policy.name}", because the ' 457 'mixed precision graph rewrite has already been enabled.\n' 458 'At most, one of the following can be called:\n\n' 459 ' 1. tf.train.experimental.enable_mixed_precision_graph_rewrite() ' 460 '(You called this first)\n' 461 ' 2. tf.keras.mixed_precision.experimental.set_policy() with a mixed ' 462 'precision policy (You called this second)\n\n' 463 'You called both functions, which is an error, because both functions ' 464 'enable you to use mixed precision. If in doubt which function to use, ' 465 'use the second, as it supports Eager execution and is more ' 466 'customizable.'.format(policy=policy)) 467 468 469@keras_export('keras.mixed_precision.set_global_policy', 470 'keras.mixed_precision.experimental.set_policy', v1=[]) 471def set_policy(policy): 472 """Sets the global dtype policy. 473 474 The global policy is the default `tf.keras.mixed_precision.Policy` used for 475 layers, if no policy is passed to the layer constructor. 476 477 >>> tf.keras.mixed_precision.set_global_policy('mixed_float16') 478 >>> tf.keras.mixed_precision.global_policy() 479 <Policy "mixed_float16"> 480 >>> tf.keras.layers.Dense(10).dtype_policy 481 <Policy "mixed_float16"> 482 >>> # Global policy is not used if a policy is directly passed to constructor 483 >>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy 484 <Policy "float64"> 485 >>> tf.keras.mixed_precision.set_global_policy('float32') 486 487 If no global policy is set, layers will instead default to a Policy 488 constructed from `tf.keras.backend.floatx()`. 489 490 To use mixed precision, the global policy should be set to `'mixed_float16'` 491 or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and 492 float32 variable dtype by default. 493 494 Only floating point policies can be set as the global policy, such as 495 `'float32'` and `'mixed_float16'`. Non-floating point policies such as 496 `'int32'` and `'complex64'` cannot be set as the global policy because most 497 layers do not support such policies. 498 499 See `tf.keras.mixed_precision.Policy` for more information. 500 501 Args: 502 policy: A Policy, or a string that will be converted to a Policy. Can also 503 be None, in which case the global policy will be constructed from 504 `tf.keras.backend.floatx()` 505 """ 506 global _global_policy 507 if not base_layer_utils.v2_dtype_behavior_enabled(): 508 raise ValueError('The global policy can only be set in TensorFlow 2 or if ' 509 'V2 dtype behavior has been set. To enable V2 dtype ' 510 'behavior, call ' 511 '"tf.compat.v1.keras.layers.enable_v2_dtype_behavior()"') 512 if policy is not None and not isinstance(policy, Policy): 513 policy = Policy(policy) 514 is_mixed_policy = (policy is not None and 515 policy.compute_dtype != policy.variable_dtype) 516 if is_mixed_policy: 517 _check_if_mixed_precision_graph_rewrite_is_enabled(policy) 518 if (policy is not None and policy.compute_dtype is not None and 519 not dtypes.as_dtype(policy.compute_dtype).is_floating): 520 raise ValueError('set_policy can only be used to set the global policy to ' 521 'floating-point policies, such as "float32" and ' 522 '"mixed_float16", but got policy: %s' 523 % (policy.name,)) 524 _global_policy = policy 525 mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy 526 527 528# TODO(reedwm): Make this thread local 529@contextlib.contextmanager 530def policy_scope(policy): 531 """A context manager that sets the global Policy under it. 532 533 Args: 534 policy: A Policy, or a string that will be converted to a Policy.. 535 536 Yields: 537 Nothing. 538 """ 539 old_policy = _global_policy 540 try: 541 set_policy(policy) 542 yield 543 finally: 544 set_policy(old_policy) 545 546 547def _is_convertible_to_dtype(dtype): 548 try: 549 dtypes.as_dtype(dtype) 550 return True 551 except TypeError: 552 return False 553 554 555def _policy_equivalent_to_dtype(policy): 556 """Returns True if the Policy is equivalent to a single dtype. 557 558 A policy is equivalent to a single dtype if the policy's compute and variable 559 dtypes are the same and the policy's type is Policy and not a subclass of 560 Policy (such as PolicyV1). 561 562 The "_infer" policy is considered equivalent to a single dtype. 563 564 Args: 565 policy: A Policy. 566 567 Returns: 568 True, if the policy is equivalent to a single dtype. 569 """ 570 # We use type() instead of isinstance because a subclass of Policy is never 571 # equivalent to a dtype. 572 return (type(policy) == Policy and # pylint: disable=unidiomatic-typecheck 573 list(policy.get_config().keys()) == ['name'] and 574 (policy.name == '_infer' or _is_convertible_to_dtype(policy.name))) 575 576 577def serialize(policy): 578 if _policy_equivalent_to_dtype(policy): 579 # We return either None or the policy name for compatibility with older 580 # versions of Keras. If the policy name is returned, it is a dtype string 581 # such as 'float32'. 582 return None if policy.name == '_infer' else policy.name 583 return generic_utils.serialize_keras_object(policy) 584 585 586def deserialize(config, custom_objects=None): 587 if isinstance(config, str) and _is_convertible_to_dtype(config): 588 return Policy(config) 589 if config is None: 590 return Policy('_infer') 591 module_objects = {'Policy': Policy, 'PolicyV1': Policy} 592 return generic_utils.deserialize_keras_object( 593 config, 594 module_objects=module_objects, 595 custom_objects=custom_objects, 596 printable_module_name='dtype policy') 597