• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Registry the relation."""
17
18from collections import UserDict
19from .primitive import Primitive
20
21
22class Registry(UserDict):
23    """Registry class for registry functions for grad and vm_impl on Primitive."""
24
25    def register(self, prim):
26        """register the function."""
27        def deco(fn):
28            """Decorate the function."""
29            if isinstance(prim, str):
30                self[prim] = fn
31            elif issubclass(prim, Primitive):
32                self[id(prim)] = fn
33            return fn
34        return deco
35
36    def get(self, prim_obj, default):
37        """Get the value by primitive."""
38        fn = default
39        if isinstance(prim_obj, str) and prim_obj in self:
40            fn = self[prim_obj]
41        elif isinstance(prim_obj, Primitive):
42            key = id(prim_obj.__class__)
43            if key in self:
44                fn = self[key]
45            else:
46                key = prim_obj.name
47                if key in self:
48                    fn = self[prim_obj.name]
49        return fn
50
51
52class PyFuncRegistry(UserDict):
53    def register(self, key, value):
54        self[key] = value
55
56    def get(self, key):
57        if key not in self:
58            raise ValueError(f"Python function with key{key} not registered.")
59        return self[key]
60