1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/jit/pi/graph_guard/cache.h"
17 #include <algorithm>
18 #include "pipeline/jit/ps/pipeline.h"
19
20 namespace mindspore {
21 namespace pijit {
OptFunc(NativeFunc cFunc,ReleaseFunc rFunc)22 OptFunc::OptFunc(NativeFunc cFunc, ReleaseFunc rFunc) : cFunc_(cFunc), rFunc_(rFunc) {}
23
~OptFunc()24 OptFunc::~OptFunc() {
25 if (rFunc_ != nullptr) {
26 rFunc_();
27 }
28 }
29
GetFunc()30 NativeFunc OptFunc::GetFunc() { return cFunc_; }
31
operator ==(const OptOption & obj) const32 bool OptOption::operator==(const OptOption &obj) const {
33 if (obj.target_ == this->target_) {
34 return true;
35 } else {
36 return false;
37 }
38 }
39
OptOption(PyCodeObject * code)40 OptOption::OptOption(PyCodeObject *code) { target_ = code; }
41
OptOption(void * ptr)42 OptOption::OptOption(void *ptr) { target_ = ptr; }
43
CreateOptionByCode(PyCodeObject * code)44 std::shared_ptr<OptOption> OptOption::CreateOptionByCode(PyCodeObject *code) {
45 OptOptionPtr ret(new OptOption(code));
46 return ret;
47 }
48
CreateOptionByPoint(void * ptr)49 std::shared_ptr<OptOption> OptOption::CreateOptionByPoint(void *ptr) {
50 OptOptionPtr ret(new OptOption(ptr));
51 return ret;
52 }
53
OptCode()54 OptCode::OptCode() : phase_(""), compiled_code_(), call_count_(0) {
55 guard_ = std::make_shared<OptGuard>();
56 graph_perf_ = std::make_shared<OptPerf>();
57 pynative_perf_ = std::make_shared<OptPerf>();
58 compiled_func_ = nullptr;
59 }
60
~OptCode()61 OptCode::~OptCode() {}
62
SetNativeFunc(const std::string & phase,NativeFunc cFunc,ReleaseFunc rFunc)63 void OptCode::SetNativeFunc(const std::string &phase, NativeFunc cFunc, ReleaseFunc rFunc) {
64 phase_ = phase;
65 compiled_func_ = std::make_shared<OptFunc>(cFunc, rFunc);
66 }
67
GetNativeFunc() const68 NativeFunc OptCode::GetNativeFunc() const {
69 if (compiled_func_ != nullptr) {
70 return compiled_func_->GetFunc();
71 } else {
72 return nullptr;
73 }
74 }
75
GetPhase() const76 std::string OptCode::GetPhase() const { return phase_; }
77
SetPythonCode(const py::object & code)78 void OptCode::SetPythonCode(const py::object &code) {
79 MS_EXCEPTION_IF_CHECK_FAIL(code.ptr() != nullptr && PyCode_Check(code.ptr()) && Py_REFCNT(code.ptr()) == 1,
80 "code handler must be only one");
81 compiled_code_ = code;
82 }
83
GetPythonCode() const84 PyCodeObject *OptCode::GetPythonCode() const { return reinterpret_cast<PyCodeObject *>(compiled_code_.ptr()); }
85
SetGuard(OptGuardPtr guard)86 void OptCode::SetGuard(OptGuardPtr guard) { guard_ = guard; }
87
GetGuard()88 OptGuardPtr OptCode::GetGuard() { return guard_; }
89
SetOption(OptOptionPtr option)90 void OptCode::SetOption(OptOptionPtr option) { option_ = option; }
91
GetOption()92 OptOptionPtr OptCode::GetOption() { return option_; }
93
GetPerf(OptPerf::PerfKind kind)94 OptPerfPtr OptCode::GetPerf(OptPerf::PerfKind kind) {
95 switch (kind) {
96 case OptPerf::PerfKind::kPerfGraph:
97 return graph_perf_;
98 case OptPerf::PerfKind::kPerfPyNative:
99 return pynative_perf_;
100 default:
101 return nullptr;
102 }
103 }
104
Copy(OptCodePtr dst)105 void OptCode::Copy(OptCodePtr dst) {
106 dst->graph_perf_ = graph_perf_;
107 dst->pynative_perf_ = pynative_perf_;
108 dst->phase_ = phase_;
109 dst->compiled_func_ = compiled_func_;
110 }
111
Inc()112 void OptCode::Inc() { call_count_++; }
113
Count()114 uint64_t OptCode::Count() { return call_count_; }
115
AddOptTarget(OptOptionPtr option)116 OptCodePtr OptCodeHub::AddOptTarget(OptOptionPtr option) {
117 OptCodePtr ret;
118 for (auto &item : codeMap_) {
119 if (*(item.first.get()) == *(option.get())) {
120 ret = std::make_shared<OptCode>();
121 item.second.push_back(ret);
122 return ret;
123 }
124 }
125 ret = std::make_shared<OptCode>();
126 codeMap_[option].push_back(ret);
127 ret->SetOption(option);
128 return ret;
129 }
130
GetOptTarget(OptOptionPtr option)131 OptCodeSet OptCodeHub::GetOptTarget(OptOptionPtr option) {
132 for (auto &item : codeMap_) {
133 if (*(item.first.get()) == *(option.get())) {
134 return item.second;
135 }
136 }
137 return {};
138 }
139
UpdateOptTarget(OptOptionPtr option,OptCodePtr code)140 void OptCodeHub::UpdateOptTarget(OptOptionPtr option, OptCodePtr code) {
141 for (auto &item : codeMap_) {
142 if (*(item.first.get()) == *(option.get())) {
143 auto it = std::find(item.second.begin(), item.second.end(), code);
144 if (it != item.second.end()) {
145 item.second.erase(it);
146 item.second.push_back(code);
147 }
148 break;
149 }
150 }
151 }
152
DelOptTarget(OptOptionPtr option,OptCodePtr code)153 void OptCodeHub::DelOptTarget(OptOptionPtr option, OptCodePtr code) {
154 for (auto &item : codeMap_) {
155 if (*(item.first.get()) == *(option.get())) {
156 auto it = std::find(item.second.begin(), item.second.end(), code);
157 if (it != item.second.end()) {
158 item.second.erase(it);
159 }
160 if (item.second.size() == 0) {
161 codeMap_.erase(item.first);
162 }
163 break;
164 }
165 }
166 }
167
DelOptTarget(OptCodePtr code)168 void OptCodeHub::DelOptTarget(OptCodePtr code) {
169 for (auto &item : codeMap_) {
170 auto it = std::find(item.second.begin(), item.second.end(), code);
171 if (it != item.second.end()) {
172 item.second.erase(it);
173 if (item.second.size() == 0) {
174 codeMap_.erase(item.first);
175 }
176 break;
177 }
178 }
179 }
180
GetAllOptTarget()181 std::vector<OptCodeSet> OptCodeHub::GetAllOptTarget() {
182 std::vector<OptCodeSet> ret;
183 std::transform(codeMap_.begin(), codeMap_.end(), std::back_inserter(ret),
184 [](const auto &item) { return item.second; });
185 return ret;
186 }
187
188 using OptCodeWPtr = std::weak_ptr<OptCode>;
189 using OptCodeWSet = std::vector<OptCodeWPtr>;
190 static std::map<std::string, OptCodeWSet> code_set;
191
Register(std::string key,OptCodePtr code)192 void OptCodeHub::Register(std::string key, OptCodePtr code) { code_set[key].emplace_back(code); }
Filter(std::string key,OptCodeFilterFunc filter)193 OptCodePtr OptCodeHub::Filter(std::string key, OptCodeFilterFunc filter) {
194 if (code_set.find(key) != code_set.end()) {
195 OptCodeWSet &codes = code_set[key];
196 for (size_t idx = 0; idx < codes.size();) {
197 OptCodePtr ptr = codes[idx].lock();
198 if (ptr != nullptr) {
199 if (filter(ptr)) {
200 return ptr;
201 }
202 idx++;
203 } else {
204 codes.erase(codes.begin() + idx);
205 }
206 }
207 }
208 return nullptr;
209 }
210 } // namespace pijit
211 } // namespace mindspore
212