1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_CPP_VISITOR_H_ 17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_CPP_VISITOR_H_ 18 19 #include <memory> 20 #include <sstream> 21 #include <string> 22 #include <thread> 23 #include <unordered_map> 24 #include <vector> 25 26 #include "mindspore/core/symbolic_shape/symbol.h" 27 #include "mindspore/core/symbolic_shape/operation.h" 28 #include "mindspore/core/symbolic_shape/symbol_visitor.h" 29 #include "backend/common/graph_kernel/symbol_engine/jit/syntax.h" 30 31 namespace mindspore::graphkernel::symshape { 32 class CppVisitor : public ast::Visitor { 33 public: 34 using DynFuncType = void (*)(const int64_t **, int64_t **); 35 CppVisitor(); CppVisitor(const std::string & name)36 explicit CppVisitor(const std::string &name) : name_(name) {} 37 ~CppVisitor(); 38 39 /// \brief Generate c++ function corresponding to the ast 40 /// \note func_name should be valid c++ function name 41 /// \return name of the function 42 std::string CodeGen(const std::vector<ast::ShapePtr> &shapes, const ast::SymbolTable &symbol_table, 43 const std::string &func_name = ""); 44 void Compile(); 45 DynFuncType LoadFunc(const std::string &func_name); 46 47 //------ override ast::Visitor ----------------- 48 49 void Visit(const ast::IntImm &intimm) override; 50 void Visit(const ast::Var &intimm) override; 51 void Visit(const ast::BinOp &op) override; 52 void Visit(const ast::Shape &shape) override; 53 void Visit(const ast::Input &input_smbl) override; 54 // ------------------------------------------------ 55 UniqueName()56 std::string UniqueName() { 57 static size_t idx = 1; 58 return "s_" + std::to_string(idx++); 59 } 60 61 protected: 62 // Do the actual compile work 63 void CompileImpl(); 64 65 public: 66 // for codegen 67 const ast::SymbolTable *symbols_table_ = nullptr; // a map: id -> symbol 68 std::vector<std::string> cpp_sentences_; 69 std::vector<int32_t> var_tag_; // indicate if Var already generated code 70 std::string name_; 71 std::vector<std::string> func_blocks_; // store generated functions 72 std::thread compile_thread_; 73 bool null_ = true; // indicate whether no code is generated 74 75 // for dynamic library 76 void *dynlib_ = nullptr; 77 }; 78 using CppVisitorPtr = std::shared_ptr<CppVisitor>; 79 } // namespace mindspore::graphkernel::symshape 80 #endif // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_CPP_VISITOR_H_ 81