1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_PYBOOST_GRAD_FUNCTIONS_H 18 #define MINDSPORE_PYBOOST_GRAD_FUNCTIONS_H 19 20 #include <map> 21 #include <string> 22 #include <vector> 23 #include "kernel/pyboost/op_runner.h" 24 #include "runtime/pynative/op_runner.h" 25 #include "runtime/pynative/op_function/func_object.h" 26 #include "backend/graph_compiler/backend.h" 27 28 namespace mindspore::runtime { 29 using Func = std::function<void(OpRunnerInfo *, VectorRef *)>; 30 31 class PyBoostOpExecute { 32 public: 33 static COMMON_EXPORT PyBoostOpExecute &GetInstance(); 34 35 // Register pyboost grad op function Register(const std::string & key,Func func)36 void Register(const std::string &key, Func func) { grad_op_func_map_[key] = func; } 37 38 // Check grad op have already registered 39 bool COMMON_EXPORT IsPyBoostOpRegistered(const std::string &op_name); 40 41 // Unified op run entry for pynative grad 42 void COMMON_EXPORT Execute(OpRunnerInfo *op_runner_info, VectorRef *op_outputs); 43 44 // Api for outside call 45 void COMMON_EXPORT RunPyBoostCall(OpRunnerInfo *op_runner_info, VectorRef *op_outputs); 46 47 // Clear backend for fork process. ClearBackend()48 void ClearBackend() { backend_ = nullptr; } 49 50 private: 51 // Run op by single op graph 52 void RunOpDeprecated(OpRunnerInfo *op_runner_info, VectorRef *op_outputs); 53 54 // RunOp in VM 55 void RunOpInVm(OpRunnerInfo *op_runner_info, VectorRef *op_outputs); 56 57 // Get backend 58 void GetMindRtBackend(const string &cur_device_target); 59 60 compile::MindRTBackendPtr backend_; 61 std::map<std::string, FuncObject> grad_op_func_map_; 62 }; 63 64 class PyBoostGradOpRegistrar { 65 public: PyBoostGradOpRegistrar(const std::string & name,const Func & func)66 PyBoostGradOpRegistrar(const std::string &name, const Func &func) { 67 PyBoostOpExecute::GetInstance().Register(name, func); 68 } 69 ~PyBoostGradOpRegistrar() = default; 70 }; 71 72 #define MS_REG_PYBOOST_GRAD_OP(NAME, FUNC) static const PyBoostGradOpRegistrar g_##NAME##_pyboost(#NAME, FUNC); 73 } // namespace mindspore::runtime 74 #endif // MINDSPORE_PYBOOST_GRAD_FUNCTIONS_H 75