1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/ops_frontend_func_impl.h"
18 #include "utils/log_adapter.h"
19
20 namespace mindspore::ops {
GetOpsFrontendFuncImplMapPtr()21 OpsFrontendFuncImplMap *GetOpsFrontendFuncImplMapPtr() {
22 static OpsFrontendFuncImplMap ops_frontend_func_impl_map;
23 return &ops_frontend_func_impl_map;
24 }
25
GetOpFrontendFuncImplPtr(const std::string & name)26 OpFrontendFuncImplPtr GetOpFrontendFuncImplPtr(const std::string &name) {
27 auto iter = GetOpsFrontendFuncImplMapPtr()->find(name);
28 if (iter == GetOpsFrontendFuncImplMapPtr()->end()) {
29 return nullptr;
30 }
31
32 return iter->second.get_func_impl();
33 }
34
RegFrontendFuncImplHelper(const std::string & name,const OpFrontendFuncImplPtr & func_impl)35 RegFrontendFuncImplHelper::RegFrontendFuncImplHelper(const std::string &name, const OpFrontendFuncImplPtr &func_impl) {
36 const FrontendFuncImplHolder holder{func_impl};
37 (void)GetOpsFrontendFuncImplMapPtr()->emplace(name, holder);
38 }
39
GetInstance()40 InferValueCallback &InferValueCallback::GetInstance() {
41 static InferValueCallback instance{};
42 return instance;
43 }
44
RegImpl(const std::string & impl_type,const InferValueFunc & func)45 void InferValueCallback::RegImpl(const std::string &impl_type, const InferValueFunc &func) {
46 if (impl_type == "python_impl") {
47 if (python_impl_) {
48 MS_LOG(ERROR) << "InferValueImpl for python_impl is already registered!";
49 }
50 python_impl_ = func;
51 } else if (impl_type == "cpu_kernel_impl") {
52 if (kernel_impl_) {
53 MS_LOG(ERROR) << "InferValueImpl for cpu_kernel_impl is already registered!";
54 }
55 kernel_impl_ = func;
56 } else {
57 MS_LOG(ERROR) << "Unsupported InferValue implement type " << impl_type << "!";
58 }
59 }
60
CallPyInferValue(const std::string & op_name,const AbstractBasePtrList & input_args)61 ValuePtr InferValueCallback::CallPyInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) {
62 if (python_impl_) {
63 return python_impl_(op_name, input_args);
64 }
65 return nullptr;
66 }
CallKernelInferValue(const std::string & op_name,const AbstractBasePtrList & input_args)67 ValuePtr InferValueCallback::CallKernelInferValue(const std::string &op_name, const AbstractBasePtrList &input_args) {
68 if (kernel_impl_) {
69 return kernel_impl_(op_name, input_args);
70 }
71 return nullptr;
72 }
73
InferValueImplRegister(const std::string & impl_type,const InferValueFunc & fn)74 InferValueImplRegister::InferValueImplRegister(const std::string &impl_type, const InferValueFunc &fn) {
75 InferValueCallback::GetInstance().RegImpl(impl_type, fn);
76 }
77 } // namespace mindspore::ops
78