• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python utilities required by Keras."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import binascii
21import codecs
22import marshal
23import os
24import re
25import sys
26import time
27import types as python_types
28
29import numpy as np
30import six
31
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_contextlib
34from tensorflow.python.util import tf_decorator
35from tensorflow.python.util import tf_inspect
36from tensorflow.python.util.tf_export import keras_export
37
38_GLOBAL_CUSTOM_OBJECTS = {}
39_GLOBAL_CUSTOM_NAMES = {}
40
41# Flag that determines whether to skip the NotImplementedError when calling
42# get_config in custom models and layers. This is only enabled when saving to
43# SavedModel, when the config isn't required.
44_SKIP_FAILED_SERIALIZATION = False
45# If a layer does not have a defined config, then the returned config will be a
46# dictionary with the below key.
47LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
48
49
50@keras_export('keras.utils.CustomObjectScope')
51class CustomObjectScope(object):
52  """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
53
54  Code within a `with` statement will be able to access custom objects
55  by name. Changes to global custom objects persist
56  within the enclosing `with` statement. At end of the `with` statement,
57  global custom objects are reverted to state
58  at beginning of the `with` statement.
59
60  Example:
61
62  Consider a custom object `MyObject` (e.g. a class):
63
64  ```python
65      with CustomObjectScope({'MyObject':MyObject}):
66          layer = Dense(..., kernel_regularizer='MyObject')
67          # save, load, etc. will recognize custom object by name
68  ```
69  """
70
71  def __init__(self, *args):
72    self.custom_objects = args
73    self.backup = None
74
75  def __enter__(self):
76    self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
77    for objects in self.custom_objects:
78      _GLOBAL_CUSTOM_OBJECTS.update(objects)
79    return self
80
81  def __exit__(self, *args, **kwargs):
82    _GLOBAL_CUSTOM_OBJECTS.clear()
83    _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
84
85
86@keras_export('keras.utils.custom_object_scope')
87def custom_object_scope(*args):
88  """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
89
90  Convenience wrapper for `CustomObjectScope`.
91  Code within a `with` statement will be able to access custom objects
92  by name. Changes to global custom objects persist
93  within the enclosing `with` statement. At end of the `with` statement,
94  global custom objects are reverted to state
95  at beginning of the `with` statement.
96
97  Example:
98
99  Consider a custom object `MyObject`
100
101  ```python
102      with custom_object_scope({'MyObject':MyObject}):
103          layer = Dense(..., kernel_regularizer='MyObject')
104          # save, load, etc. will recognize custom object by name
105  ```
106
107  Arguments:
108      *args: Variable length list of dictionaries of name, class pairs to add to
109        custom objects.
110
111  Returns:
112      Object of type `CustomObjectScope`.
113  """
114  return CustomObjectScope(*args)
115
116
117@keras_export('keras.utils.get_custom_objects')
118def get_custom_objects():
119  """Retrieves a live reference to the global dictionary of custom objects.
120
121  Updating and clearing custom objects using `custom_object_scope`
122  is preferred, but `get_custom_objects` can
123  be used to directly access `_GLOBAL_CUSTOM_OBJECTS`.
124
125  Example:
126
127  ```python
128      get_custom_objects().clear()
129      get_custom_objects()['MyObject'] = MyObject
130  ```
131
132  Returns:
133      Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
134  """
135  return _GLOBAL_CUSTOM_OBJECTS
136
137
138def serialize_keras_class_and_config(cls_name, cls_config):
139  """Returns the serialization of the class with the given config."""
140  return {'class_name': cls_name, 'config': cls_config}
141
142
143@keras_export('keras.utils.register_keras_serializable')
144def register_keras_serializable(package='Custom', name=None):
145  """Registers an object with the Keras serialization framework.
146
147  This decorator injects the decorated class or function into the Keras custom
148  object dictionary, so that it can be serialized and deserialized without
149  needing an entry in the user-provided custom object dict. It also injects a
150  function that Keras will call to get the object's serializable string key.
151
152  Note that to be serialized and deserialized, classes must implement the
153  `get_config()` method. Functions do not have this requirement.
154
155  The object will be registered under the key 'package>name' where `name`,
156  defaults to the object name if not passed.
157
158  Arguments:
159    package: The package that this class belongs to.
160    name: The name to serialize this class under in this package. If None, the
161      class's name will be used.
162
163  Returns:
164    A decorator that registers the decorated class with the passed names.
165  """
166
167  def decorator(arg):
168    """Registers a class with the Keras serialization framework."""
169    class_name = name if name is not None else arg.__name__
170    registered_name = package + '>' + class_name
171
172    if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
173      raise ValueError(
174          'Cannot register a class that does not have a get_config() method.')
175
176    if registered_name in _GLOBAL_CUSTOM_OBJECTS:
177      raise ValueError(
178          '%s has already been registered to %s' %
179          (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))
180
181    if arg in _GLOBAL_CUSTOM_NAMES:
182      raise ValueError('%s has already been registered to %s' %
183                       (arg, _GLOBAL_CUSTOM_NAMES[arg]))
184    _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
185    _GLOBAL_CUSTOM_NAMES[arg] = registered_name
186
187    return arg
188
189  return decorator
190
191
192@keras_export('keras.utils.get_registered_name')
193def get_registered_name(obj):
194  """Returns the name registered to an object within the Keras framework.
195
196  This function is part of the Keras serialization and deserialization
197  framework. It maps objects to the string names associated with those objects
198  for serialization/deserialization.
199
200  Args:
201    obj: The object to look up.
202
203  Returns:
204    The name associated with the object, or the default Python name if the
205      object is not registered.
206  """
207  if obj in _GLOBAL_CUSTOM_NAMES:
208    return _GLOBAL_CUSTOM_NAMES[obj]
209  else:
210    return obj.__name__
211
212
213@tf_contextlib.contextmanager
214def skip_failed_serialization():
215  global _SKIP_FAILED_SERIALIZATION
216  prev = _SKIP_FAILED_SERIALIZATION
217  try:
218    _SKIP_FAILED_SERIALIZATION = True
219    yield
220  finally:
221    _SKIP_FAILED_SERIALIZATION = prev
222
223
224@keras_export('keras.utils.get_registered_object')
225def get_registered_object(name, custom_objects=None, module_objects=None):
226  """Returns the class associated with `name` if it is registered with Keras.
227
228  This function is part of the Keras serialization and deserialization
229  framework. It maps strings to the objects associated with them for
230  serialization/deserialization.
231
232  Example:
233  ```
234  def from_config(cls, config, custom_objects=None):
235    if 'my_custom_object_name' in config:
236      config['hidden_cls'] = tf.keras.utils.get_registered_object(
237          config['my_custom_object_name'], custom_objects=custom_objects)
238  ```
239
240  Args:
241    name: The name to look up.
242    custom_objects: A dictionary of custom objects to look the name up in.
243      Generally, custom_objects is provided by the user.
244    module_objects: A dictionary of custom objects to look the name up in.
245      Generally, module_objects is provided by midlevel library implementers.
246
247  Returns:
248    An instantiable class associated with 'name', or None if no such class
249      exists.
250  """
251  if name in _GLOBAL_CUSTOM_OBJECTS:
252    return _GLOBAL_CUSTOM_OBJECTS[name]
253  elif custom_objects and name in custom_objects:
254    return custom_objects[name]
255  elif module_objects and name in module_objects:
256    return module_objects[name]
257  return None
258
259
260@keras_export('keras.utils.serialize_keras_object')
261def serialize_keras_object(instance):
262  """Serialize Keras object into JSON."""
263  _, instance = tf_decorator.unwrap(instance)
264  if instance is None:
265    return None
266
267  if hasattr(instance, 'get_config'):
268    name = get_registered_name(instance.__class__)
269    try:
270      config = instance.get_config()
271    except NotImplementedError as e:
272      if _SKIP_FAILED_SERIALIZATION:
273        return serialize_keras_class_and_config(
274            name, {LAYER_UNDEFINED_CONFIG_KEY: True})
275      raise e
276    serialization_config = {}
277    for key, item in config.items():
278      if isinstance(item, six.string_types):
279        serialization_config[key] = item
280        continue
281
282      # Any object of a different type needs to be converted to string or dict
283      # for serialization (e.g. custom functions, custom classes)
284      try:
285        serialized_item = serialize_keras_object(item)
286        if isinstance(serialized_item, dict) and not isinstance(item, dict):
287          serialized_item['__passive_serialization__'] = True
288        serialization_config[key] = serialized_item
289      except ValueError:
290        serialization_config[key] = item
291
292    name = get_registered_name(instance.__class__)
293    return serialize_keras_class_and_config(name, serialization_config)
294  if hasattr(instance, '__name__'):
295    return get_registered_name(instance)
296  raise ValueError('Cannot serialize', instance)
297
298
299def get_custom_objects_by_name(item, custom_objects=None):
300  """Returns the item if it is in either local or global custom objects."""
301  if item in _GLOBAL_CUSTOM_OBJECTS:
302    return _GLOBAL_CUSTOM_OBJECTS[item]
303  elif custom_objects and item in custom_objects:
304    return custom_objects[item]
305  return None
306
307
308def class_and_config_for_serialized_keras_object(
309    config,
310    module_objects=None,
311    custom_objects=None,
312    printable_module_name='object'):
313  """Returns the class name and config for a serialized keras object."""
314  if (not isinstance(config, dict) or 'class_name' not in config or
315      'config' not in config):
316    raise ValueError('Improper config format: ' + str(config))
317
318  class_name = config['class_name']
319  cls = get_registered_object(class_name, custom_objects, module_objects)
320  if cls is None:
321    raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
322
323  cls_config = config['config']
324  deserialized_objects = {}
325  for key, item in cls_config.items():
326    if isinstance(item, dict) and '__passive_serialization__' in item:
327      deserialized_objects[key] = deserialize_keras_object(
328          item,
329          module_objects=module_objects,
330          custom_objects=custom_objects,
331          printable_module_name='config_item')
332    # TODO(momernick): Should this also have 'module_objects'?
333    elif (isinstance(item, six.string_types) and
334          tf_inspect.isfunction(get_registered_object(item, custom_objects))):
335      # Handle custom functions here. When saving functions, we only save the
336      # function's name as a string. If we find a matching string in the custom
337      # objects during deserialization, we convert the string back to the
338      # original function.
339      # Note that a potential issue is that a string field could have a naming
340      # conflict with a custom function name, but this should be a rare case.
341      # This issue does not occur if a string field has a naming conflict with
342      # a custom object, since the config of an object will always be a dict.
343      deserialized_objects[key] = get_registered_object(item, custom_objects)
344  for key, item in deserialized_objects.items():
345    cls_config[key] = deserialized_objects[key]
346
347  return (cls, cls_config)
348
349
350@keras_export('keras.utils.deserialize_keras_object')
351def deserialize_keras_object(identifier,
352                             module_objects=None,
353                             custom_objects=None,
354                             printable_module_name='object'):
355  if identifier is None:
356    return None
357
358  if isinstance(identifier, dict):
359    # In this case we are dealing with a Keras config dictionary.
360    config = identifier
361    (cls, cls_config) = class_and_config_for_serialized_keras_object(
362        config, module_objects, custom_objects, printable_module_name)
363
364    if hasattr(cls, 'from_config'):
365      arg_spec = tf_inspect.getfullargspec(cls.from_config)
366      custom_objects = custom_objects or {}
367
368      if 'custom_objects' in arg_spec.args:
369        return cls.from_config(
370            cls_config,
371            custom_objects=dict(
372                list(_GLOBAL_CUSTOM_OBJECTS.items()) +
373                list(custom_objects.items())))
374      with CustomObjectScope(custom_objects):
375        return cls.from_config(cls_config)
376    else:
377      # Then `cls` may be a function returning a class.
378      # in this case by convention `config` holds
379      # the kwargs of the function.
380      custom_objects = custom_objects or {}
381      with CustomObjectScope(custom_objects):
382        return cls(**cls_config)
383  elif isinstance(identifier, six.string_types):
384    object_name = identifier
385    if custom_objects and object_name in custom_objects:
386      obj = custom_objects.get(object_name)
387    elif object_name in _GLOBAL_CUSTOM_OBJECTS:
388      obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
389    else:
390      obj = module_objects.get(object_name)
391      if obj is None:
392        raise ValueError('Unknown ' + printable_module_name + ':' + object_name)
393    # Classes passed by name are instantiated with no args, functions are
394    # returned as-is.
395    if tf_inspect.isclass(obj):
396      return obj()
397    return obj
398  elif tf_inspect.isfunction(identifier):
399    # If a function has already been deserialized, return as is.
400    return identifier
401  else:
402    raise ValueError('Could not interpret serialized %s: %s' %
403                     (printable_module_name, identifier))
404
405
406def func_dump(func):
407  """Serializes a user defined function.
408
409  Arguments:
410      func: the function to serialize.
411
412  Returns:
413      A tuple `(code, defaults, closure)`.
414  """
415  if os.name == 'nt':
416    raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
417    code = codecs.encode(raw_code, 'base64').decode('ascii')
418  else:
419    raw_code = marshal.dumps(func.__code__)
420    code = codecs.encode(raw_code, 'base64').decode('ascii')
421  defaults = func.__defaults__
422  if func.__closure__:
423    closure = tuple(c.cell_contents for c in func.__closure__)
424  else:
425    closure = None
426  return code, defaults, closure
427
428
429def func_load(code, defaults=None, closure=None, globs=None):
430  """Deserializes a user defined function.
431
432  Arguments:
433      code: bytecode of the function.
434      defaults: defaults of the function.
435      closure: closure of the function.
436      globs: dictionary of global objects.
437
438  Returns:
439      A function object.
440  """
441  if isinstance(code, (tuple, list)):  # unpack previous dump
442    code, defaults, closure = code
443    if isinstance(defaults, list):
444      defaults = tuple(defaults)
445
446  def ensure_value_to_cell(value):
447    """Ensures that a value is converted to a python cell object.
448
449    Arguments:
450        value: Any value that needs to be casted to the cell type
451
452    Returns:
453        A value wrapped as a cell object (see function "func_load")
454    """
455
456    def dummy_fn():
457      # pylint: disable=pointless-statement
458      value  # just access it so it gets captured in .__closure__
459
460    cell_value = dummy_fn.__closure__[0]
461    if not isinstance(value, type(cell_value)):
462      return cell_value
463    return value
464
465  if closure is not None:
466    closure = tuple(ensure_value_to_cell(_) for _ in closure)
467  try:
468    raw_code = codecs.decode(code.encode('ascii'), 'base64')
469  except (UnicodeEncodeError, binascii.Error):
470    raw_code = code.encode('raw_unicode_escape')
471  code = marshal.loads(raw_code)
472  if globs is None:
473    globs = globals()
474  return python_types.FunctionType(
475      code, globs, name=code.co_name, argdefs=defaults, closure=closure)
476
477
478def has_arg(fn, name, accept_all=False):
479  """Checks if a callable accepts a given keyword argument.
480
481  Arguments:
482      fn: Callable to inspect.
483      name: Check if `fn` can be called with `name` as a keyword argument.
484      accept_all: What to return if there is no parameter called `name` but the
485        function accepts a `**kwargs` argument.
486
487  Returns:
488      bool, whether `fn` accepts a `name` keyword argument.
489  """
490  arg_spec = tf_inspect.getfullargspec(fn)
491  if accept_all and arg_spec.varkw is not None:
492    return True
493  return name in arg_spec.args
494
495
496@keras_export('keras.utils.Progbar')
497class Progbar(object):
498  """Displays a progress bar.
499
500  Arguments:
501      target: Total number of steps expected, None if unknown.
502      width: Progress bar width on screen.
503      verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
504      stateful_metrics: Iterable of string names of metrics that should *not* be
505        averaged over time. Metrics in this list will be displayed as-is. All
506        others will be averaged by the progbar before display.
507      interval: Minimum visual progress update interval (in seconds).
508      unit_name: Display name for step counts (usually "step" or "sample").
509  """
510
511  def __init__(self,
512               target,
513               width=30,
514               verbose=1,
515               interval=0.05,
516               stateful_metrics=None,
517               unit_name='step'):
518    self.target = target
519    self.width = width
520    self.verbose = verbose
521    self.interval = interval
522    self.unit_name = unit_name
523    if stateful_metrics:
524      self.stateful_metrics = set(stateful_metrics)
525    else:
526      self.stateful_metrics = set()
527
528    self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
529                              sys.stdout.isatty()) or
530                             'ipykernel' in sys.modules or
531                             'posix' in sys.modules or
532                             'PYCHARM_HOSTED' in os.environ)
533    self._total_width = 0
534    self._seen_so_far = 0
535    # We use a dict + list to avoid garbage collection
536    # issues found in OrderedDict
537    self._values = {}
538    self._values_order = []
539    self._start = time.time()
540    self._last_update = 0
541
542  def update(self, current, values=None):
543    """Updates the progress bar.
544
545    Arguments:
546        current: Index of current step.
547        values: List of tuples: `(name, value_for_last_step)`. If `name` is in
548          `stateful_metrics`, `value_for_last_step` will be displayed as-is.
549          Else, an average of the metric over time will be displayed.
550    """
551    values = values or []
552    for k, v in values:
553      if k not in self._values_order:
554        self._values_order.append(k)
555      if k not in self.stateful_metrics:
556        # In the case that progress bar doesn't have a target value in the first
557        # epoch, both on_batch_end and on_epoch_end will be called, which will
558        # cause 'current' and 'self._seen_so_far' to have the same value. Force
559        # the minimal value to 1 here, otherwise stateful_metric will be 0s.
560        value_base = max(current - self._seen_so_far, 1)
561        if k not in self._values:
562          self._values[k] = [v * value_base, value_base]
563        else:
564          self._values[k][0] += v * value_base
565          self._values[k][1] += value_base
566      else:
567        # Stateful metrics output a numeric value. This representation
568        # means "take an average from a single value" but keeps the
569        # numeric formatting.
570        self._values[k] = [v, 1]
571    self._seen_so_far = current
572
573    now = time.time()
574    info = ' - %.0fs' % (now - self._start)
575    if self.verbose == 1:
576      if (now - self._last_update < self.interval and
577          self.target is not None and current < self.target):
578        return
579
580      prev_total_width = self._total_width
581      if self._dynamic_display:
582        sys.stdout.write('\b' * prev_total_width)
583        sys.stdout.write('\r')
584      else:
585        sys.stdout.write('\n')
586
587      if self.target is not None:
588        numdigits = int(np.log10(self.target)) + 1
589        bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
590        prog = float(current) / self.target
591        prog_width = int(self.width * prog)
592        if prog_width > 0:
593          bar += ('=' * (prog_width - 1))
594          if current < self.target:
595            bar += '>'
596          else:
597            bar += '='
598        bar += ('.' * (self.width - prog_width))
599        bar += ']'
600      else:
601        bar = '%7d/Unknown' % current
602
603      self._total_width = len(bar)
604      sys.stdout.write(bar)
605
606      if current:
607        time_per_unit = (now - self._start) / current
608      else:
609        time_per_unit = 0
610      if self.target is not None and current < self.target:
611        eta = time_per_unit * (self.target - current)
612        if eta > 3600:
613          eta_format = '%d:%02d:%02d' % (eta // 3600,
614                                         (eta % 3600) // 60, eta % 60)
615        elif eta > 60:
616          eta_format = '%d:%02d' % (eta // 60, eta % 60)
617        else:
618          eta_format = '%ds' % eta
619
620        info = ' - ETA: %s' % eta_format
621      else:
622        if time_per_unit >= 1 or time_per_unit == 0:
623          info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
624        elif time_per_unit >= 1e-3:
625          info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
626        else:
627          info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
628
629      for k in self._values_order:
630        info += ' - %s:' % k
631        if isinstance(self._values[k], list):
632          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
633          if abs(avg) > 1e-3:
634            info += ' %.4f' % avg
635          else:
636            info += ' %.4e' % avg
637        else:
638          info += ' %s' % self._values[k]
639
640      self._total_width += len(info)
641      if prev_total_width > self._total_width:
642        info += (' ' * (prev_total_width - self._total_width))
643
644      if self.target is not None and current >= self.target:
645        info += '\n'
646
647      sys.stdout.write(info)
648      sys.stdout.flush()
649
650    elif self.verbose == 2:
651      if self.target is not None and current >= self.target:
652        numdigits = int(np.log10(self.target)) + 1
653        count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
654        info = count + info
655        for k in self._values_order:
656          info += ' - %s:' % k
657          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
658          if avg > 1e-3:
659            info += ' %.4f' % avg
660          else:
661            info += ' %.4e' % avg
662        info += '\n'
663
664        sys.stdout.write(info)
665        sys.stdout.flush()
666
667    self._last_update = now
668
669  def add(self, n, values=None):
670    self.update(self._seen_so_far + n, values)
671
672
673def make_batches(size, batch_size):
674  """Returns a list of batch indices (tuples of indices).
675
676  Arguments:
677      size: Integer, total size of the data to slice into batches.
678      batch_size: Integer, batch size.
679
680  Returns:
681      A list of tuples of array indices.
682  """
683  num_batches = int(np.ceil(size / float(batch_size)))
684  return [(i * batch_size, min(size, (i + 1) * batch_size))
685          for i in range(0, num_batches)]
686
687
688def slice_arrays(arrays, start=None, stop=None):
689  """Slice an array or list of arrays.
690
691  This takes an array-like, or a list of
692  array-likes, and outputs:
693      - arrays[start:stop] if `arrays` is an array-like
694      - [x[start:stop] for x in arrays] if `arrays` is a list
695
696  Can also work on list/array of indices: `slice_arrays(x, indices)`
697
698  Arguments:
699      arrays: Single array or list of arrays.
700      start: can be an integer index (start index) or a list/array of indices
701      stop: integer (stop index); should be None if `start` was a list.
702
703  Returns:
704      A slice of the array(s).
705
706  Raises:
707      ValueError: If the value of start is a list and stop is not None.
708  """
709  if arrays is None:
710    return [None]
711  if isinstance(start, list) and stop is not None:
712    raise ValueError('The stop argument has to be None if the value of start '
713                     'is a list.')
714  elif isinstance(arrays, list):
715    if hasattr(start, '__len__'):
716      # hdf5 datasets only support list objects as indices
717      if hasattr(start, 'shape'):
718        start = start.tolist()
719      return [None if x is None else x[start] for x in arrays]
720    return [
721        None if x is None else
722        None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays
723    ]
724  else:
725    if hasattr(start, '__len__'):
726      if hasattr(start, 'shape'):
727        start = start.tolist()
728      return arrays[start]
729    if hasattr(start, '__getitem__'):
730      return arrays[start:stop]
731    return [None]
732
733
734def to_list(x):
735  """Normalizes a list/tensor into a list.
736
737  If a tensor is passed, we return
738  a list of size 1 containing the tensor.
739
740  Arguments:
741      x: target object to be normalized.
742
743  Returns:
744      A list.
745  """
746  if isinstance(x, list):
747    return x
748  return [x]
749
750
751def object_list_uid(object_list):
752  """Creates a single string from object ids."""
753  object_list = nest.flatten(object_list)
754  return ', '.join(str(abs(id(x))) for x in object_list)
755
756
757def to_snake_case(name):
758  intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
759  insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
760  # If the class is private the name starts with "_" which is not secure
761  # for creating scopes. We prefix the name with "private" in this case.
762  if insecure[0] != '_':
763    return insecure
764  return 'private' + insecure
765
766
767def is_all_none(structure):
768  iterable = nest.flatten(structure)
769  # We cannot use Python's `any` because the iterable may return Tensors.
770  for element in iterable:
771    if element is not None:
772      return False
773  return True
774
775
776def check_for_unexpected_keys(name, input_dict, expected_values):
777  unknown = set(input_dict.keys()).difference(expected_values)
778  if unknown:
779    raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
780                     'following keys: {}'.format(name, list(unknown),
781                                                 expected_values))
782
783
784def validate_kwargs(kwargs,
785                    allowed_kwargs,
786                    error_message='Keyword argument not understood:'):
787  """Checks that all keyword arguments are in the set of allowed keys."""
788  for kwarg in kwargs:
789    if kwarg not in allowed_kwargs:
790      raise TypeError(error_message, kwarg)
791