• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/symbol_engine/jit/cpp_visitor.h"
17 
18 #if !(defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER))
19 #include <dlfcn.h>
20 
21 #include <array>
22 #include <cstdlib>
23 #include <ctime>
24 #include <fstream>
25 #include <vector>
26 #include <sstream>
27 #include <string>
28 
29 #include "backend/common/graph_kernel/graph_kernel_flags.h"
30 #include "mindspore/core/symbolic_shape/symbol.h"
31 #include "kernel/framework_utils.h"
32 #include "include/common/debug/common.h"
33 #include "utils/file_utils.h"
34 
35 namespace mindspore::graphkernel::symshape {
36 constexpr const char *kPrefix = "symbol_engine_jit_";
37 using ast::Shape, ast::BinOpType;
38 
KernelMetaPath()39 static string KernelMetaPath() {
40   static string result = "";
41   if (result != "") {
42     return result;
43   }
44   auto config_path = kernel::GetCompilerCachePath();
45   auto kernel_meta_path = config_path + std::string(kernel::kAkgKernelMeta);
46   auto real_path = FileUtils::GetRealPath(kernel_meta_path.c_str());
47   if (!real_path.has_value()) {
48     MS_LOG(EXCEPTION) << "get real path failed: " << kernel_meta_path;
49   }
50   result = real_path.value() + "/";
51   return result;
52 }
53 
CppVisitor()54 CppVisitor::CppVisitor() {
55   time_t t = time(nullptr);
56   name_ = "cppvisitor_" + std::to_string(t);
57 }
58 
~CppVisitor()59 CppVisitor::~CppVisitor() {
60   if (dynlib_) {
61     dlclose(dynlib_);
62   }
63 }
64 
CodeGen(const std::vector<ast::ShapePtr> & shapes,const ast::SymbolTable & symbol_table,const std::string & func_name)65 std::string CppVisitor::CodeGen(const std::vector<ast::ShapePtr> &shapes, const ast::SymbolTable &symbol_table,
66                                 const std::string &func_name) {
67   symbols_table_ = &symbol_table;
68   static int64_t func_idx = 1;
69   std::stringstream func;
70   std::string final_func_name = "func_" + std::to_string(func_idx) + "_" + func_name;
71   func_idx++;
72   var_tag_ = std::vector<int32_t>(symbols_table_->size(), 0);
73   // function implementation
74   func << "extern \"C\" void " << final_func_name << "(const int64_t **input, int64_t** res){\n";
75   //  assemble shape expression
76   std::stringstream res_expr;
77   for (size_t i = 0; i < shapes.size(); ++i) {
78     const auto &shape = shapes[i];
79     for (size_t j = 0; j < shape->smbls_.size(); ++j) {
80       shape->smbls_[j]->Accept(this);
81       res_expr << "res[" << i << "][" << j << "] = " << cpp_sentences_.back() << ";\n";
82       cpp_sentences_.pop_back();
83     }
84   }
85   cpp_sentences_.push_back(res_expr.str());
86   for (auto &sentence : cpp_sentences_) {
87     func << sentence << '\n';
88   }
89   func << "\n}\n";
90   func_blocks_.push_back(func.str());
91   // clear shape_ and cpp_sentence;
92   cpp_sentences_.clear();
93   symbols_table_ = nullptr;
94   var_tag_.clear();
95   null_ = false;
96 
97   return final_func_name;
98 }
99 
Compile()100 void CppVisitor::Compile() {
101   if (null_) {
102     // skip compile if no function is generated
103     return;
104   }
105   MS_LOG(DEBUG) << "Start to compile cpp file used to infer shape";
106   compile_thread_ = std::thread(&CppVisitor::CompileImpl, this);
107 }
108 
CompileImpl()109 void CppVisitor::CompileImpl() {
110   auto kernel_meta_path = KernelMetaPath();
111   (void)FileUtils::CreateNotExistDirs(kernel_meta_path);
112   std::string cpp_file_name(kernel_meta_path + kPrefix + name_ + ".cc");
113   auto real_filename = FileUtils::GetRealPath(cpp_file_name.c_str());
114   if (!real_filename.has_value()) {
115     MS_LOG(EXCEPTION) << "Failed to get real name for " << cpp_file_name;
116   }
117   MS_LOG(DEBUG) << "SymbolEngineJit c++ function saved to: " << real_filename.value();
118   std::ofstream cpp_file(real_filename.value());
119 
120   // --- generate  .cc file
121   const string header = R"(
122 #include <vector>
123 #include <cstdint>
124 
125 )";
126 
127   cpp_file << header;
128   for (auto &func : func_blocks_) {
129     cpp_file << func << "\n";
130   }
131   cpp_file.close();
132 
133   // compile to dyn lib
134   std::stringstream cmd;
135   std::string so_name(kernel_meta_path + kPrefix + name_ + ".so");
136   cmd << "g++ -fPIC -shared -std=c++17 " << real_filename.value() << " -o " << so_name << " 2>&1";
137 
138   // create library
139   constexpr size_t kBufferSize = 256;
140   std::array<char, kBufferSize> buffer{};
141   string result;
142   FILE *pipe = popen(cmd.str().c_str(), "r");
143   if (!pipe) {
144     MS_LOG(EXCEPTION) << "fail to run command to compile c++ code, error:" << strerror(errno);
145     return;
146   }
147   while (fgets(buffer.data(), kBufferSize, pipe)) {
148     result += buffer.data();
149   }
150   void(pclose(pipe));
151   if (!Common::FileExists(so_name)) {
152     MS_LOG(EXCEPTION) << "compile failed: no .so file: " << so_name << "\n Information from pipe: " << result;
153   }
154   MS_LOG(DEBUG) << "Finished compiling, information from pipe: \n" << result;
155 }
156 
LoadFunc(const std::string & func_name)157 CppVisitor::DynFuncType CppVisitor::LoadFunc(const std::string &func_name) {
158   MS_LOG(DEBUG) << "CppVisitor trying to load function: " << func_name;
159   if (compile_thread_.joinable()) {
160     compile_thread_.join();
161   }
162   if (!dynlib_) {
163     auto so_name = FileUtils::GetRealPath((KernelMetaPath() + kPrefix + name_ + ".so").c_str());
164     if (!so_name.has_value()) {
165       MS_LOG(EXCEPTION) << "Failed to get real path for " << KernelMetaPath() << kPrefix << name_ << ".so";
166     }
167     dynlib_ = dlopen(so_name.value().c_str(), RTLD_LAZY);
168     if (!dynlib_) {
169       MS_LOG(EXCEPTION) << "Cannot open dynamic library " << so_name.value() << ".so :" << dlerror() << '\n';
170     }
171   }
172 
173   auto fn = (DynFuncType)(dlsym(dynlib_, func_name.c_str()));
174   if (!fn) {
175     MS_LOG(EXCEPTION) << "Cannot find function " << func_name << " :" << dlerror() << '\n';
176   }
177   return fn;
178 }
179 
Visit(const ast::IntImm & imm)180 void CppVisitor::Visit(const ast::IntImm &imm) { cpp_sentences_.push_back(std::to_string(imm.shape_int)); }
181 
Visit(const ast::BinOp & op)182 void CppVisitor::Visit(const ast::BinOp &op) {
183   std::stringstream sentence;
184 
185   std::string ope_string = "";
186   bool prefix = true;
187 
188   switch (op.optype_) {
189     case BinOpType::ScalarMax:
190       ope_string = "std::max";
191       break;
192     case BinOpType::ScalarMin:
193       ope_string = "std::min";
194       break;
195     case BinOpType::ScalarDiv:
196       ope_string = "/";
197       prefix = false;
198       break;
199     case BinOpType::ScalarAdd:
200       ope_string = "+";
201       prefix = false;
202       break;
203     case BinOpType::ScalarSub:
204       ope_string = "-";
205       prefix = false;
206       break;
207     case BinOpType::ScalarMul:
208       ope_string = "*";
209       prefix = false;
210       break;
211     default:
212       MS_LOG(EXCEPTION) << "Unexpected operation";
213       break;
214   }
215 
216   if (prefix) {
217     sentence << ope_string << "(";
218     op.a_->Accept(this);
219     sentence << cpp_sentences_.back() << ", ";
220     cpp_sentences_.pop_back();
221     op.b_->Accept(this);
222     sentence << cpp_sentences_.back() << ")";
223     cpp_sentences_.pop_back();
224 
225     cpp_sentences_.push_back(sentence.str());
226   } else {
227     op.a_->Accept(this);
228     sentence << cpp_sentences_.back() << ope_string;
229     cpp_sentences_.pop_back();
230     op.b_->Accept(this);
231     sentence << cpp_sentences_.back();
232     cpp_sentences_.pop_back();
233     cpp_sentences_.push_back(sentence.str());
234     return;
235   }
236 }
237 
Visit(const ast::Var & input_smbl)238 void CppVisitor::Visit(const ast::Var &input_smbl) {
239   if (var_tag_[input_smbl.id_]) {
240     cpp_sentences_.push_back(input_smbl.ToString());
241     return;
242   }
243 
244   var_tag_[input_smbl.id_] = 1;
245   // assume no recurive call
246   auto smbl_p = (*symbols_table_)[input_smbl.id_];
247   std::stringstream sentence;
248   sentence << "int64_t " << input_smbl.ToString() << " = ";
249   smbl_p->Accept(this);
250   sentence << cpp_sentences_.back() << ";";
251   cpp_sentences_.pop_back();
252   cpp_sentences_.push_back(sentence.str());
253   cpp_sentences_.push_back(input_smbl.ToString());
254 }
255 
Visit(const ast::Input & input_smbl)256 void CppVisitor::Visit(const ast::Input &input_smbl) {
257   std::stringstream sentence;
258   sentence << "input[" << input_smbl.i_ << "][" << input_smbl.j_ << "]";
259   cpp_sentences_.push_back(sentence.str());
260 }
261 
Visit(const ast::Shape & shape)262 void CppVisitor::Visit(const ast::Shape &shape) {
263   MS_LOG(DEBUG) << "CppVisitor Visit a Shape: " << shape.ToString();
264   std::stringstream sentence;
265   std::string name = UniqueName();
266   sentence << "std::vector<int64_t> " << name << " {";
267 
268   for (auto smbl : shape.smbls_) {
269     smbl->Accept(this);
270     sentence << cpp_sentences_.back() << ", ";
271     cpp_sentences_.pop_back();
272   }
273   if (!shape.smbls_.empty()) {
274     // remove the last ", "
275     constexpr int remove_len = 2;
276     sentence.seekp(-remove_len, sentence.cur);
277   }
278 
279   sentence << "};";
280 
281   cpp_sentences_.push_back(sentence.str());
282   cpp_sentences_.push_back(name);
283 }
284 }  // namespace mindspore::graphkernel::symshape
285 #else
286 namespace mindspore::graphkernel::symshape {
287 using ast::Shape, ast::BinOpType;
CppVisitor()288 CppVisitor::CppVisitor() {}
~CppVisitor()289 CppVisitor::~CppVisitor() {}
CodeGen(const std::vector<ast::ShapePtr> &,const ast::SymbolTable &,const std::string &)290 std::string CppVisitor::CodeGen(const std::vector<ast::ShapePtr> &, const ast::SymbolTable &, const std::string &) {
291   return "";
292 }
Compile()293 void CppVisitor::Compile() {}
LoadFunc(const std::string &)294 CppVisitor::DynFuncType CppVisitor::LoadFunc(const std::string &) { return nullptr; }
Visit(const ast::IntImm & imm)295 void CppVisitor::Visit(const ast::IntImm &imm) {}
Visit(const ast::BinOp & op)296 void CppVisitor::Visit(const ast::BinOp &op) {}
Visit(const ast::Var & input_smbl)297 void CppVisitor::Visit(const ast::Var &input_smbl) {}
Visit(const ast::Input & input_smbl)298 void CppVisitor::Visit(const ast::Input &input_smbl) {}
Visit(const ast::Shape & shape)299 void CppVisitor::Visit(const ast::Shape &shape) {}
300 }  // namespace mindspore::graphkernel::symshape
301 #endif
302