• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Register pyfunc for py_func_cpu_kernel"""
17
18from __future__ import absolute_import
19from mindspore.ops._register_for_op import PyFuncRegistry
20
21
22class CustomPyFuncRegistry:
23    """
24    Registry class for custom pyfunc function.
25    Key: func id
26    Value : pyfunc
27    """
28
29    def __init__(self):
30        self._func_dict = PyFuncRegistry()
31
32    @classmethod
33    def instance(cls):
34        """
35        Get singleton of CustomPyFuncRegistry.
36
37        Returns:
38            An instance of CustomPyFuncRegistry.
39        """
40        if not hasattr(CustomPyFuncRegistry, "_instance"):
41            CustomPyFuncRegistry._instance = CustomPyFuncRegistry()
42        return CustomPyFuncRegistry._instance
43
44    def register(self, fn_id, fn):
45        """register id, pyfunc to dict"""
46        self._func_dict.register(fn_id, fn)
47
48    def get(self, fn_id):
49        """get pyfunc function by id"""
50        return self._func_dict.get(fn_id)
51
52
53def add_pyfunc(fn_id, fn):
54    CustomPyFuncRegistry.instance().register(fn_id, fn)
55
56
57def get_pyfunc(fn_id):
58    return CustomPyFuncRegistry.instance().get(fn_id)
59