1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python utilities required by Keras.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import binascii 21import codecs 22import marshal 23import os 24import re 25import sys 26import time 27import types as python_types 28 29import numpy as np 30import six 31 32from tensorflow.python.util import nest 33from tensorflow.python.util import tf_decorator 34from tensorflow.python.util import tf_inspect 35from tensorflow.python.util.tf_export import keras_export 36 37_GLOBAL_CUSTOM_OBJECTS = {} 38 39 40@keras_export('keras.utils.CustomObjectScope') 41class CustomObjectScope(object): 42 """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. 43 44 Code within a `with` statement will be able to access custom objects 45 by name. Changes to global custom objects persist 46 within the enclosing `with` statement. At end of the `with` statement, 47 global custom objects are reverted to state 48 at beginning of the `with` statement. 49 50 Example: 51 52 Consider a custom object `MyObject` (e.g. a class): 53 54 ```python 55 with CustomObjectScope({'MyObject':MyObject}): 56 layer = Dense(..., kernel_regularizer='MyObject') 57 # save, load, etc. will recognize custom object by name 58 ``` 59 """ 60 61 def __init__(self, *args): 62 self.custom_objects = args 63 self.backup = None 64 65 def __enter__(self): 66 self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() 67 for objects in self.custom_objects: 68 _GLOBAL_CUSTOM_OBJECTS.update(objects) 69 return self 70 71 def __exit__(self, *args, **kwargs): 72 _GLOBAL_CUSTOM_OBJECTS.clear() 73 _GLOBAL_CUSTOM_OBJECTS.update(self.backup) 74 75 76@keras_export('keras.utils.custom_object_scope') 77def custom_object_scope(*args): 78 """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape. 79 80 Convenience wrapper for `CustomObjectScope`. 81 Code within a `with` statement will be able to access custom objects 82 by name. Changes to global custom objects persist 83 within the enclosing `with` statement. At end of the `with` statement, 84 global custom objects are reverted to state 85 at beginning of the `with` statement. 86 87 Example: 88 89 Consider a custom object `MyObject` 90 91 ```python 92 with custom_object_scope({'MyObject':MyObject}): 93 layer = Dense(..., kernel_regularizer='MyObject') 94 # save, load, etc. will recognize custom object by name 95 ``` 96 97 Arguments: 98 *args: Variable length list of dictionaries of name, 99 class pairs to add to custom objects. 100 101 Returns: 102 Object of type `CustomObjectScope`. 103 """ 104 return CustomObjectScope(*args) 105 106 107@keras_export('keras.utils.get_custom_objects') 108def get_custom_objects(): 109 """Retrieves a live reference to the global dictionary of custom objects. 110 111 Updating and clearing custom objects using `custom_object_scope` 112 is preferred, but `get_custom_objects` can 113 be used to directly access `_GLOBAL_CUSTOM_OBJECTS`. 114 115 Example: 116 117 ```python 118 get_custom_objects().clear() 119 get_custom_objects()['MyObject'] = MyObject 120 ``` 121 122 Returns: 123 Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). 124 """ 125 return _GLOBAL_CUSTOM_OBJECTS 126 127 128def serialize_keras_class_and_config(cls_name, cls_config): 129 """Returns the serialization of the class with the given config.""" 130 return {'class_name': cls_name, 'config': cls_config} 131 132 133@keras_export('keras.utils.serialize_keras_object') 134def serialize_keras_object(instance): 135 _, instance = tf_decorator.unwrap(instance) 136 if instance is None: 137 return None 138 if hasattr(instance, 'get_config'): 139 return serialize_keras_class_and_config(instance.__class__.__name__, 140 instance.get_config()) 141 if hasattr(instance, '__name__'): 142 return instance.__name__ 143 else: 144 raise ValueError('Cannot serialize', instance) 145 146 147def class_and_config_for_serialized_keras_object( 148 config, 149 module_objects=None, 150 custom_objects=None, 151 printable_module_name='object'): 152 """Returns the class name and config for a serialized keras object.""" 153 if (not isinstance(config, dict) or 'class_name' not in config or 154 'config' not in config): 155 raise ValueError('Improper config format: ' + str(config)) 156 157 class_name = config['class_name'] 158 if custom_objects and class_name in custom_objects: 159 cls = custom_objects[class_name] 160 elif class_name in _GLOBAL_CUSTOM_OBJECTS: 161 cls = _GLOBAL_CUSTOM_OBJECTS[class_name] 162 else: 163 module_objects = module_objects or {} 164 cls = module_objects.get(class_name) 165 if cls is None: 166 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) 167 return (cls, config['config']) 168 169 170@keras_export('keras.utils.deserialize_keras_object') 171def deserialize_keras_object(identifier, 172 module_objects=None, 173 custom_objects=None, 174 printable_module_name='object'): 175 if identifier is None: 176 return None 177 if isinstance(identifier, dict): 178 # In this case we are dealing with a Keras config dictionary. 179 config = identifier 180 (cls, cls_config) = class_and_config_for_serialized_keras_object( 181 config, module_objects, custom_objects, printable_module_name) 182 183 if hasattr(cls, 'from_config'): 184 arg_spec = tf_inspect.getfullargspec(cls.from_config) 185 custom_objects = custom_objects or {} 186 187 if 'custom_objects' in arg_spec.args: 188 return cls.from_config( 189 cls_config, 190 custom_objects=dict( 191 list(_GLOBAL_CUSTOM_OBJECTS.items()) + 192 list(custom_objects.items()))) 193 with CustomObjectScope(custom_objects): 194 return cls.from_config(cls_config) 195 else: 196 # Then `cls` may be a function returning a class. 197 # in this case by convention `config` holds 198 # the kwargs of the function. 199 custom_objects = custom_objects or {} 200 with CustomObjectScope(custom_objects): 201 return cls(**cls_config) 202 elif isinstance(identifier, six.string_types): 203 function_name = identifier 204 if custom_objects and function_name in custom_objects: 205 fn = custom_objects.get(function_name) 206 elif function_name in _GLOBAL_CUSTOM_OBJECTS: 207 fn = _GLOBAL_CUSTOM_OBJECTS[function_name] 208 else: 209 fn = module_objects.get(function_name) 210 if fn is None: 211 raise ValueError('Unknown ' + printable_module_name + ':' + 212 function_name) 213 return fn 214 else: 215 raise ValueError('Could not interpret serialized ' + printable_module_name + 216 ': ' + identifier) 217 218 219def func_dump(func): 220 """Serializes a user defined function. 221 222 Arguments: 223 func: the function to serialize. 224 225 Returns: 226 A tuple `(code, defaults, closure)`. 227 """ 228 if os.name == 'nt': 229 raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/') 230 code = codecs.encode(raw_code, 'base64').decode('ascii') 231 else: 232 raw_code = marshal.dumps(func.__code__) 233 code = codecs.encode(raw_code, 'base64').decode('ascii') 234 defaults = func.__defaults__ 235 if func.__closure__: 236 closure = tuple(c.cell_contents for c in func.__closure__) 237 else: 238 closure = None 239 return code, defaults, closure 240 241 242def func_load(code, defaults=None, closure=None, globs=None): 243 """Deserializes a user defined function. 244 245 Arguments: 246 code: bytecode of the function. 247 defaults: defaults of the function. 248 closure: closure of the function. 249 globs: dictionary of global objects. 250 251 Returns: 252 A function object. 253 """ 254 if isinstance(code, (tuple, list)): # unpack previous dump 255 code, defaults, closure = code 256 if isinstance(defaults, list): 257 defaults = tuple(defaults) 258 259 def ensure_value_to_cell(value): 260 """Ensures that a value is converted to a python cell object. 261 262 Arguments: 263 value: Any value that needs to be casted to the cell type 264 265 Returns: 266 A value wrapped as a cell object (see function "func_load") 267 """ 268 def dummy_fn(): 269 # pylint: disable=pointless-statement 270 value # just access it so it gets captured in .__closure__ 271 272 cell_value = dummy_fn.__closure__[0] 273 if not isinstance(value, type(cell_value)): 274 return cell_value 275 else: 276 return value 277 278 if closure is not None: 279 closure = tuple(ensure_value_to_cell(_) for _ in closure) 280 try: 281 raw_code = codecs.decode(code.encode('ascii'), 'base64') 282 except (UnicodeEncodeError, binascii.Error): 283 raw_code = code.encode('raw_unicode_escape') 284 code = marshal.loads(raw_code) 285 if globs is None: 286 globs = globals() 287 return python_types.FunctionType( 288 code, globs, name=code.co_name, argdefs=defaults, closure=closure) 289 290 291def has_arg(fn, name, accept_all=False): 292 """Checks if a callable accepts a given keyword argument. 293 294 Arguments: 295 fn: Callable to inspect. 296 name: Check if `fn` can be called with `name` as a keyword argument. 297 accept_all: What to return if there is no parameter called `name` 298 but the function accepts a `**kwargs` argument. 299 300 Returns: 301 bool, whether `fn` accepts a `name` keyword argument. 302 """ 303 arg_spec = tf_inspect.getfullargspec(fn) 304 if accept_all and arg_spec.varkw is not None: 305 return True 306 return name in arg_spec.args 307 308 309@keras_export('keras.utils.Progbar') 310class Progbar(object): 311 """Displays a progress bar. 312 313 Arguments: 314 target: Total number of steps expected, None if unknown. 315 width: Progress bar width on screen. 316 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 317 stateful_metrics: Iterable of string names of metrics that 318 should *not* be averaged over time. Metrics in this list 319 will be displayed as-is. All others will be averaged 320 by the progbar before display. 321 interval: Minimum visual progress update interval (in seconds). 322 unit_name: Display name for step counts (usually "step" or "sample"). 323 """ 324 325 def __init__(self, target, width=30, verbose=1, interval=0.05, 326 stateful_metrics=None, unit_name='step'): 327 self.target = target 328 self.width = width 329 self.verbose = verbose 330 self.interval = interval 331 self.unit_name = unit_name 332 if stateful_metrics: 333 self.stateful_metrics = set(stateful_metrics) 334 else: 335 self.stateful_metrics = set() 336 337 self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 338 sys.stdout.isatty()) or 339 'ipykernel' in sys.modules or 340 'posix' in sys.modules) 341 self._total_width = 0 342 self._seen_so_far = 0 343 # We use a dict + list to avoid garbage collection 344 # issues found in OrderedDict 345 self._values = {} 346 self._values_order = [] 347 self._start = time.time() 348 self._last_update = 0 349 350 def update(self, current, values=None): 351 """Updates the progress bar. 352 353 Arguments: 354 current: Index of current step. 355 values: List of tuples: 356 `(name, value_for_last_step)`. 357 If `name` is in `stateful_metrics`, 358 `value_for_last_step` will be displayed as-is. 359 Else, an average of the metric over time will be displayed. 360 """ 361 values = values or [] 362 for k, v in values: 363 if k not in self._values_order: 364 self._values_order.append(k) 365 if k not in self.stateful_metrics: 366 if k not in self._values: 367 self._values[k] = [v * (current - self._seen_so_far), 368 current - self._seen_so_far] 369 else: 370 self._values[k][0] += v * (current - self._seen_so_far) 371 self._values[k][1] += (current - self._seen_so_far) 372 else: 373 # Stateful metrics output a numeric value. This representation 374 # means "take an average from a single value" but keeps the 375 # numeric formatting. 376 self._values[k] = [v, 1] 377 self._seen_so_far = current 378 379 now = time.time() 380 info = ' - %.0fs' % (now - self._start) 381 if self.verbose == 1: 382 if (now - self._last_update < self.interval and 383 self.target is not None and current < self.target): 384 return 385 386 prev_total_width = self._total_width 387 if self._dynamic_display: 388 sys.stdout.write('\b' * prev_total_width) 389 sys.stdout.write('\r') 390 else: 391 sys.stdout.write('\n') 392 393 if self.target is not None: 394 numdigits = int(np.log10(self.target)) + 1 395 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) 396 prog = float(current) / self.target 397 prog_width = int(self.width * prog) 398 if prog_width > 0: 399 bar += ('=' * (prog_width - 1)) 400 if current < self.target: 401 bar += '>' 402 else: 403 bar += '=' 404 bar += ('.' * (self.width - prog_width)) 405 bar += ']' 406 else: 407 bar = '%7d/Unknown' % current 408 409 self._total_width = len(bar) 410 sys.stdout.write(bar) 411 412 if current: 413 time_per_unit = (now - self._start) / current 414 else: 415 time_per_unit = 0 416 if self.target is not None and current < self.target: 417 eta = time_per_unit * (self.target - current) 418 if eta > 3600: 419 eta_format = '%d:%02d:%02d' % (eta // 3600, 420 (eta % 3600) // 60, 421 eta % 60) 422 elif eta > 60: 423 eta_format = '%d:%02d' % (eta // 60, eta % 60) 424 else: 425 eta_format = '%ds' % eta 426 427 info = ' - ETA: %s' % eta_format 428 else: 429 if time_per_unit >= 1 or time_per_unit == 0: 430 info += ' %.0fs/%s' % (time_per_unit, self.unit_name) 431 elif time_per_unit >= 1e-3: 432 info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) 433 else: 434 info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) 435 436 for k in self._values_order: 437 info += ' - %s:' % k 438 if isinstance(self._values[k], list): 439 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 440 if abs(avg) > 1e-3: 441 info += ' %.4f' % avg 442 else: 443 info += ' %.4e' % avg 444 else: 445 info += ' %s' % self._values[k] 446 447 self._total_width += len(info) 448 if prev_total_width > self._total_width: 449 info += (' ' * (prev_total_width - self._total_width)) 450 451 if self.target is not None and current >= self.target: 452 info += '\n' 453 454 sys.stdout.write(info) 455 sys.stdout.flush() 456 457 elif self.verbose == 2: 458 if self.target is not None and current >= self.target: 459 numdigits = int(np.log10(self.target)) + 1 460 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) 461 info = count + info 462 for k in self._values_order: 463 info += ' - %s:' % k 464 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 465 if avg > 1e-3: 466 info += ' %.4f' % avg 467 else: 468 info += ' %.4e' % avg 469 info += '\n' 470 471 sys.stdout.write(info) 472 sys.stdout.flush() 473 474 self._last_update = now 475 476 def add(self, n, values=None): 477 self.update(self._seen_so_far + n, values) 478 479 480def make_batches(size, batch_size): 481 """Returns a list of batch indices (tuples of indices). 482 483 Arguments: 484 size: Integer, total size of the data to slice into batches. 485 batch_size: Integer, batch size. 486 487 Returns: 488 A list of tuples of array indices. 489 """ 490 num_batches = int(np.ceil(size / float(batch_size))) 491 return [(i * batch_size, min(size, (i + 1) * batch_size)) 492 for i in range(0, num_batches)] 493 494 495def slice_arrays(arrays, start=None, stop=None): 496 """Slice an array or list of arrays. 497 498 This takes an array-like, or a list of 499 array-likes, and outputs: 500 - arrays[start:stop] if `arrays` is an array-like 501 - [x[start:stop] for x in arrays] if `arrays` is a list 502 503 Can also work on list/array of indices: `slice_arrays(x, indices)` 504 505 Arguments: 506 arrays: Single array or list of arrays. 507 start: can be an integer index (start index) 508 or a list/array of indices 509 stop: integer (stop index); should be None if 510 `start` was a list. 511 512 Returns: 513 A slice of the array(s). 514 515 Raises: 516 ValueError: If the value of start is a list and stop is not None. 517 """ 518 if arrays is None: 519 return [None] 520 if isinstance(start, list) and stop is not None: 521 raise ValueError('The stop argument has to be None if the value of start ' 522 'is a list.') 523 elif isinstance(arrays, list): 524 if hasattr(start, '__len__'): 525 # hdf5 datasets only support list objects as indices 526 if hasattr(start, 'shape'): 527 start = start.tolist() 528 return [None if x is None else x[start] for x in arrays] 529 else: 530 return [None if x is None else x[start:stop] for x in arrays] 531 else: 532 if hasattr(start, '__len__'): 533 if hasattr(start, 'shape'): 534 start = start.tolist() 535 return arrays[start] 536 elif hasattr(start, '__getitem__'): 537 return arrays[start:stop] 538 else: 539 return [None] 540 541 542def to_list(x): 543 """Normalizes a list/tensor into a list. 544 545 If a tensor is passed, we return 546 a list of size 1 containing the tensor. 547 548 Arguments: 549 x: target object to be normalized. 550 551 Returns: 552 A list. 553 """ 554 if isinstance(x, list): 555 return x 556 return [x] 557 558 559def object_list_uid(object_list): 560 """Creates a single string from object ids.""" 561 object_list = nest.flatten(object_list) 562 return ', '.join([str(abs(id(x))) for x in object_list]) 563 564 565def to_snake_case(name): 566 intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) 567 insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() 568 # If the class is private the name starts with "_" which is not secure 569 # for creating scopes. We prefix the name with "private" in this case. 570 if insecure[0] != '_': 571 return insecure 572 return 'private' + insecure 573 574 575def is_all_none(structure): 576 iterable = nest.flatten(structure) 577 # We cannot use Python's `any` because the iterable may return Tensors. 578 for element in iterable: 579 if element is not None: 580 return False 581 return True 582