1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Registry mechanism for "registering" classes/functions for general use. 17 18This is typically used with a decorator that calls Register for adding 19a class or function to a registry. 20""" 21 22import traceback 23 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.util import compat 26 27 28# Registry mechanism below is based on mapreduce.python.mrpython.Register. 29_LOCATION_TAG = "location" 30_TYPE_TAG = "type" 31 32 33class Registry(object): 34 """Provides a registry for saving objects.""" 35 36 __slots__ = ["_name", "_registry"] 37 38 def __init__(self, name): 39 """Creates a new registry.""" 40 self._name = name 41 self._registry = {} 42 43 def register(self, candidate, name=None): 44 """Registers a Python object "candidate" for the given "name". 45 46 Args: 47 candidate: The candidate object to add to the registry. 48 name: An optional string specifying the registry key for the candidate. 49 If None, candidate.__name__ will be used. 50 Raises: 51 KeyError: If same name is used twice. 52 """ 53 if not name: 54 name = candidate.__name__ 55 if name in self._registry: 56 frame = self._registry[name][_LOCATION_TAG] 57 raise KeyError( 58 "Registering two %s with name '%s'! " 59 "(Previous registration was in %s %s:%d)" % 60 (self._name, name, frame.name, frame.filename, frame.lineno)) 61 62 logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name) 63 # stack trace is [this_function, Register(), user_function,...] 64 # so the user function is #2. 65 stack = traceback.extract_stack(limit=3) 66 stack_index = min(2, len(stack) - 1) 67 if stack_index >= 0: 68 location_tag = stack[stack_index] 69 else: 70 location_tag = ("UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN") 71 self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: location_tag} 72 73 def list(self): 74 """Lists registered items. 75 76 Returns: 77 A list of names of registered objects. 78 """ 79 return self._registry.keys() 80 81 def lookup(self, name): 82 """Looks up "name". 83 84 Args: 85 name: a string specifying the registry key for the candidate. 86 Returns: 87 Registered object if found 88 Raises: 89 LookupError: if "name" has not been registered. 90 """ 91 name = compat.as_str(name) 92 if name in self._registry: 93 return self._registry[name][_TYPE_TAG] 94 else: 95 raise LookupError( 96 "%s registry has no entry for: %s" % (self._name, name)) 97