1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python utilities required by Keras.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import binascii 21import codecs 22import marshal 23import os 24import re 25import sys 26import time 27import types as python_types 28 29import numpy as np 30import six 31 32from tensorflow.python.util import nest 33from tensorflow.python.util import tf_contextlib 34from tensorflow.python.util import tf_decorator 35from tensorflow.python.util import tf_inspect 36from tensorflow.python.util.tf_export import keras_export 37 38_GLOBAL_CUSTOM_OBJECTS = {} 39_GLOBAL_CUSTOM_NAMES = {} 40 41# Flag that determines whether to skip the NotImplementedError when calling 42# get_config in custom models and layers. This is only enabled when saving to 43# SavedModel, when the config isn't required. 44_SKIP_FAILED_SERIALIZATION = False 45# If a layer does not have a defined config, then the returned config will be a 46# dictionary with the below key. 47LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' 48 49 50@keras_export('keras.utils.CustomObjectScope') 51class CustomObjectScope(object): 52 """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. 53 54 Code within a `with` statement will be able to access custom objects 55 by name. Changes to global custom objects persist 56 within the enclosing `with` statement. At end of the `with` statement, 57 global custom objects are reverted to state 58 at beginning of the `with` statement. 59 60 Example: 61 62 Consider a custom object `MyObject` (e.g. a class): 63 64 ```python 65 with CustomObjectScope({'MyObject':MyObject}): 66 layer = Dense(..., kernel_regularizer='MyObject') 67 # save, load, etc. will recognize custom object by name 68 ``` 69 """ 70 71 def __init__(self, *args): 72 self.custom_objects = args 73 self.backup = None 74 75 def __enter__(self): 76 self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() 77 for objects in self.custom_objects: 78 _GLOBAL_CUSTOM_OBJECTS.update(objects) 79 return self 80 81 def __exit__(self, *args, **kwargs): 82 _GLOBAL_CUSTOM_OBJECTS.clear() 83 _GLOBAL_CUSTOM_OBJECTS.update(self.backup) 84 85 86@keras_export('keras.utils.custom_object_scope') 87def custom_object_scope(*args): 88 """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. 89 90 Convenience wrapper for `CustomObjectScope`. 91 Code within a `with` statement will be able to access custom objects 92 by name. Changes to global custom objects persist 93 within the enclosing `with` statement. At end of the `with` statement, 94 global custom objects are reverted to state 95 at beginning of the `with` statement. 96 97 Example: 98 99 Consider a custom object `MyObject` 100 101 ```python 102 with custom_object_scope({'MyObject':MyObject}): 103 layer = Dense(..., kernel_regularizer='MyObject') 104 # save, load, etc. will recognize custom object by name 105 ``` 106 107 Arguments: 108 *args: Variable length list of dictionaries of name, class pairs to add to 109 custom objects. 110 111 Returns: 112 Object of type `CustomObjectScope`. 113 """ 114 return CustomObjectScope(*args) 115 116 117@keras_export('keras.utils.get_custom_objects') 118def get_custom_objects(): 119 """Retrieves a live reference to the global dictionary of custom objects. 120 121 Updating and clearing custom objects using `custom_object_scope` 122 is preferred, but `get_custom_objects` can 123 be used to directly access `_GLOBAL_CUSTOM_OBJECTS`. 124 125 Example: 126 127 ```python 128 get_custom_objects().clear() 129 get_custom_objects()['MyObject'] = MyObject 130 ``` 131 132 Returns: 133 Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). 134 """ 135 return _GLOBAL_CUSTOM_OBJECTS 136 137 138def serialize_keras_class_and_config(cls_name, cls_config): 139 """Returns the serialization of the class with the given config.""" 140 return {'class_name': cls_name, 'config': cls_config} 141 142 143@keras_export('keras.utils.register_keras_serializable') 144def register_keras_serializable(package='Custom', name=None): 145 """Registers an object with the Keras serialization framework. 146 147 This decorator injects the decorated class or function into the Keras custom 148 object dictionary, so that it can be serialized and deserialized without 149 needing an entry in the user-provided custom object dict. It also injects a 150 function that Keras will call to get the object's serializable string key. 151 152 Note that to be serialized and deserialized, classes must implement the 153 `get_config()` method. Functions do not have this requirement. 154 155 The object will be registered under the key 'package>name' where `name`, 156 defaults to the object name if not passed. 157 158 Arguments: 159 package: The package that this class belongs to. 160 name: The name to serialize this class under in this package. If None, the 161 class's name will be used. 162 163 Returns: 164 A decorator that registers the decorated class with the passed names. 165 """ 166 167 def decorator(arg): 168 """Registers a class with the Keras serialization framework.""" 169 class_name = name if name is not None else arg.__name__ 170 registered_name = package + '>' + class_name 171 172 if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'): 173 raise ValueError( 174 'Cannot register a class that does not have a get_config() method.') 175 176 if registered_name in _GLOBAL_CUSTOM_OBJECTS: 177 raise ValueError( 178 '%s has already been registered to %s' % 179 (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name])) 180 181 if arg in _GLOBAL_CUSTOM_NAMES: 182 raise ValueError('%s has already been registered to %s' % 183 (arg, _GLOBAL_CUSTOM_NAMES[arg])) 184 _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg 185 _GLOBAL_CUSTOM_NAMES[arg] = registered_name 186 187 return arg 188 189 return decorator 190 191 192@keras_export('keras.utils.get_registered_name') 193def get_registered_name(obj): 194 """Returns the name registered to an object within the Keras framework. 195 196 This function is part of the Keras serialization and deserialization 197 framework. It maps objects to the string names associated with those objects 198 for serialization/deserialization. 199 200 Args: 201 obj: The object to look up. 202 203 Returns: 204 The name associated with the object, or the default Python name if the 205 object is not registered. 206 """ 207 if obj in _GLOBAL_CUSTOM_NAMES: 208 return _GLOBAL_CUSTOM_NAMES[obj] 209 else: 210 return obj.__name__ 211 212 213@tf_contextlib.contextmanager 214def skip_failed_serialization(): 215 global _SKIP_FAILED_SERIALIZATION 216 prev = _SKIP_FAILED_SERIALIZATION 217 try: 218 _SKIP_FAILED_SERIALIZATION = True 219 yield 220 finally: 221 _SKIP_FAILED_SERIALIZATION = prev 222 223 224@keras_export('keras.utils.get_registered_object') 225def get_registered_object(name, custom_objects=None, module_objects=None): 226 """Returns the class associated with `name` if it is registered with Keras. 227 228 This function is part of the Keras serialization and deserialization 229 framework. It maps strings to the objects associated with them for 230 serialization/deserialization. 231 232 Example: 233 ``` 234 def from_config(cls, config, custom_objects=None): 235 if 'my_custom_object_name' in config: 236 config['hidden_cls'] = tf.keras.utils.get_registered_object( 237 config['my_custom_object_name'], custom_objects=custom_objects) 238 ``` 239 240 Args: 241 name: The name to look up. 242 custom_objects: A dictionary of custom objects to look the name up in. 243 Generally, custom_objects is provided by the user. 244 module_objects: A dictionary of custom objects to look the name up in. 245 Generally, module_objects is provided by midlevel library implementers. 246 247 Returns: 248 An instantiable class associated with 'name', or None if no such class 249 exists. 250 """ 251 if name in _GLOBAL_CUSTOM_OBJECTS: 252 return _GLOBAL_CUSTOM_OBJECTS[name] 253 elif custom_objects and name in custom_objects: 254 return custom_objects[name] 255 elif module_objects and name in module_objects: 256 return module_objects[name] 257 return None 258 259 260@keras_export('keras.utils.serialize_keras_object') 261def serialize_keras_object(instance): 262 """Serialize Keras object into JSON.""" 263 _, instance = tf_decorator.unwrap(instance) 264 if instance is None: 265 return None 266 267 if hasattr(instance, 'get_config'): 268 name = get_registered_name(instance.__class__) 269 try: 270 config = instance.get_config() 271 except NotImplementedError as e: 272 if _SKIP_FAILED_SERIALIZATION: 273 return serialize_keras_class_and_config( 274 name, {LAYER_UNDEFINED_CONFIG_KEY: True}) 275 raise e 276 serialization_config = {} 277 for key, item in config.items(): 278 if isinstance(item, six.string_types): 279 serialization_config[key] = item 280 continue 281 282 # Any object of a different type needs to be converted to string or dict 283 # for serialization (e.g. custom functions, custom classes) 284 try: 285 serialized_item = serialize_keras_object(item) 286 if isinstance(serialized_item, dict) and not isinstance(item, dict): 287 serialized_item['__passive_serialization__'] = True 288 serialization_config[key] = serialized_item 289 except ValueError: 290 serialization_config[key] = item 291 292 name = get_registered_name(instance.__class__) 293 return serialize_keras_class_and_config(name, serialization_config) 294 if hasattr(instance, '__name__'): 295 return get_registered_name(instance) 296 raise ValueError('Cannot serialize', instance) 297 298 299def get_custom_objects_by_name(item, custom_objects=None): 300 """Returns the item if it is in either local or global custom objects.""" 301 if item in _GLOBAL_CUSTOM_OBJECTS: 302 return _GLOBAL_CUSTOM_OBJECTS[item] 303 elif custom_objects and item in custom_objects: 304 return custom_objects[item] 305 return None 306 307 308def class_and_config_for_serialized_keras_object( 309 config, 310 module_objects=None, 311 custom_objects=None, 312 printable_module_name='object'): 313 """Returns the class name and config for a serialized keras object.""" 314 if (not isinstance(config, dict) or 'class_name' not in config or 315 'config' not in config): 316 raise ValueError('Improper config format: ' + str(config)) 317 318 class_name = config['class_name'] 319 cls = get_registered_object(class_name, custom_objects, module_objects) 320 if cls is None: 321 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) 322 323 cls_config = config['config'] 324 deserialized_objects = {} 325 for key, item in cls_config.items(): 326 if isinstance(item, dict) and '__passive_serialization__' in item: 327 deserialized_objects[key] = deserialize_keras_object( 328 item, 329 module_objects=module_objects, 330 custom_objects=custom_objects, 331 printable_module_name='config_item') 332 # TODO(momernick): Should this also have 'module_objects'? 333 elif (isinstance(item, six.string_types) and 334 tf_inspect.isfunction(get_registered_object(item, custom_objects))): 335 # Handle custom functions here. When saving functions, we only save the 336 # function's name as a string. If we find a matching string in the custom 337 # objects during deserialization, we convert the string back to the 338 # original function. 339 # Note that a potential issue is that a string field could have a naming 340 # conflict with a custom function name, but this should be a rare case. 341 # This issue does not occur if a string field has a naming conflict with 342 # a custom object, since the config of an object will always be a dict. 343 deserialized_objects[key] = get_registered_object(item, custom_objects) 344 for key, item in deserialized_objects.items(): 345 cls_config[key] = deserialized_objects[key] 346 347 return (cls, cls_config) 348 349 350@keras_export('keras.utils.deserialize_keras_object') 351def deserialize_keras_object(identifier, 352 module_objects=None, 353 custom_objects=None, 354 printable_module_name='object'): 355 if identifier is None: 356 return None 357 358 if isinstance(identifier, dict): 359 # In this case we are dealing with a Keras config dictionary. 360 config = identifier 361 (cls, cls_config) = class_and_config_for_serialized_keras_object( 362 config, module_objects, custom_objects, printable_module_name) 363 364 if hasattr(cls, 'from_config'): 365 arg_spec = tf_inspect.getfullargspec(cls.from_config) 366 custom_objects = custom_objects or {} 367 368 if 'custom_objects' in arg_spec.args: 369 return cls.from_config( 370 cls_config, 371 custom_objects=dict( 372 list(_GLOBAL_CUSTOM_OBJECTS.items()) + 373 list(custom_objects.items()))) 374 with CustomObjectScope(custom_objects): 375 return cls.from_config(cls_config) 376 else: 377 # Then `cls` may be a function returning a class. 378 # in this case by convention `config` holds 379 # the kwargs of the function. 380 custom_objects = custom_objects or {} 381 with CustomObjectScope(custom_objects): 382 return cls(**cls_config) 383 elif isinstance(identifier, six.string_types): 384 object_name = identifier 385 if custom_objects and object_name in custom_objects: 386 obj = custom_objects.get(object_name) 387 elif object_name in _GLOBAL_CUSTOM_OBJECTS: 388 obj = _GLOBAL_CUSTOM_OBJECTS[object_name] 389 else: 390 obj = module_objects.get(object_name) 391 if obj is None: 392 raise ValueError('Unknown ' + printable_module_name + ':' + object_name) 393 # Classes passed by name are instantiated with no args, functions are 394 # returned as-is. 395 if tf_inspect.isclass(obj): 396 return obj() 397 return obj 398 elif tf_inspect.isfunction(identifier): 399 # If a function has already been deserialized, return as is. 400 return identifier 401 else: 402 raise ValueError('Could not interpret serialized %s: %s' % 403 (printable_module_name, identifier)) 404 405 406def func_dump(func): 407 """Serializes a user defined function. 408 409 Arguments: 410 func: the function to serialize. 411 412 Returns: 413 A tuple `(code, defaults, closure)`. 414 """ 415 if os.name == 'nt': 416 raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/') 417 code = codecs.encode(raw_code, 'base64').decode('ascii') 418 else: 419 raw_code = marshal.dumps(func.__code__) 420 code = codecs.encode(raw_code, 'base64').decode('ascii') 421 defaults = func.__defaults__ 422 if func.__closure__: 423 closure = tuple(c.cell_contents for c in func.__closure__) 424 else: 425 closure = None 426 return code, defaults, closure 427 428 429def func_load(code, defaults=None, closure=None, globs=None): 430 """Deserializes a user defined function. 431 432 Arguments: 433 code: bytecode of the function. 434 defaults: defaults of the function. 435 closure: closure of the function. 436 globs: dictionary of global objects. 437 438 Returns: 439 A function object. 440 """ 441 if isinstance(code, (tuple, list)): # unpack previous dump 442 code, defaults, closure = code 443 if isinstance(defaults, list): 444 defaults = tuple(defaults) 445 446 def ensure_value_to_cell(value): 447 """Ensures that a value is converted to a python cell object. 448 449 Arguments: 450 value: Any value that needs to be casted to the cell type 451 452 Returns: 453 A value wrapped as a cell object (see function "func_load") 454 """ 455 456 def dummy_fn(): 457 # pylint: disable=pointless-statement 458 value # just access it so it gets captured in .__closure__ 459 460 cell_value = dummy_fn.__closure__[0] 461 if not isinstance(value, type(cell_value)): 462 return cell_value 463 return value 464 465 if closure is not None: 466 closure = tuple(ensure_value_to_cell(_) for _ in closure) 467 try: 468 raw_code = codecs.decode(code.encode('ascii'), 'base64') 469 except (UnicodeEncodeError, binascii.Error): 470 raw_code = code.encode('raw_unicode_escape') 471 code = marshal.loads(raw_code) 472 if globs is None: 473 globs = globals() 474 return python_types.FunctionType( 475 code, globs, name=code.co_name, argdefs=defaults, closure=closure) 476 477 478def has_arg(fn, name, accept_all=False): 479 """Checks if a callable accepts a given keyword argument. 480 481 Arguments: 482 fn: Callable to inspect. 483 name: Check if `fn` can be called with `name` as a keyword argument. 484 accept_all: What to return if there is no parameter called `name` but the 485 function accepts a `**kwargs` argument. 486 487 Returns: 488 bool, whether `fn` accepts a `name` keyword argument. 489 """ 490 arg_spec = tf_inspect.getfullargspec(fn) 491 if accept_all and arg_spec.varkw is not None: 492 return True 493 return name in arg_spec.args 494 495 496@keras_export('keras.utils.Progbar') 497class Progbar(object): 498 """Displays a progress bar. 499 500 Arguments: 501 target: Total number of steps expected, None if unknown. 502 width: Progress bar width on screen. 503 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 504 stateful_metrics: Iterable of string names of metrics that should *not* be 505 averaged over time. Metrics in this list will be displayed as-is. All 506 others will be averaged by the progbar before display. 507 interval: Minimum visual progress update interval (in seconds). 508 unit_name: Display name for step counts (usually "step" or "sample"). 509 """ 510 511 def __init__(self, 512 target, 513 width=30, 514 verbose=1, 515 interval=0.05, 516 stateful_metrics=None, 517 unit_name='step'): 518 self.target = target 519 self.width = width 520 self.verbose = verbose 521 self.interval = interval 522 self.unit_name = unit_name 523 if stateful_metrics: 524 self.stateful_metrics = set(stateful_metrics) 525 else: 526 self.stateful_metrics = set() 527 528 self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 529 sys.stdout.isatty()) or 530 'ipykernel' in sys.modules or 531 'posix' in sys.modules or 532 'PYCHARM_HOSTED' in os.environ) 533 self._total_width = 0 534 self._seen_so_far = 0 535 # We use a dict + list to avoid garbage collection 536 # issues found in OrderedDict 537 self._values = {} 538 self._values_order = [] 539 self._start = time.time() 540 self._last_update = 0 541 542 def update(self, current, values=None): 543 """Updates the progress bar. 544 545 Arguments: 546 current: Index of current step. 547 values: List of tuples: `(name, value_for_last_step)`. If `name` is in 548 `stateful_metrics`, `value_for_last_step` will be displayed as-is. 549 Else, an average of the metric over time will be displayed. 550 """ 551 values = values or [] 552 for k, v in values: 553 if k not in self._values_order: 554 self._values_order.append(k) 555 if k not in self.stateful_metrics: 556 # In the case that progress bar doesn't have a target value in the first 557 # epoch, both on_batch_end and on_epoch_end will be called, which will 558 # cause 'current' and 'self._seen_so_far' to have the same value. Force 559 # the minimal value to 1 here, otherwise stateful_metric will be 0s. 560 value_base = max(current - self._seen_so_far, 1) 561 if k not in self._values: 562 self._values[k] = [v * value_base, value_base] 563 else: 564 self._values[k][0] += v * value_base 565 self._values[k][1] += value_base 566 else: 567 # Stateful metrics output a numeric value. This representation 568 # means "take an average from a single value" but keeps the 569 # numeric formatting. 570 self._values[k] = [v, 1] 571 self._seen_so_far = current 572 573 now = time.time() 574 info = ' - %.0fs' % (now - self._start) 575 if self.verbose == 1: 576 if (now - self._last_update < self.interval and 577 self.target is not None and current < self.target): 578 return 579 580 prev_total_width = self._total_width 581 if self._dynamic_display: 582 sys.stdout.write('\b' * prev_total_width) 583 sys.stdout.write('\r') 584 else: 585 sys.stdout.write('\n') 586 587 if self.target is not None: 588 numdigits = int(np.log10(self.target)) + 1 589 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) 590 prog = float(current) / self.target 591 prog_width = int(self.width * prog) 592 if prog_width > 0: 593 bar += ('=' * (prog_width - 1)) 594 if current < self.target: 595 bar += '>' 596 else: 597 bar += '=' 598 bar += ('.' * (self.width - prog_width)) 599 bar += ']' 600 else: 601 bar = '%7d/Unknown' % current 602 603 self._total_width = len(bar) 604 sys.stdout.write(bar) 605 606 if current: 607 time_per_unit = (now - self._start) / current 608 else: 609 time_per_unit = 0 610 if self.target is not None and current < self.target: 611 eta = time_per_unit * (self.target - current) 612 if eta > 3600: 613 eta_format = '%d:%02d:%02d' % (eta // 3600, 614 (eta % 3600) // 60, eta % 60) 615 elif eta > 60: 616 eta_format = '%d:%02d' % (eta // 60, eta % 60) 617 else: 618 eta_format = '%ds' % eta 619 620 info = ' - ETA: %s' % eta_format 621 else: 622 if time_per_unit >= 1 or time_per_unit == 0: 623 info += ' %.0fs/%s' % (time_per_unit, self.unit_name) 624 elif time_per_unit >= 1e-3: 625 info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) 626 else: 627 info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) 628 629 for k in self._values_order: 630 info += ' - %s:' % k 631 if isinstance(self._values[k], list): 632 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 633 if abs(avg) > 1e-3: 634 info += ' %.4f' % avg 635 else: 636 info += ' %.4e' % avg 637 else: 638 info += ' %s' % self._values[k] 639 640 self._total_width += len(info) 641 if prev_total_width > self._total_width: 642 info += (' ' * (prev_total_width - self._total_width)) 643 644 if self.target is not None and current >= self.target: 645 info += '\n' 646 647 sys.stdout.write(info) 648 sys.stdout.flush() 649 650 elif self.verbose == 2: 651 if self.target is not None and current >= self.target: 652 numdigits = int(np.log10(self.target)) + 1 653 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) 654 info = count + info 655 for k in self._values_order: 656 info += ' - %s:' % k 657 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 658 if avg > 1e-3: 659 info += ' %.4f' % avg 660 else: 661 info += ' %.4e' % avg 662 info += '\n' 663 664 sys.stdout.write(info) 665 sys.stdout.flush() 666 667 self._last_update = now 668 669 def add(self, n, values=None): 670 self.update(self._seen_so_far + n, values) 671 672 673def make_batches(size, batch_size): 674 """Returns a list of batch indices (tuples of indices). 675 676 Arguments: 677 size: Integer, total size of the data to slice into batches. 678 batch_size: Integer, batch size. 679 680 Returns: 681 A list of tuples of array indices. 682 """ 683 num_batches = int(np.ceil(size / float(batch_size))) 684 return [(i * batch_size, min(size, (i + 1) * batch_size)) 685 for i in range(0, num_batches)] 686 687 688def slice_arrays(arrays, start=None, stop=None): 689 """Slice an array or list of arrays. 690 691 This takes an array-like, or a list of 692 array-likes, and outputs: 693 - arrays[start:stop] if `arrays` is an array-like 694 - [x[start:stop] for x in arrays] if `arrays` is a list 695 696 Can also work on list/array of indices: `slice_arrays(x, indices)` 697 698 Arguments: 699 arrays: Single array or list of arrays. 700 start: can be an integer index (start index) or a list/array of indices 701 stop: integer (stop index); should be None if `start` was a list. 702 703 Returns: 704 A slice of the array(s). 705 706 Raises: 707 ValueError: If the value of start is a list and stop is not None. 708 """ 709 if arrays is None: 710 return [None] 711 if isinstance(start, list) and stop is not None: 712 raise ValueError('The stop argument has to be None if the value of start ' 713 'is a list.') 714 elif isinstance(arrays, list): 715 if hasattr(start, '__len__'): 716 # hdf5 datasets only support list objects as indices 717 if hasattr(start, 'shape'): 718 start = start.tolist() 719 return [None if x is None else x[start] for x in arrays] 720 return [ 721 None if x is None else 722 None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays 723 ] 724 else: 725 if hasattr(start, '__len__'): 726 if hasattr(start, 'shape'): 727 start = start.tolist() 728 return arrays[start] 729 if hasattr(start, '__getitem__'): 730 return arrays[start:stop] 731 return [None] 732 733 734def to_list(x): 735 """Normalizes a list/tensor into a list. 736 737 If a tensor is passed, we return 738 a list of size 1 containing the tensor. 739 740 Arguments: 741 x: target object to be normalized. 742 743 Returns: 744 A list. 745 """ 746 if isinstance(x, list): 747 return x 748 return [x] 749 750 751def object_list_uid(object_list): 752 """Creates a single string from object ids.""" 753 object_list = nest.flatten(object_list) 754 return ', '.join(str(abs(id(x))) for x in object_list) 755 756 757def to_snake_case(name): 758 intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) 759 insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() 760 # If the class is private the name starts with "_" which is not secure 761 # for creating scopes. We prefix the name with "private" in this case. 762 if insecure[0] != '_': 763 return insecure 764 return 'private' + insecure 765 766 767def is_all_none(structure): 768 iterable = nest.flatten(structure) 769 # We cannot use Python's `any` because the iterable may return Tensors. 770 for element in iterable: 771 if element is not None: 772 return False 773 return True 774 775 776def check_for_unexpected_keys(name, input_dict, expected_values): 777 unknown = set(input_dict.keys()).difference(expected_values) 778 if unknown: 779 raise ValueError('Unknown entries in {} dictionary: {}. Only expected ' 780 'following keys: {}'.format(name, list(unknown), 781 expected_values)) 782 783 784def validate_kwargs(kwargs, 785 allowed_kwargs, 786 error_message='Keyword argument not understood:'): 787 """Checks that all keyword arguments are in the set of allowed keys.""" 788 for kwarg in kwargs: 789 if kwarg not in allowed_kwargs: 790 raise TypeError(error_message, kwarg) 791