1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/symbol_engine/jit/cpp_visitor.h"
17
18 #if !(defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER))
19 #include <dlfcn.h>
20
21 #include <array>
22 #include <cstdlib>
23 #include <ctime>
24 #include <fstream>
25 #include <vector>
26 #include <sstream>
27 #include <string>
28
29 #include "backend/common/graph_kernel/graph_kernel_flags.h"
30 #include "mindspore/core/symbolic_shape/symbol.h"
31 #include "kernel/framework_utils.h"
32 #include "include/common/debug/common.h"
33 #include "utils/file_utils.h"
34
35 namespace mindspore::graphkernel::symshape {
36 constexpr const char *kPrefix = "symbol_engine_jit_";
37 using ast::Shape, ast::BinOpType;
38
KernelMetaPath()39 static string KernelMetaPath() {
40 static string result = "";
41 if (result != "") {
42 return result;
43 }
44 auto config_path = kernel::GetCompilerCachePath();
45 auto kernel_meta_path = config_path + std::string(kernel::kAkgKernelMeta);
46 auto real_path = FileUtils::GetRealPath(kernel_meta_path.c_str());
47 if (!real_path.has_value()) {
48 MS_LOG(EXCEPTION) << "get real path failed: " << kernel_meta_path;
49 }
50 result = real_path.value() + "/";
51 return result;
52 }
53
CppVisitor()54 CppVisitor::CppVisitor() {
55 time_t t = time(nullptr);
56 name_ = "cppvisitor_" + std::to_string(t);
57 }
58
~CppVisitor()59 CppVisitor::~CppVisitor() {
60 if (dynlib_) {
61 dlclose(dynlib_);
62 }
63 }
64
CodeGen(const std::vector<ast::ShapePtr> & shapes,const ast::SymbolTable & symbol_table,const std::string & func_name)65 std::string CppVisitor::CodeGen(const std::vector<ast::ShapePtr> &shapes, const ast::SymbolTable &symbol_table,
66 const std::string &func_name) {
67 symbols_table_ = &symbol_table;
68 static int64_t func_idx = 1;
69 std::stringstream func;
70 std::string final_func_name = "func_" + std::to_string(func_idx) + "_" + func_name;
71 func_idx++;
72 var_tag_ = std::vector<int32_t>(symbols_table_->size(), 0);
73 // function implementation
74 func << "extern \"C\" void " << final_func_name << "(const int64_t **input, int64_t** res){\n";
75 // assemble shape expression
76 std::stringstream res_expr;
77 for (size_t i = 0; i < shapes.size(); ++i) {
78 const auto &shape = shapes[i];
79 for (size_t j = 0; j < shape->smbls_.size(); ++j) {
80 shape->smbls_[j]->Accept(this);
81 res_expr << "res[" << i << "][" << j << "] = " << cpp_sentences_.back() << ";\n";
82 cpp_sentences_.pop_back();
83 }
84 }
85 cpp_sentences_.push_back(res_expr.str());
86 for (auto &sentence : cpp_sentences_) {
87 func << sentence << '\n';
88 }
89 func << "\n}\n";
90 func_blocks_.push_back(func.str());
91 // clear shape_ and cpp_sentence;
92 cpp_sentences_.clear();
93 symbols_table_ = nullptr;
94 var_tag_.clear();
95 null_ = false;
96
97 return final_func_name;
98 }
99
Compile()100 void CppVisitor::Compile() {
101 if (null_) {
102 // skip compile if no function is generated
103 return;
104 }
105 MS_LOG(DEBUG) << "Start to compile cpp file used to infer shape";
106 compile_thread_ = std::thread(&CppVisitor::CompileImpl, this);
107 }
108
CompileImpl()109 void CppVisitor::CompileImpl() {
110 auto kernel_meta_path = KernelMetaPath();
111 (void)FileUtils::CreateNotExistDirs(kernel_meta_path);
112 std::string cpp_file_name(kernel_meta_path + kPrefix + name_ + ".cc");
113 auto real_filename = FileUtils::GetRealPath(cpp_file_name.c_str());
114 if (!real_filename.has_value()) {
115 MS_LOG(EXCEPTION) << "Failed to get real name for " << cpp_file_name;
116 }
117 MS_LOG(DEBUG) << "SymbolEngineJit c++ function saved to: " << real_filename.value();
118 std::ofstream cpp_file(real_filename.value());
119
120 // --- generate .cc file
121 const string header = R"(
122 #include <vector>
123 #include <cstdint>
124
125 )";
126
127 cpp_file << header;
128 for (auto &func : func_blocks_) {
129 cpp_file << func << "\n";
130 }
131 cpp_file.close();
132
133 // compile to dyn lib
134 std::stringstream cmd;
135 std::string so_name(kernel_meta_path + kPrefix + name_ + ".so");
136 cmd << "g++ -fPIC -shared -std=c++17 " << real_filename.value() << " -o " << so_name << " 2>&1";
137
138 // create library
139 constexpr size_t kBufferSize = 256;
140 std::array<char, kBufferSize> buffer{};
141 string result;
142 FILE *pipe = popen(cmd.str().c_str(), "r");
143 if (!pipe) {
144 MS_LOG(EXCEPTION) << "fail to run command to compile c++ code, error:" << strerror(errno);
145 return;
146 }
147 while (fgets(buffer.data(), kBufferSize, pipe)) {
148 result += buffer.data();
149 }
150 void(pclose(pipe));
151 if (!Common::FileExists(so_name)) {
152 MS_LOG(EXCEPTION) << "compile failed: no .so file: " << so_name << "\n Information from pipe: " << result;
153 }
154 MS_LOG(DEBUG) << "Finished compiling, information from pipe: \n" << result;
155 }
156
LoadFunc(const std::string & func_name)157 CppVisitor::DynFuncType CppVisitor::LoadFunc(const std::string &func_name) {
158 MS_LOG(DEBUG) << "CppVisitor trying to load function: " << func_name;
159 if (compile_thread_.joinable()) {
160 compile_thread_.join();
161 }
162 if (!dynlib_) {
163 auto so_name = FileUtils::GetRealPath((KernelMetaPath() + kPrefix + name_ + ".so").c_str());
164 if (!so_name.has_value()) {
165 MS_LOG(EXCEPTION) << "Failed to get real path for " << KernelMetaPath() << kPrefix << name_ << ".so";
166 }
167 dynlib_ = dlopen(so_name.value().c_str(), RTLD_LAZY);
168 if (!dynlib_) {
169 MS_LOG(EXCEPTION) << "Cannot open dynamic library " << so_name.value() << ".so :" << dlerror() << '\n';
170 }
171 }
172
173 auto fn = (DynFuncType)(dlsym(dynlib_, func_name.c_str()));
174 if (!fn) {
175 MS_LOG(EXCEPTION) << "Cannot find function " << func_name << " :" << dlerror() << '\n';
176 }
177 return fn;
178 }
179
Visit(const ast::IntImm & imm)180 void CppVisitor::Visit(const ast::IntImm &imm) { cpp_sentences_.push_back(std::to_string(imm.shape_int)); }
181
Visit(const ast::BinOp & op)182 void CppVisitor::Visit(const ast::BinOp &op) {
183 std::stringstream sentence;
184
185 std::string ope_string = "";
186 bool prefix = true;
187
188 switch (op.optype_) {
189 case BinOpType::ScalarMax:
190 ope_string = "std::max";
191 break;
192 case BinOpType::ScalarMin:
193 ope_string = "std::min";
194 break;
195 case BinOpType::ScalarDiv:
196 ope_string = "/";
197 prefix = false;
198 break;
199 case BinOpType::ScalarAdd:
200 ope_string = "+";
201 prefix = false;
202 break;
203 case BinOpType::ScalarSub:
204 ope_string = "-";
205 prefix = false;
206 break;
207 case BinOpType::ScalarMul:
208 ope_string = "*";
209 prefix = false;
210 break;
211 default:
212 MS_LOG(EXCEPTION) << "Unexpected operation";
213 break;
214 }
215
216 if (prefix) {
217 sentence << ope_string << "(";
218 op.a_->Accept(this);
219 sentence << cpp_sentences_.back() << ", ";
220 cpp_sentences_.pop_back();
221 op.b_->Accept(this);
222 sentence << cpp_sentences_.back() << ")";
223 cpp_sentences_.pop_back();
224
225 cpp_sentences_.push_back(sentence.str());
226 } else {
227 op.a_->Accept(this);
228 sentence << cpp_sentences_.back() << ope_string;
229 cpp_sentences_.pop_back();
230 op.b_->Accept(this);
231 sentence << cpp_sentences_.back();
232 cpp_sentences_.pop_back();
233 cpp_sentences_.push_back(sentence.str());
234 return;
235 }
236 }
237
Visit(const ast::Var & input_smbl)238 void CppVisitor::Visit(const ast::Var &input_smbl) {
239 if (var_tag_[input_smbl.id_]) {
240 cpp_sentences_.push_back(input_smbl.ToString());
241 return;
242 }
243
244 var_tag_[input_smbl.id_] = 1;
245 // assume no recurive call
246 auto smbl_p = (*symbols_table_)[input_smbl.id_];
247 std::stringstream sentence;
248 sentence << "int64_t " << input_smbl.ToString() << " = ";
249 smbl_p->Accept(this);
250 sentence << cpp_sentences_.back() << ";";
251 cpp_sentences_.pop_back();
252 cpp_sentences_.push_back(sentence.str());
253 cpp_sentences_.push_back(input_smbl.ToString());
254 }
255
Visit(const ast::Input & input_smbl)256 void CppVisitor::Visit(const ast::Input &input_smbl) {
257 std::stringstream sentence;
258 sentence << "input[" << input_smbl.i_ << "][" << input_smbl.j_ << "]";
259 cpp_sentences_.push_back(sentence.str());
260 }
261
Visit(const ast::Shape & shape)262 void CppVisitor::Visit(const ast::Shape &shape) {
263 MS_LOG(DEBUG) << "CppVisitor Visit a Shape: " << shape.ToString();
264 std::stringstream sentence;
265 std::string name = UniqueName();
266 sentence << "std::vector<int64_t> " << name << " {";
267
268 for (auto smbl : shape.smbls_) {
269 smbl->Accept(this);
270 sentence << cpp_sentences_.back() << ", ";
271 cpp_sentences_.pop_back();
272 }
273 if (!shape.smbls_.empty()) {
274 // remove the last ", "
275 constexpr int remove_len = 2;
276 sentence.seekp(-remove_len, sentence.cur);
277 }
278
279 sentence << "};";
280
281 cpp_sentences_.push_back(sentence.str());
282 cpp_sentences_.push_back(name);
283 }
284 } // namespace mindspore::graphkernel::symshape
285 #else
286 namespace mindspore::graphkernel::symshape {
287 using ast::Shape, ast::BinOpType;
CppVisitor()288 CppVisitor::CppVisitor() {}
~CppVisitor()289 CppVisitor::~CppVisitor() {}
CodeGen(const std::vector<ast::ShapePtr> &,const ast::SymbolTable &,const std::string &)290 std::string CppVisitor::CodeGen(const std::vector<ast::ShapePtr> &, const ast::SymbolTable &, const std::string &) {
291 return "";
292 }
Compile()293 void CppVisitor::Compile() {}
LoadFunc(const std::string &)294 CppVisitor::DynFuncType CppVisitor::LoadFunc(const std::string &) { return nullptr; }
Visit(const ast::IntImm & imm)295 void CppVisitor::Visit(const ast::IntImm &imm) {}
Visit(const ast::BinOp & op)296 void CppVisitor::Visit(const ast::BinOp &op) {}
Visit(const ast::Var & input_smbl)297 void CppVisitor::Visit(const ast::Var &input_smbl) {}
Visit(const ast::Input & input_smbl)298 void CppVisitor::Visit(const ast::Input &input_smbl) {}
Visit(const ast::Shape & shape)299 void CppVisitor::Visit(const ast::Shape &shape) {}
300 } // namespace mindspore::graphkernel::symshape
301 #endif
302