• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_CACHE_H
17 #define MINDSPORE_PI_JIT_CACHE_H
18 
19 #include <memory>
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include "pybind11/pybind11.h"
24 #include "include/common/utils/python_adapter.h"
25 #include "pipeline/jit/pi/graph_guard/guard.h"
26 #include "pipeline/jit/pi/graph_guard/perf.h"
27 
28 namespace mindspore {
29 namespace pijit {
30 using NativeFunc = std::function<PyObject *(PyObject *, PyObject *)>;
31 using ReleaseFunc = std::function<void()>;
32 class OptFunc {
33  public:
34   OptFunc(NativeFunc cFunc, ReleaseFunc rFunc);
35   virtual ~OptFunc();
36   NativeFunc GetFunc();
37 
38  protected:
39   NativeFunc cFunc_;
40   ReleaseFunc rFunc_;
41 };
42 using OptFuncPtr = std::shared_ptr<OptFunc>;
43 
44 /// \brief OptOption is the compilation option for the code
45 class OptOption : public std::enable_shared_from_this<OptOption> {
46  public:
47   /// \brief no support for default construction and you can extend the option class to support more feature
48   OptOption() = delete;
49   virtual ~OptOption() = default;
50   /// \brief support create option by PyCodeObject
51   static std::shared_ptr<OptOption> CreateOptionByCode(PyCodeObject *code);
52   static std::shared_ptr<OptOption> CreateOptionByPoint(void *ptr);
53   bool operator==(const OptOption &obj) const;
54 
55  protected:
56   explicit OptOption(PyCodeObject *code);
57   explicit OptOption(void *ptr);
58   void *target_;
59 };
60 using OptOptionPtr = std::shared_ptr<OptOption>;
61 
62 /// \brief optimized code with native function graph and guard based on the compilation option
63 class OptCode : public std::enable_shared_from_this<OptCode> {
64  public:
65   OptCode();
66   virtual ~OptCode();
67   virtual void SetGuard(OptGuardPtr guard);
68   virtual OptGuardPtr GetGuard();
69   virtual void SetOption(OptOptionPtr option);
70   virtual OptOptionPtr GetOption();
71   virtual OptPerfPtr GetPerf(OptPerf::PerfKind kind);
72 
73   void SetPythonCode(const py::object &code);
74   PyCodeObject *GetPythonCode() const;
75   void SetNativeFunc(const std::string &phase, NativeFunc cFunc, ReleaseFunc rFunc);
76   NativeFunc GetNativeFunc() const;
77   std::string GetPhase() const;
78   void Copy(std::shared_ptr<OptCode> dst);
79   void Inc();
80   uint64_t Count();
81 
82  protected:
83   std::string phase_;
84   OptFuncPtr compiled_func_;
85   py::object compiled_code_;
86   OptGuardPtr guard_;
87   OptOptionPtr option_;
88   OptPerfPtr graph_perf_;
89   OptPerfPtr pynative_perf_;
90   uint64_t call_count_;
91 };
92 using OptCodePtr = std::shared_ptr<OptCode>;
93 using OptCodeSet = std::vector<OptCodePtr>;
94 
95 using OptCodeFilterFunc = std::function<bool(OptCodePtr)>;
96 /// \brief hub for optimized code based on compilation option
97 class OptCodeHub : public std::enable_shared_from_this<OptCodeHub> {
98  public:
99   OptCodeHub() = default;
100   virtual ~OptCodeHub() = default;
101   virtual OptCodePtr AddOptTarget(OptOptionPtr option);
102   virtual OptCodeSet GetOptTarget(OptOptionPtr option);
103   virtual void UpdateOptTarget(OptOptionPtr option, OptCodePtr code);
104   virtual void DelOptTarget(OptOptionPtr option, OptCodePtr code);
105   virtual void DelOptTarget(OptCodePtr code);
106   virtual std::vector<OptCodeSet> GetAllOptTarget();
107   static void Register(std::string key, OptCodePtr code);
108   static OptCodePtr Filter(std::string key, OptCodeFilterFunc filter);
109 
110  protected:
111   std::map<OptOptionPtr, OptCodeSet> codeMap_;
112 };
113 
114 using OptCodeHubPtr = std::shared_ptr<OptCodeHub>;
115 }  // namespace pijit
116 }  // namespace mindspore
117 
118 #endif  // MINDSPORE_PI_JIT_CACHE_H
119