• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/ops_frontend_func_impl.h"
18 #include "utils/log_adapter.h"
19 
20 namespace mindspore::ops {
GetOpsFrontendFuncImplMapPtr()21 OpsFrontendFuncImplMap *GetOpsFrontendFuncImplMapPtr() {
22   static OpsFrontendFuncImplMap ops_frontend_func_impl_map;
23   return &ops_frontend_func_impl_map;
24 }
25 
GetOpFrontendFuncImplPtr(const std::string & name)26 OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name) {
27   auto iter = GetOpsFrontendFuncImplMapPtr()->find(name);
28   if (iter == GetOpsFrontendFuncImplMapPtr()->end()) {
29     return nullptr;
30   }
31 
32   return iter->second.get_func_impl();
33 }
34 
RegFrontendFuncImplHelper(const std::string & name,const OpFrontendFuncImplPtr & func_impl)35 RegFrontendFuncImplHelper::RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl) {
36   const FrontendFuncImplHolder holder{func_impl};
37   (void)GetOpsFrontendFuncImplMapPtr()->emplace(name, holder);
38 }
39 
GetInstance()40 InferValueCallback &InferValueCallback::GetInstance() {
41   static InferValueCallback instance{};
42   return instance;
43 }
44 
RegImpl(const std::string & impl_type,const InferValueFunc & func)45 void InferValueCallback::RegImpl(const std::string &impl_type, const InferValueFunc &func) {
46   if (impl_type == "python_impl") {
47     if (python_impl_) {
48       MS_LOG(ERROR) << "InferValueImpl for python_impl is already registered!";
49     }
50     python_impl_ = func;
51   } else if (impl_type == "cpu_kernel_impl") {
52     if (kernel_impl_) {
53       MS_LOG(ERROR) << "InferValueImpl for cpu_kernel_impl is already registered!";
54     }
55     kernel_impl_ = func;
56   } else {
57     MS_LOG(ERROR) << "Unsupported InferValue implement type " << impl_type << "!";
58   }
59 }
60 
CallPyInferValue(const std::string & op_name,const AbstractBasePtrList & input_args)61 ValuePtr InferValueCallback::CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) {
62   if (python_impl_) {
63     return python_impl_(op_name, input_args);
64   }
65   return nullptr;
66 }
CallKernelInferValue(const std::string & op_name,const AbstractBasePtrList & input_args)67 ValuePtr InferValueCallback::CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) {
68   if (kernel_impl_) {
69     return kernel_impl_(op_name, input_args);
70   }
71   return nullptr;
72 }
73 
InferValueImplRegister(const std::string & impl_type,const InferValueFunc & fn)74 InferValueImplRegister::InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn) {
75   InferValueCallback::GetInstance().RegImpl(impl_type, fn);
76 }
77 }  //  namespace mindspore::ops
78