• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Registry the relation."""
17from __future__ import absolute_import
18from __future__ import division
19
20from collections import UserDict
21
22from mindspore.ops.primitive import Primitive
23
24
25class Registry(UserDict):
26    """Registry class for registry functions for grad and vm_impl on Primitive."""
27
28    def register(self, prim):
29        """register the function."""
30        def deco(fn):
31            """Decorate the function."""
32            if isinstance(prim, str):
33                self[prim] = fn
34            elif issubclass(prim, Primitive):
35                self[id(prim)] = fn
36            return fn
37        return deco
38
39    def get(self, prim_obj, default):
40        """Get the value by primitive."""
41        fn = default
42        if isinstance(prim_obj, str) and prim_obj in self:
43            fn = self[prim_obj]
44        elif isinstance(prim_obj, Primitive):
45            key = id(prim_obj.__class__)
46            if key in self:
47                fn = self[key]
48            else:
49                key = prim_obj.name
50                if key in self:
51                    fn = self[prim_obj.name]
52        return fn
53
54
55class PyFuncRegistry(UserDict):
56    def register(self, key, value):
57        self[key] = value
58
59    def get(self, key):
60        if key not in self:
61            raise ValueError(f"Python function with key{key} not registered.")
62        return self[key]
63
64
65class OpaquePredicateRegistry(PyFuncRegistry):
66    """Registry opaque predicate functions used for dynamic obfuscation"""
67    def __init__(self):
68        super(OpaquePredicateRegistry, self).__init__()
69        self.func_names = []
70
71    def register(self, key, value):
72        self[key] = value
73        self.func_names.append(key)
74