• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/jit/pi/graph_guard/cache.h"
17 #include <algorithm>
18 #include "pipeline/jit/ps/pipeline.h"
19 
20 namespace mindspore {
21 namespace pijit {
OptFunc(NativeFunc cFunc,ReleaseFunc rFunc)22 OptFunc::OptFunc(NativeFunc cFunc, ReleaseFunc rFunc) : cFunc_(cFunc), rFunc_(rFunc) {}
23 
~OptFunc()24 OptFunc::~OptFunc() {
25   if (rFunc_ != nullptr) {
26     rFunc_();
27   }
28 }
29 
GetFunc()30 NativeFunc OptFunc::GetFunc() { return cFunc_; }
31 
operator ==(const OptOption & obj) const32 bool OptOption::operator==(const OptOption &obj) const {
33   if (obj.target_ == this->target_) {
34     return true;
35   } else {
36     return false;
37   }
38 }
39 
OptOption(PyCodeObject * code)40 OptOption::OptOption(PyCodeObject *code) { target_ = code; }
41 
OptOption(void * ptr)42 OptOption::OptOption(void *ptr) { target_ = ptr; }
43 
CreateOptionByCode(PyCodeObject * code)44 std::shared_ptr<OptOption> OptOption::CreateOptionByCode(PyCodeObject *code) {
45   OptOptionPtr ret(new OptOption(code));
46   return ret;
47 }
48 
CreateOptionByPoint(void * ptr)49 std::shared_ptr<OptOption> OptOption::CreateOptionByPoint(void *ptr) {
50   OptOptionPtr ret(new OptOption(ptr));
51   return ret;
52 }
53 
OptCode()54 OptCode::OptCode() : phase_(""), compiled_code_(), call_count_(0) {
55   guard_ = std::make_shared<OptGuard>();
56   graph_perf_ = std::make_shared<OptPerf>();
57   pynative_perf_ = std::make_shared<OptPerf>();
58   compiled_func_ = nullptr;
59 }
60 
~OptCode()61 OptCode::~OptCode() {}
62 
SetNativeFunc(const std::string & phase,NativeFunc cFunc,ReleaseFunc rFunc)63 void OptCode::SetNativeFunc(const std::string &phase, NativeFunc cFunc, ReleaseFunc rFunc) {
64   phase_ = phase;
65   compiled_func_ = std::make_shared<OptFunc>(cFunc, rFunc);
66 }
67 
GetNativeFunc() const68 NativeFunc OptCode::GetNativeFunc() const {
69   if (compiled_func_ != nullptr) {
70     return compiled_func_->GetFunc();
71   } else {
72     return nullptr;
73   }
74 }
75 
GetPhase() const76 std::string OptCode::GetPhase() const { return phase_; }
77 
SetPythonCode(const py::object & code)78 void OptCode::SetPythonCode(const py::object &code) {
79   MS_EXCEPTION_IF_CHECK_FAIL(code.ptr() != nullptr && PyCode_Check(code.ptr()) && Py_REFCNT(code.ptr()) == 1,
80                              "code handler must be only one");
81   compiled_code_ = code;
82 }
83 
GetPythonCode() const84 PyCodeObject *OptCode::GetPythonCode() const { return reinterpret_cast<PyCodeObject *>(compiled_code_.ptr()); }
85 
SetGuard(OptGuardPtr guard)86 void OptCode::SetGuard(OptGuardPtr guard) { guard_ = guard; }
87 
GetGuard()88 OptGuardPtr OptCode::GetGuard() { return guard_; }
89 
SetOption(OptOptionPtr option)90 void OptCode::SetOption(OptOptionPtr option) { option_ = option; }
91 
GetOption()92 OptOptionPtr OptCode::GetOption() { return option_; }
93 
GetPerf(OptPerf::PerfKind kind)94 OptPerfPtr OptCode::GetPerf(OptPerf::PerfKind kind) {
95   switch (kind) {
96     case OptPerf::PerfKind::kPerfGraph:
97       return graph_perf_;
98     case OptPerf::PerfKind::kPerfPyNative:
99       return pynative_perf_;
100     default:
101       return nullptr;
102   }
103 }
104 
Copy(OptCodePtr dst)105 void OptCode::Copy(OptCodePtr dst) {
106   dst->graph_perf_ = graph_perf_;
107   dst->pynative_perf_ = pynative_perf_;
108   dst->phase_ = phase_;
109   dst->compiled_func_ = compiled_func_;
110 }
111 
Inc()112 void OptCode::Inc() { call_count_++; }
113 
Count()114 uint64_t OptCode::Count() { return call_count_; }
115 
AddOptTarget(OptOptionPtr option)116 OptCodePtr OptCodeHub::AddOptTarget(OptOptionPtr option) {
117   OptCodePtr ret;
118   for (auto &item : codeMap_) {
119     if (*(item.first.get()) == *(option.get())) {
120       ret = std::make_shared<OptCode>();
121       item.second.push_back(ret);
122       return ret;
123     }
124   }
125   ret = std::make_shared<OptCode>();
126   codeMap_[option].push_back(ret);
127   ret->SetOption(option);
128   return ret;
129 }
130 
GetOptTarget(OptOptionPtr option)131 OptCodeSet OptCodeHub::GetOptTarget(OptOptionPtr option) {
132   for (auto &item : codeMap_) {
133     if (*(item.first.get()) == *(option.get())) {
134       return item.second;
135     }
136   }
137   return {};
138 }
139 
UpdateOptTarget(OptOptionPtr option,OptCodePtr code)140 void OptCodeHub::UpdateOptTarget(OptOptionPtr option, OptCodePtr code) {
141   for (auto &item : codeMap_) {
142     if (*(item.first.get()) == *(option.get())) {
143       auto it = std::find(item.second.begin(), item.second.end(), code);
144       if (it != item.second.end()) {
145         item.second.erase(it);
146         item.second.push_back(code);
147       }
148       break;
149     }
150   }
151 }
152 
DelOptTarget(OptOptionPtr option,OptCodePtr code)153 void OptCodeHub::DelOptTarget(OptOptionPtr option, OptCodePtr code) {
154   for (auto &item : codeMap_) {
155     if (*(item.first.get()) == *(option.get())) {
156       auto it = std::find(item.second.begin(), item.second.end(), code);
157       if (it != item.second.end()) {
158         item.second.erase(it);
159       }
160       if (item.second.size() == 0) {
161         codeMap_.erase(item.first);
162       }
163       break;
164     }
165   }
166 }
167 
DelOptTarget(OptCodePtr code)168 void OptCodeHub::DelOptTarget(OptCodePtr code) {
169   for (auto &item : codeMap_) {
170     auto it = std::find(item.second.begin(), item.second.end(), code);
171     if (it != item.second.end()) {
172       item.second.erase(it);
173       if (item.second.size() == 0) {
174         codeMap_.erase(item.first);
175       }
176       break;
177     }
178   }
179 }
180 
GetAllOptTarget()181 std::vector<OptCodeSet> OptCodeHub::GetAllOptTarget() {
182   std::vector<OptCodeSet> ret;
183   std::transform(codeMap_.begin(), codeMap_.end(), std::back_inserter(ret),
184                  [](const auto &item) { return item.second; });
185   return ret;
186 }
187 
188 using OptCodeWPtr = std::weak_ptr<OptCode>;
189 using OptCodeWSet = std::vector<OptCodeWPtr>;
190 static std::map<std::string, OptCodeWSet> code_set;
191 
Register(std::string key,OptCodePtr code)192 void OptCodeHub::Register(std::string key, OptCodePtr code) { code_set[key].emplace_back(code); }
Filter(std::string key,OptCodeFilterFunc filter)193 OptCodePtr OptCodeHub::Filter(std::string key, OptCodeFilterFunc filter) {
194   if (code_set.find(key) != code_set.end()) {
195     OptCodeWSet &codes = code_set[key];
196     for (size_t idx = 0; idx < codes.size();) {
197       OptCodePtr ptr = codes[idx].lock();
198       if (ptr != nullptr) {
199         if (filter(ptr)) {
200           return ptr;
201         }
202         idx++;
203       } else {
204         codes.erase(codes.begin() + idx);
205       }
206     }
207   }
208   return nullptr;
209 }
210 }  // namespace pijit
211 }  // namespace mindspore
212