• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Registry mechanism for "registering" classes/functions for general use.
17
18This is typically used with a decorator that calls Register for adding
19a class or function to a registry.
20"""
21
22import traceback
23
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.util import compat
26
27
28# Registry mechanism below is based on mapreduce.python.mrpython.Register.
29_LOCATION_TAG = "location"
30_TYPE_TAG = "type"
31
32
33class Registry(object):
34  """Provides a registry for saving objects."""
35
36  __slots__ = ["_name", "_registry"]
37
38  def __init__(self, name):
39    """Creates a new registry."""
40    self._name = name
41    self._registry = {}
42
43  def register(self, candidate, name=None):
44    """Registers a Python object "candidate" for the given "name".
45
46    Args:
47      candidate: The candidate object to add to the registry.
48      name: An optional string specifying the registry key for the candidate.
49            If None, candidate.__name__ will be used.
50    Raises:
51      KeyError: If same name is used twice.
52    """
53    if not name:
54      name = candidate.__name__
55    if name in self._registry:
56      frame = self._registry[name][_LOCATION_TAG]
57      raise KeyError(
58          "Registering two %s with name '%s'! "
59          "(Previous registration was in %s %s:%d)" %
60          (self._name, name, frame.name, frame.filename, frame.lineno))
61
62    logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
63    # stack trace is [this_function, Register(), user_function,...]
64    # so the user function is #2.
65    stack = traceback.extract_stack(limit=3)
66    stack_index = min(2, len(stack) - 1)
67    if stack_index >= 0:
68      location_tag = stack[stack_index]
69    else:
70      location_tag = ("UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN")
71    self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: location_tag}
72
73  def list(self):
74    """Lists registered items.
75
76    Returns:
77      A list of names of registered objects.
78    """
79    return self._registry.keys()
80
81  def lookup(self, name):
82    """Looks up "name".
83
84    Args:
85      name: a string specifying the registry key for the candidate.
86    Returns:
87      Registered object if found
88    Raises:
89      LookupError: if "name" has not been registered.
90    """
91    name = compat.as_str(name)
92    if name in self._registry:
93      return self._registry[name][_TYPE_TAG]
94    else:
95      raise LookupError(
96          "%s registry has no entry for: %s" % (self._name, name))
97