• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PYBOOST_GRAD_FUNCTIONS_H
18 #define MINDSPORE_PYBOOST_GRAD_FUNCTIONS_H
19 
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include "kernel/pyboost/op_runner.h"
24 #include "runtime/pynative/op_runner.h"
25 #include "runtime/pynative/op_function/func_object.h"
26 #include "backend/graph_compiler/backend.h"
27 
28 namespace mindspore::runtime {
29 using Func = std::function<void(OpRunnerInfo *, VectorRef *)>;
30 
31 class PyBoostOpExecute {
32  public:
33   static COMMON_EXPORT PyBoostOpExecute &GetInstance();
34 
35   // Register pyboost grad op function
Register(const std::string & key,Func func)36   void Register(const std::string &key, Func func) { grad_op_func_map_[key] = func; }
37 
38   // Check grad op have already registered
39   bool COMMON_EXPORT IsPyBoostOpRegistered(const std::string &op_name);
40 
41   // Unified op run entry for pynative grad
42   void COMMON_EXPORT Execute(OpRunnerInfo *op_runner_info, VectorRef *op_outputs);
43 
44   // Api for outside call
45   void COMMON_EXPORT RunPyBoostCall(OpRunnerInfo *op_runner_info, VectorRef *op_outputs);
46 
47   // Clear backend for fork process.
ClearBackend()48   void ClearBackend() { backend_ = nullptr; }
49 
50  private:
51   // Run op by single op graph
52   void RunOpDeprecated(OpRunnerInfo *op_runner_info, VectorRef *op_outputs);
53 
54   // RunOp in VM
55   void RunOpInVm(OpRunnerInfo *op_runner_info, VectorRef *op_outputs);
56 
57   // Get backend
58   void GetMindRtBackend(const string &cur_device_target);
59 
60   compile::MindRTBackendPtr backend_;
61   std::map<std::string, FuncObject> grad_op_func_map_;
62 };
63 
64 class PyBoostGradOpRegistrar {
65  public:
PyBoostGradOpRegistrar(const std::string & name,const Func & func)66   PyBoostGradOpRegistrar(const std::string &name, const Func &func) {
67     PyBoostOpExecute::GetInstance().Register(name, func);
68   }
69   ~PyBoostGradOpRegistrar() = default;
70 };
71 
72 #define MS_REG_PYBOOST_GRAD_OP(NAME, FUNC) static const PyBoostGradOpRegistrar g_##NAME##_pyboost(#NAME, FUNC);
73 }  // namespace mindspore::runtime
74 #endif  // MINDSPORE_PYBOOST_GRAD_FUNCTIONS_H
75