1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python utilities required by Keras.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import binascii 21import codecs 22import importlib 23import marshal 24import os 25import re 26import sys 27import threading 28import time 29import types as python_types 30import weakref 31 32import numpy as np 33import six 34from tensorflow.python.keras.utils import tf_contextlib 35from tensorflow.python.keras.utils import tf_inspect 36from tensorflow.python.util import nest 37from tensorflow.python.util import tf_decorator 38from tensorflow.python.util.tf_export import keras_export 39 40_GLOBAL_CUSTOM_OBJECTS = {} 41_GLOBAL_CUSTOM_NAMES = {} 42 43# Flag that determines whether to skip the NotImplementedError when calling 44# get_config in custom models and layers. This is only enabled when saving to 45# SavedModel, when the config isn't required. 46_SKIP_FAILED_SERIALIZATION = False 47# If a layer does not have a defined config, then the returned config will be a 48# dictionary with the below key. 49_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' 50 51 52@keras_export('keras.utils.custom_object_scope', # pylint: disable=g-classes-have-attributes 53 'keras.utils.CustomObjectScope') 54class CustomObjectScope(object): 55 """Exposes custom classes/functions to Keras deserialization internals. 56 57 Under a scope `with custom_object_scope(objects_dict)`, Keras methods such 58 as `tf.keras.models.load_model` or `tf.keras.models.model_from_config` 59 will be able to deserialize any custom object referenced by a 60 saved config (e.g. a custom layer or metric). 61 62 Example: 63 64 Consider a custom regularizer `my_regularizer`: 65 66 ```python 67 layer = Dense(3, kernel_regularizer=my_regularizer) 68 config = layer.get_config() # Config contains a reference to `my_regularizer` 69 ... 70 # Later: 71 with custom_object_scope({'my_regularizer': my_regularizer}): 72 layer = Dense.from_config(config) 73 ``` 74 75 Args: 76 *args: Dictionary or dictionaries of `{name: object}` pairs. 77 """ 78 79 def __init__(self, *args): 80 self.custom_objects = args 81 self.backup = None 82 83 def __enter__(self): 84 self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() 85 for objects in self.custom_objects: 86 _GLOBAL_CUSTOM_OBJECTS.update(objects) 87 return self 88 89 def __exit__(self, *args, **kwargs): 90 _GLOBAL_CUSTOM_OBJECTS.clear() 91 _GLOBAL_CUSTOM_OBJECTS.update(self.backup) 92 93 94@keras_export('keras.utils.get_custom_objects') 95def get_custom_objects(): 96 """Retrieves a live reference to the global dictionary of custom objects. 97 98 Updating and clearing custom objects using `custom_object_scope` 99 is preferred, but `get_custom_objects` can 100 be used to directly access the current collection of custom objects. 101 102 Example: 103 104 ```python 105 get_custom_objects().clear() 106 get_custom_objects()['MyObject'] = MyObject 107 ``` 108 109 Returns: 110 Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). 111 """ 112 return _GLOBAL_CUSTOM_OBJECTS 113 114 115# Store a unique, per-object ID for shared objects. 116# 117# We store a unique ID for each object so that we may, at loading time, 118# re-create the network properly. Without this ID, we would have no way of 119# determining whether a config is a description of a new object that 120# should be created or is merely a reference to an already-created object. 121SHARED_OBJECT_KEY = 'shared_object_id' 122 123 124SHARED_OBJECT_DISABLED = threading.local() 125SHARED_OBJECT_LOADING = threading.local() 126SHARED_OBJECT_SAVING = threading.local() 127 128 129# Attributes on the threadlocal variable must be set per-thread, thus we 130# cannot initialize these globally. Instead, we have accessor functions with 131# default values. 132def _shared_object_disabled(): 133 """Get whether shared object handling is disabled in a threadsafe manner.""" 134 return getattr(SHARED_OBJECT_DISABLED, 'disabled', False) 135 136 137def _shared_object_loading_scope(): 138 """Get the current shared object saving scope in a threadsafe manner.""" 139 return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope()) 140 141 142def _shared_object_saving_scope(): 143 """Get the current shared object saving scope in a threadsafe manner.""" 144 return getattr(SHARED_OBJECT_SAVING, 'scope', None) 145 146 147class DisableSharedObjectScope(object): 148 """A context manager for disabling handling of shared objects. 149 150 Disables shared object handling for both saving and loading. 151 152 Created primarily for use with `clone_model`, which does extra surgery that 153 is incompatible with shared objects. 154 """ 155 156 def __enter__(self): 157 SHARED_OBJECT_DISABLED.disabled = True 158 self._orig_loading_scope = _shared_object_loading_scope() 159 self._orig_saving_scope = _shared_object_saving_scope() 160 161 def __exit__(self, *args, **kwargs): 162 SHARED_OBJECT_DISABLED.disabled = False 163 SHARED_OBJECT_LOADING.scope = self._orig_loading_scope 164 SHARED_OBJECT_SAVING.scope = self._orig_saving_scope 165 166 167class NoopLoadingScope(object): 168 """The default shared object loading scope. It does nothing. 169 170 Created to simplify serialization code that doesn't care about shared objects 171 (e.g. when serializing a single object). 172 """ 173 174 def get(self, unused_object_id): 175 return None 176 177 def set(self, object_id, obj): 178 pass 179 180 181class SharedObjectLoadingScope(object): 182 """A context manager for keeping track of loaded objects. 183 184 During the deserialization process, we may come across objects that are 185 shared across multiple layers. In order to accurately restore the network 186 structure to its original state, `SharedObjectLoadingScope` allows us to 187 re-use shared objects rather than cloning them. 188 """ 189 190 def __enter__(self): 191 if _shared_object_disabled(): 192 return NoopLoadingScope() 193 194 global SHARED_OBJECT_LOADING 195 SHARED_OBJECT_LOADING.scope = self 196 self._obj_ids_to_obj = {} 197 return self 198 199 def get(self, object_id): 200 """Given a shared object ID, returns a previously instantiated object. 201 202 Args: 203 object_id: shared object ID to use when attempting to find already-loaded 204 object. 205 206 Returns: 207 The object, if we've seen this ID before. Else, `None`. 208 """ 209 # Explicitly check for `None` internally to make external calling code a 210 # bit cleaner. 211 if object_id is None: 212 return 213 return self._obj_ids_to_obj.get(object_id) 214 215 def set(self, object_id, obj): 216 """Stores an instantiated object for future lookup and sharing.""" 217 if object_id is None: 218 return 219 self._obj_ids_to_obj[object_id] = obj 220 221 def __exit__(self, *args, **kwargs): 222 global SHARED_OBJECT_LOADING 223 SHARED_OBJECT_LOADING.scope = NoopLoadingScope() 224 225 226class SharedObjectConfig(dict): 227 """A configuration container that keeps track of references. 228 229 `SharedObjectConfig` will automatically attach a shared object ID to any 230 configs which are referenced more than once, allowing for proper shared 231 object reconstruction at load time. 232 233 In most cases, it would be more proper to subclass something like 234 `collections.UserDict` or `collections.Mapping` rather than `dict` directly. 235 Unfortunately, python's json encoder does not support `Mapping`s. This is 236 important functionality to retain, since we are dealing with serialization. 237 238 We should be safe to subclass `dict` here, since we aren't actually 239 overriding any core methods, only augmenting with a new one for reference 240 counting. 241 """ 242 243 def __init__(self, base_config, object_id, **kwargs): 244 self.ref_count = 1 245 self.object_id = object_id 246 super(SharedObjectConfig, self).__init__(base_config, **kwargs) 247 248 def increment_ref_count(self): 249 # As soon as we've seen the object more than once, we want to attach the 250 # shared object ID. This allows us to only attach the shared object ID when 251 # it's strictly necessary, making backwards compatibility breakage less 252 # likely. 253 if self.ref_count == 1: 254 self[SHARED_OBJECT_KEY] = self.object_id 255 self.ref_count += 1 256 257 258class SharedObjectSavingScope(object): 259 """Keeps track of shared object configs when serializing.""" 260 261 def __enter__(self): 262 if _shared_object_disabled(): 263 return None 264 265 global SHARED_OBJECT_SAVING 266 267 # Serialization can happen at a number of layers for a number of reasons. 268 # We may end up with a case where we're opening a saving scope within 269 # another saving scope. In that case, we'd like to use the outermost scope 270 # available and ignore inner scopes, since there is not (yet) a reasonable 271 # use case for having these nested and distinct. 272 if _shared_object_saving_scope() is not None: 273 self._passthrough = True 274 return _shared_object_saving_scope() 275 else: 276 self._passthrough = False 277 278 SHARED_OBJECT_SAVING.scope = self 279 self._shared_objects_config = weakref.WeakKeyDictionary() 280 self._next_id = 0 281 return self 282 283 def get_config(self, obj): 284 """Gets a `SharedObjectConfig` if one has already been seen for `obj`. 285 286 Args: 287 obj: The object for which to retrieve the `SharedObjectConfig`. 288 289 Returns: 290 The SharedObjectConfig for a given object, if already seen. Else, 291 `None`. 292 """ 293 try: 294 shared_object_config = self._shared_objects_config[obj] 295 except (TypeError, KeyError): 296 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` 297 # that has not overridden `__hash__`), a `TypeError` will be thrown. 298 # We'll just continue on without shared object support. 299 return None 300 shared_object_config.increment_ref_count() 301 return shared_object_config 302 303 def create_config(self, base_config, obj): 304 """Create a new SharedObjectConfig for a given object.""" 305 shared_object_config = SharedObjectConfig(base_config, self._next_id) 306 self._next_id += 1 307 try: 308 self._shared_objects_config[obj] = shared_object_config 309 except TypeError: 310 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` 311 # that has not overridden `__hash__`), a `TypeError` will be thrown. 312 # We'll just continue on without shared object support. 313 pass 314 return shared_object_config 315 316 def __exit__(self, *args, **kwargs): 317 if not getattr(self, '_passthrough', False): 318 global SHARED_OBJECT_SAVING 319 SHARED_OBJECT_SAVING.scope = None 320 321 322def serialize_keras_class_and_config( 323 cls_name, cls_config, obj=None, shared_object_id=None): 324 """Returns the serialization of the class with the given config.""" 325 base_config = {'class_name': cls_name, 'config': cls_config} 326 327 # We call `serialize_keras_class_and_config` for some branches of the load 328 # path. In that case, we may already have a shared object ID we'd like to 329 # retain. 330 if shared_object_id is not None: 331 base_config[SHARED_OBJECT_KEY] = shared_object_id 332 333 # If we have an active `SharedObjectSavingScope`, check whether we've already 334 # serialized this config. If so, just use that config. This will store an 335 # extra ID field in the config, allowing us to re-create the shared object 336 # relationship at load time. 337 if _shared_object_saving_scope() is not None and obj is not None: 338 shared_object_config = _shared_object_saving_scope().get_config(obj) 339 if shared_object_config is None: 340 return _shared_object_saving_scope().create_config(base_config, obj) 341 return shared_object_config 342 343 return base_config 344 345 346@keras_export('keras.utils.register_keras_serializable') 347def register_keras_serializable(package='Custom', name=None): 348 """Registers an object with the Keras serialization framework. 349 350 This decorator injects the decorated class or function into the Keras custom 351 object dictionary, so that it can be serialized and deserialized without 352 needing an entry in the user-provided custom object dict. It also injects a 353 function that Keras will call to get the object's serializable string key. 354 355 Note that to be serialized and deserialized, classes must implement the 356 `get_config()` method. Functions do not have this requirement. 357 358 The object will be registered under the key 'package>name' where `name`, 359 defaults to the object name if not passed. 360 361 Args: 362 package: The package that this class belongs to. 363 name: The name to serialize this class under in this package. If None, the 364 class' name will be used. 365 366 Returns: 367 A decorator that registers the decorated class with the passed names. 368 """ 369 370 def decorator(arg): 371 """Registers a class with the Keras serialization framework.""" 372 class_name = name if name is not None else arg.__name__ 373 registered_name = package + '>' + class_name 374 375 if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'): 376 raise ValueError( 377 'Cannot register a class that does not have a get_config() method.') 378 379 if registered_name in _GLOBAL_CUSTOM_OBJECTS: 380 raise ValueError( 381 '%s has already been registered to %s' % 382 (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name])) 383 384 if arg in _GLOBAL_CUSTOM_NAMES: 385 raise ValueError('%s has already been registered to %s' % 386 (arg, _GLOBAL_CUSTOM_NAMES[arg])) 387 _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg 388 _GLOBAL_CUSTOM_NAMES[arg] = registered_name 389 390 return arg 391 392 return decorator 393 394 395@keras_export('keras.utils.get_registered_name') 396def get_registered_name(obj): 397 """Returns the name registered to an object within the Keras framework. 398 399 This function is part of the Keras serialization and deserialization 400 framework. It maps objects to the string names associated with those objects 401 for serialization/deserialization. 402 403 Args: 404 obj: The object to look up. 405 406 Returns: 407 The name associated with the object, or the default Python name if the 408 object is not registered. 409 """ 410 if obj in _GLOBAL_CUSTOM_NAMES: 411 return _GLOBAL_CUSTOM_NAMES[obj] 412 else: 413 return obj.__name__ 414 415 416@tf_contextlib.contextmanager 417def skip_failed_serialization(): 418 global _SKIP_FAILED_SERIALIZATION 419 prev = _SKIP_FAILED_SERIALIZATION 420 try: 421 _SKIP_FAILED_SERIALIZATION = True 422 yield 423 finally: 424 _SKIP_FAILED_SERIALIZATION = prev 425 426 427@keras_export('keras.utils.get_registered_object') 428def get_registered_object(name, custom_objects=None, module_objects=None): 429 """Returns the class associated with `name` if it is registered with Keras. 430 431 This function is part of the Keras serialization and deserialization 432 framework. It maps strings to the objects associated with them for 433 serialization/deserialization. 434 435 Example: 436 ``` 437 def from_config(cls, config, custom_objects=None): 438 if 'my_custom_object_name' in config: 439 config['hidden_cls'] = tf.keras.utils.get_registered_object( 440 config['my_custom_object_name'], custom_objects=custom_objects) 441 ``` 442 443 Args: 444 name: The name to look up. 445 custom_objects: A dictionary of custom objects to look the name up in. 446 Generally, custom_objects is provided by the user. 447 module_objects: A dictionary of custom objects to look the name up in. 448 Generally, module_objects is provided by midlevel library implementers. 449 450 Returns: 451 An instantiable class associated with 'name', or None if no such class 452 exists. 453 """ 454 if name in _GLOBAL_CUSTOM_OBJECTS: 455 return _GLOBAL_CUSTOM_OBJECTS[name] 456 elif custom_objects and name in custom_objects: 457 return custom_objects[name] 458 elif module_objects and name in module_objects: 459 return module_objects[name] 460 return None 461 462 463@keras_export('keras.utils.serialize_keras_object') 464def serialize_keras_object(instance): 465 """Serialize a Keras object into a JSON-compatible representation. 466 467 Calls to `serialize_keras_object` while underneath the 468 `SharedObjectSavingScope` context manager will cause any objects re-used 469 across multiple layers to be saved with a special shared object ID. This 470 allows the network to be re-created properly during deserialization. 471 472 Args: 473 instance: The object to serialize. 474 475 Returns: 476 A dict-like, JSON-compatible representation of the object's config. 477 """ 478 _, instance = tf_decorator.unwrap(instance) 479 if instance is None: 480 return None 481 482 if hasattr(instance, 'get_config'): 483 name = get_registered_name(instance.__class__) 484 try: 485 config = instance.get_config() 486 except NotImplementedError as e: 487 if _SKIP_FAILED_SERIALIZATION: 488 return serialize_keras_class_and_config( 489 name, {_LAYER_UNDEFINED_CONFIG_KEY: True}) 490 raise e 491 serialization_config = {} 492 for key, item in config.items(): 493 if isinstance(item, six.string_types): 494 serialization_config[key] = item 495 continue 496 497 # Any object of a different type needs to be converted to string or dict 498 # for serialization (e.g. custom functions, custom classes) 499 try: 500 serialized_item = serialize_keras_object(item) 501 if isinstance(serialized_item, dict) and not isinstance(item, dict): 502 serialized_item['__passive_serialization__'] = True 503 serialization_config[key] = serialized_item 504 except ValueError: 505 serialization_config[key] = item 506 507 name = get_registered_name(instance.__class__) 508 return serialize_keras_class_and_config( 509 name, serialization_config, instance) 510 if hasattr(instance, '__name__'): 511 return get_registered_name(instance) 512 raise ValueError('Cannot serialize', instance) 513 514 515def get_custom_objects_by_name(item, custom_objects=None): 516 """Returns the item if it is in either local or global custom objects.""" 517 if item in _GLOBAL_CUSTOM_OBJECTS: 518 return _GLOBAL_CUSTOM_OBJECTS[item] 519 elif custom_objects and item in custom_objects: 520 return custom_objects[item] 521 return None 522 523 524def class_and_config_for_serialized_keras_object( 525 config, 526 module_objects=None, 527 custom_objects=None, 528 printable_module_name='object'): 529 """Returns the class name and config for a serialized keras object.""" 530 if (not isinstance(config, dict) 531 or 'class_name' not in config 532 or 'config' not in config): 533 raise ValueError('Improper config format: ' + str(config)) 534 535 class_name = config['class_name'] 536 cls = get_registered_object(class_name, custom_objects, module_objects) 537 if cls is None: 538 raise ValueError( 539 'Unknown {}: {}. Please ensure this object is ' 540 'passed to the `custom_objects` argument. See ' 541 'https://www.tensorflow.org/guide/keras/save_and_serialize' 542 '#registering_the_custom_object for details.' 543 .format(printable_module_name, class_name)) 544 545 cls_config = config['config'] 546 # Check if `cls_config` is a list. If it is a list, return the class and the 547 # associated class configs for recursively deserialization. This case will 548 # happen on the old version of sequential model (e.g. `keras_version` == 549 # "2.0.6"), which is serialized in a different structure, for example 550 # "{'class_name': 'Sequential', 551 # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}". 552 if isinstance(cls_config, list): 553 return (cls, cls_config) 554 555 deserialized_objects = {} 556 for key, item in cls_config.items(): 557 if isinstance(item, dict) and '__passive_serialization__' in item: 558 deserialized_objects[key] = deserialize_keras_object( 559 item, 560 module_objects=module_objects, 561 custom_objects=custom_objects, 562 printable_module_name='config_item') 563 # TODO(momernick): Should this also have 'module_objects'? 564 elif (isinstance(item, six.string_types) and 565 tf_inspect.isfunction(get_registered_object(item, custom_objects))): 566 # Handle custom functions here. When saving functions, we only save the 567 # function's name as a string. If we find a matching string in the custom 568 # objects during deserialization, we convert the string back to the 569 # original function. 570 # Note that a potential issue is that a string field could have a naming 571 # conflict with a custom function name, but this should be a rare case. 572 # This issue does not occur if a string field has a naming conflict with 573 # a custom object, since the config of an object will always be a dict. 574 deserialized_objects[key] = get_registered_object(item, custom_objects) 575 for key, item in deserialized_objects.items(): 576 cls_config[key] = deserialized_objects[key] 577 578 return (cls, cls_config) 579 580 581@keras_export('keras.utils.deserialize_keras_object') 582def deserialize_keras_object(identifier, 583 module_objects=None, 584 custom_objects=None, 585 printable_module_name='object'): 586 """Turns the serialized form of a Keras object back into an actual object. 587 588 Calls to `deserialize_keras_object` while underneath the 589 `SharedObjectLoadingScope` context manager will cause any already-seen shared 590 objects to be returned as-is rather than creating a new object. 591 592 Args: 593 identifier: the serialized form of the object. 594 module_objects: A dictionary of custom objects to look the name up in. 595 Generally, module_objects is provided by midlevel library implementers. 596 custom_objects: A dictionary of custom objects to look the name up in. 597 Generally, custom_objects is provided by the user. 598 printable_module_name: A human-readable string representing the type of the 599 object. Printed in case of exception. 600 601 Returns: 602 The deserialized object. 603 """ 604 if identifier is None: 605 return None 606 607 if isinstance(identifier, dict): 608 # In this case we are dealing with a Keras config dictionary. 609 config = identifier 610 (cls, cls_config) = class_and_config_for_serialized_keras_object( 611 config, module_objects, custom_objects, printable_module_name) 612 613 # If this object has already been loaded (i.e. it's shared between multiple 614 # objects), return the already-loaded object. 615 shared_object_id = config.get(SHARED_OBJECT_KEY) 616 shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none 617 if shared_object is not None: 618 return shared_object 619 620 if hasattr(cls, 'from_config'): 621 arg_spec = tf_inspect.getfullargspec(cls.from_config) 622 custom_objects = custom_objects or {} 623 624 if 'custom_objects' in arg_spec.args: 625 deserialized_obj = cls.from_config( 626 cls_config, 627 custom_objects=dict( 628 list(_GLOBAL_CUSTOM_OBJECTS.items()) + 629 list(custom_objects.items()))) 630 else: 631 with CustomObjectScope(custom_objects): 632 deserialized_obj = cls.from_config(cls_config) 633 else: 634 # Then `cls` may be a function returning a class. 635 # in this case by convention `config` holds 636 # the kwargs of the function. 637 custom_objects = custom_objects or {} 638 with CustomObjectScope(custom_objects): 639 deserialized_obj = cls(**cls_config) 640 641 # Add object to shared objects, in case we find it referenced again. 642 _shared_object_loading_scope().set(shared_object_id, deserialized_obj) 643 644 return deserialized_obj 645 646 elif isinstance(identifier, six.string_types): 647 object_name = identifier 648 if custom_objects and object_name in custom_objects: 649 obj = custom_objects.get(object_name) 650 elif object_name in _GLOBAL_CUSTOM_OBJECTS: 651 obj = _GLOBAL_CUSTOM_OBJECTS[object_name] 652 else: 653 obj = module_objects.get(object_name) 654 if obj is None: 655 raise ValueError( 656 'Unknown {}: {}. Please ensure this object is ' 657 'passed to the `custom_objects` argument. See ' 658 'https://www.tensorflow.org/guide/keras/save_and_serialize' 659 '#registering_the_custom_object for details.' 660 .format(printable_module_name, object_name)) 661 662 # Classes passed by name are instantiated with no args, functions are 663 # returned as-is. 664 if tf_inspect.isclass(obj): 665 return obj() 666 return obj 667 elif tf_inspect.isfunction(identifier): 668 # If a function has already been deserialized, return as is. 669 return identifier 670 else: 671 raise ValueError('Could not interpret serialized %s: %s' % 672 (printable_module_name, identifier)) 673 674 675def func_dump(func): 676 """Serializes a user defined function. 677 678 Args: 679 func: the function to serialize. 680 681 Returns: 682 A tuple `(code, defaults, closure)`. 683 """ 684 if os.name == 'nt': 685 raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/') 686 code = codecs.encode(raw_code, 'base64').decode('ascii') 687 else: 688 raw_code = marshal.dumps(func.__code__) 689 code = codecs.encode(raw_code, 'base64').decode('ascii') 690 defaults = func.__defaults__ 691 if func.__closure__: 692 closure = tuple(c.cell_contents for c in func.__closure__) 693 else: 694 closure = None 695 return code, defaults, closure 696 697 698def func_load(code, defaults=None, closure=None, globs=None): 699 """Deserializes a user defined function. 700 701 Args: 702 code: bytecode of the function. 703 defaults: defaults of the function. 704 closure: closure of the function. 705 globs: dictionary of global objects. 706 707 Returns: 708 A function object. 709 """ 710 if isinstance(code, (tuple, list)): # unpack previous dump 711 code, defaults, closure = code 712 if isinstance(defaults, list): 713 defaults = tuple(defaults) 714 715 def ensure_value_to_cell(value): 716 """Ensures that a value is converted to a python cell object. 717 718 Args: 719 value: Any value that needs to be casted to the cell type 720 721 Returns: 722 A value wrapped as a cell object (see function "func_load") 723 """ 724 725 def dummy_fn(): 726 # pylint: disable=pointless-statement 727 value # just access it so it gets captured in .__closure__ 728 729 cell_value = dummy_fn.__closure__[0] 730 if not isinstance(value, type(cell_value)): 731 return cell_value 732 return value 733 734 if closure is not None: 735 closure = tuple(ensure_value_to_cell(_) for _ in closure) 736 try: 737 raw_code = codecs.decode(code.encode('ascii'), 'base64') 738 except (UnicodeEncodeError, binascii.Error): 739 raw_code = code.encode('raw_unicode_escape') 740 code = marshal.loads(raw_code) 741 if globs is None: 742 globs = globals() 743 return python_types.FunctionType( 744 code, globs, name=code.co_name, argdefs=defaults, closure=closure) 745 746 747def has_arg(fn, name, accept_all=False): 748 """Checks if a callable accepts a given keyword argument. 749 750 Args: 751 fn: Callable to inspect. 752 name: Check if `fn` can be called with `name` as a keyword argument. 753 accept_all: What to return if there is no parameter called `name` but the 754 function accepts a `**kwargs` argument. 755 756 Returns: 757 bool, whether `fn` accepts a `name` keyword argument. 758 """ 759 arg_spec = tf_inspect.getfullargspec(fn) 760 if accept_all and arg_spec.varkw is not None: 761 return True 762 return name in arg_spec.args or name in arg_spec.kwonlyargs 763 764 765@keras_export('keras.utils.Progbar') 766class Progbar(object): 767 """Displays a progress bar. 768 769 Args: 770 target: Total number of steps expected, None if unknown. 771 width: Progress bar width on screen. 772 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 773 stateful_metrics: Iterable of string names of metrics that should *not* be 774 averaged over time. Metrics in this list will be displayed as-is. All 775 others will be averaged by the progbar before display. 776 interval: Minimum visual progress update interval (in seconds). 777 unit_name: Display name for step counts (usually "step" or "sample"). 778 """ 779 780 def __init__(self, 781 target, 782 width=30, 783 verbose=1, 784 interval=0.05, 785 stateful_metrics=None, 786 unit_name='step'): 787 self.target = target 788 self.width = width 789 self.verbose = verbose 790 self.interval = interval 791 self.unit_name = unit_name 792 if stateful_metrics: 793 self.stateful_metrics = set(stateful_metrics) 794 else: 795 self.stateful_metrics = set() 796 797 self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 798 sys.stdout.isatty()) or 799 'ipykernel' in sys.modules or 800 'posix' in sys.modules or 801 'PYCHARM_HOSTED' in os.environ) 802 self._total_width = 0 803 self._seen_so_far = 0 804 # We use a dict + list to avoid garbage collection 805 # issues found in OrderedDict 806 self._values = {} 807 self._values_order = [] 808 self._start = time.time() 809 self._last_update = 0 810 811 self._time_after_first_step = None 812 813 def update(self, current, values=None, finalize=None): 814 """Updates the progress bar. 815 816 Args: 817 current: Index of current step. 818 values: List of tuples: `(name, value_for_last_step)`. If `name` is in 819 `stateful_metrics`, `value_for_last_step` will be displayed as-is. 820 Else, an average of the metric over time will be displayed. 821 finalize: Whether this is the last update for the progress bar. If 822 `None`, defaults to `current >= self.target`. 823 """ 824 if finalize is None: 825 if self.target is None: 826 finalize = False 827 else: 828 finalize = current >= self.target 829 830 values = values or [] 831 for k, v in values: 832 if k not in self._values_order: 833 self._values_order.append(k) 834 if k not in self.stateful_metrics: 835 # In the case that progress bar doesn't have a target value in the first 836 # epoch, both on_batch_end and on_epoch_end will be called, which will 837 # cause 'current' and 'self._seen_so_far' to have the same value. Force 838 # the minimal value to 1 here, otherwise stateful_metric will be 0s. 839 value_base = max(current - self._seen_so_far, 1) 840 if k not in self._values: 841 self._values[k] = [v * value_base, value_base] 842 else: 843 self._values[k][0] += v * value_base 844 self._values[k][1] += value_base 845 else: 846 # Stateful metrics output a numeric value. This representation 847 # means "take an average from a single value" but keeps the 848 # numeric formatting. 849 self._values[k] = [v, 1] 850 self._seen_so_far = current 851 852 now = time.time() 853 info = ' - %.0fs' % (now - self._start) 854 if self.verbose == 1: 855 if now - self._last_update < self.interval and not finalize: 856 return 857 858 prev_total_width = self._total_width 859 if self._dynamic_display: 860 sys.stdout.write('\b' * prev_total_width) 861 sys.stdout.write('\r') 862 else: 863 sys.stdout.write('\n') 864 865 if self.target is not None: 866 numdigits = int(np.log10(self.target)) + 1 867 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) 868 prog = float(current) / self.target 869 prog_width = int(self.width * prog) 870 if prog_width > 0: 871 bar += ('=' * (prog_width - 1)) 872 if current < self.target: 873 bar += '>' 874 else: 875 bar += '=' 876 bar += ('.' * (self.width - prog_width)) 877 bar += ']' 878 else: 879 bar = '%7d/Unknown' % current 880 881 self._total_width = len(bar) 882 sys.stdout.write(bar) 883 884 time_per_unit = self._estimate_step_duration(current, now) 885 886 if self.target is None or finalize: 887 if time_per_unit >= 1 or time_per_unit == 0: 888 info += ' %.0fs/%s' % (time_per_unit, self.unit_name) 889 elif time_per_unit >= 1e-3: 890 info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) 891 else: 892 info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) 893 else: 894 eta = time_per_unit * (self.target - current) 895 if eta > 3600: 896 eta_format = '%d:%02d:%02d' % (eta // 3600, 897 (eta % 3600) // 60, eta % 60) 898 elif eta > 60: 899 eta_format = '%d:%02d' % (eta // 60, eta % 60) 900 else: 901 eta_format = '%ds' % eta 902 903 info = ' - ETA: %s' % eta_format 904 905 for k in self._values_order: 906 info += ' - %s:' % k 907 if isinstance(self._values[k], list): 908 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 909 if abs(avg) > 1e-3: 910 info += ' %.4f' % avg 911 else: 912 info += ' %.4e' % avg 913 else: 914 info += ' %s' % self._values[k] 915 916 self._total_width += len(info) 917 if prev_total_width > self._total_width: 918 info += (' ' * (prev_total_width - self._total_width)) 919 920 if finalize: 921 info += '\n' 922 923 sys.stdout.write(info) 924 sys.stdout.flush() 925 926 elif self.verbose == 2: 927 if finalize: 928 numdigits = int(np.log10(self.target)) + 1 929 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) 930 info = count + info 931 for k in self._values_order: 932 info += ' - %s:' % k 933 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 934 if avg > 1e-3: 935 info += ' %.4f' % avg 936 else: 937 info += ' %.4e' % avg 938 info += '\n' 939 940 sys.stdout.write(info) 941 sys.stdout.flush() 942 943 self._last_update = now 944 945 def add(self, n, values=None): 946 self.update(self._seen_so_far + n, values) 947 948 def _estimate_step_duration(self, current, now): 949 """Estimate the duration of a single step. 950 951 Given the step number `current` and the corresponding time `now` 952 this function returns an estimate for how long a single step 953 takes. If this is called before one step has been completed 954 (i.e. `current == 0`) then zero is given as an estimate. The duration 955 estimate ignores the duration of the (assumed to be non-representative) 956 first step for estimates when more steps are available (i.e. `current>1`). 957 Args: 958 current: Index of current step. 959 now: The current time. 960 Returns: Estimate of the duration of a single step. 961 """ 962 if current: 963 # there are a few special scenarios here: 964 # 1) somebody is calling the progress bar without ever supplying step 1 965 # 2) somebody is calling the progress bar and supplies step one mulitple 966 # times, e.g. as part of a finalizing call 967 # in these cases, we just fall back to the simple calculation 968 if self._time_after_first_step is not None and current > 1: 969 time_per_unit = (now - self._time_after_first_step) / (current - 1) 970 else: 971 time_per_unit = (now - self._start) / current 972 973 if current == 1: 974 self._time_after_first_step = now 975 return time_per_unit 976 else: 977 return 0 978 979 980def make_batches(size, batch_size): 981 """Returns a list of batch indices (tuples of indices). 982 983 Args: 984 size: Integer, total size of the data to slice into batches. 985 batch_size: Integer, batch size. 986 987 Returns: 988 A list of tuples of array indices. 989 """ 990 num_batches = int(np.ceil(size / float(batch_size))) 991 return [(i * batch_size, min(size, (i + 1) * batch_size)) 992 for i in range(0, num_batches)] 993 994 995def slice_arrays(arrays, start=None, stop=None): 996 """Slice an array or list of arrays. 997 998 This takes an array-like, or a list of 999 array-likes, and outputs: 1000 - arrays[start:stop] if `arrays` is an array-like 1001 - [x[start:stop] for x in arrays] if `arrays` is a list 1002 1003 Can also work on list/array of indices: `slice_arrays(x, indices)` 1004 1005 Args: 1006 arrays: Single array or list of arrays. 1007 start: can be an integer index (start index) or a list/array of indices 1008 stop: integer (stop index); should be None if `start` was a list. 1009 1010 Returns: 1011 A slice of the array(s). 1012 1013 Raises: 1014 ValueError: If the value of start is a list and stop is not None. 1015 """ 1016 if arrays is None: 1017 return [None] 1018 if isinstance(start, list) and stop is not None: 1019 raise ValueError('The stop argument has to be None if the value of start ' 1020 'is a list.') 1021 elif isinstance(arrays, list): 1022 if hasattr(start, '__len__'): 1023 # hdf5 datasets only support list objects as indices 1024 if hasattr(start, 'shape'): 1025 start = start.tolist() 1026 return [None if x is None else x[start] for x in arrays] 1027 return [ 1028 None if x is None else 1029 None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays 1030 ] 1031 else: 1032 if hasattr(start, '__len__'): 1033 if hasattr(start, 'shape'): 1034 start = start.tolist() 1035 return arrays[start] 1036 if hasattr(start, '__getitem__'): 1037 return arrays[start:stop] 1038 return [None] 1039 1040 1041def to_list(x): 1042 """Normalizes a list/tensor into a list. 1043 1044 If a tensor is passed, we return 1045 a list of size 1 containing the tensor. 1046 1047 Args: 1048 x: target object to be normalized. 1049 1050 Returns: 1051 A list. 1052 """ 1053 if isinstance(x, list): 1054 return x 1055 return [x] 1056 1057 1058def to_snake_case(name): 1059 intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) 1060 insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() 1061 # If the class is private the name starts with "_" which is not secure 1062 # for creating scopes. We prefix the name with "private" in this case. 1063 if insecure[0] != '_': 1064 return insecure 1065 return 'private' + insecure 1066 1067 1068def is_all_none(structure): 1069 iterable = nest.flatten(structure) 1070 # We cannot use Python's `any` because the iterable may return Tensors. 1071 for element in iterable: 1072 if element is not None: 1073 return False 1074 return True 1075 1076 1077def check_for_unexpected_keys(name, input_dict, expected_values): 1078 unknown = set(input_dict.keys()).difference(expected_values) 1079 if unknown: 1080 raise ValueError('Unknown entries in {} dictionary: {}. Only expected ' 1081 'following keys: {}'.format(name, list(unknown), 1082 expected_values)) 1083 1084 1085def validate_kwargs(kwargs, 1086 allowed_kwargs, 1087 error_message='Keyword argument not understood:'): 1088 """Checks that all keyword arguments are in the set of allowed keys.""" 1089 for kwarg in kwargs: 1090 if kwarg not in allowed_kwargs: 1091 raise TypeError(error_message, kwarg) 1092 1093 1094def validate_config(config): 1095 """Determines whether config appears to be a valid layer config.""" 1096 return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config 1097 1098 1099def default(method): 1100 """Decorates a method to detect overrides in subclasses.""" 1101 method._is_default = True # pylint: disable=protected-access 1102 return method 1103 1104 1105def is_default(method): 1106 """Check if a method is decorated with the `default` wrapper.""" 1107 return getattr(method, '_is_default', False) 1108 1109 1110def populate_dict_with_module_objects(target_dict, modules, obj_filter): 1111 for module in modules: 1112 for name in dir(module): 1113 obj = getattr(module, name) 1114 if obj_filter(obj): 1115 target_dict[name] = obj 1116 1117 1118class LazyLoader(python_types.ModuleType): 1119 """Lazily import a module, mainly to avoid pulling in large dependencies.""" 1120 1121 def __init__(self, local_name, parent_module_globals, name): 1122 self._local_name = local_name 1123 self._parent_module_globals = parent_module_globals 1124 super(LazyLoader, self).__init__(name) 1125 1126 def _load(self): 1127 """Load the module and insert it into the parent's globals.""" 1128 # Import the target module and insert it into the parent's namespace 1129 module = importlib.import_module(self.__name__) 1130 self._parent_module_globals[self._local_name] = module 1131 # Update this object's dict so that if someone keeps a reference to the 1132 # LazyLoader, lookups are efficient (__getattr__ is only called on lookups 1133 # that fail). 1134 self.__dict__.update(module.__dict__) 1135 return module 1136 1137 def __getattr__(self, item): 1138 module = self._load() 1139 return getattr(module, item) 1140 1141 1142# Aliases 1143 1144custom_object_scope = CustomObjectScope # pylint: disable=invalid-name 1145