1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Registry the relation.""" 17 18from collections import UserDict 19from .primitive import Primitive 20 21 22class Registry(UserDict): 23 """Registry class for registry functions for grad and vm_impl on Primitive.""" 24 25 def register(self, prim): 26 """register the function.""" 27 def deco(fn): 28 """Decorate the function.""" 29 if isinstance(prim, str): 30 self[prim] = fn 31 elif issubclass(prim, Primitive): 32 self[id(prim)] = fn 33 return fn 34 return deco 35 36 def get(self, prim_obj, default): 37 """Get the value by primitive.""" 38 fn = default 39 if isinstance(prim_obj, str) and prim_obj in self: 40 fn = self[prim_obj] 41 elif isinstance(prim_obj, Primitive): 42 key = id(prim_obj.__class__) 43 if key in self: 44 fn = self[key] 45 else: 46 key = prim_obj.name 47 if key in self: 48 fn = self[prim_obj.name] 49 return fn 50 51 52class PyFuncRegistry(UserDict): 53 def register(self, key, value): 54 self[key] = value 55 56 def get(self, key): 57 if key not in self: 58 raise ValueError(f"Python function with key{key} not registered.") 59 return self[key] 60