• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_CPP_VISITOR_H_
17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_CPP_VISITOR_H_
18 
19 #include <memory>
20 #include <sstream>
21 #include <string>
22 #include <thread>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "mindspore/core/symbolic_shape/symbol.h"
27 #include "mindspore/core/symbolic_shape/operation.h"
28 #include "mindspore/core/symbolic_shape/symbol_visitor.h"
29 #include "backend/common/graph_kernel/symbol_engine/jit/syntax.h"
30 
31 namespace mindspore::graphkernel::symshape {
32 class CppVisitor : public ast::Visitor {
33  public:
34   using DynFuncType = void (*)(const int64_t **, int64_t **);
35   CppVisitor();
CppVisitor(const std::string & name)36   explicit CppVisitor(const std::string &name) : name_(name) {}
37   ~CppVisitor();
38 
39   /// \brief Generate c++ function corresponding to the ast
40   /// \note func_name should be valid c++ function name
41   /// \return name of the function
42   std::string CodeGen(const std::vector<ast::ShapePtr> &shapes, const ast::SymbolTable &symbol_table,
43                       const std::string &func_name = "");
44   void Compile();
45   DynFuncType LoadFunc(const std::string &func_name);
46 
47   //------ override  ast::Visitor -----------------
48 
49   void Visit(const ast::IntImm &intimm) override;
50   void Visit(const ast::Var &intimm) override;
51   void Visit(const ast::BinOp &op) override;
52   void Visit(const ast::Shape &shape) override;
53   void Visit(const ast::Input &input_smbl) override;
54   // ------------------------------------------------
55 
UniqueName()56   std::string UniqueName() {
57     static size_t idx = 1;
58     return "s_" + std::to_string(idx++);
59   }
60 
61  protected:
62   // Do the actual compile work
63   void CompileImpl();
64 
65  public:
66   // for codegen
67   const ast::SymbolTable *symbols_table_ = nullptr;  // a map: id -> symbol
68   std::vector<std::string> cpp_sentences_;
69   std::vector<int32_t> var_tag_;  // indicate if Var already generated code
70   std::string name_;
71   std::vector<std::string> func_blocks_;  // store generated functions
72   std::thread compile_thread_;
73   bool null_ = true;  // indicate whether no code is generated
74 
75   // for dynamic library
76   void *dynlib_ = nullptr;
77 };
78 using CppVisitorPtr = std::shared_ptr<CppVisitor>;
79 }  // namespace mindspore::graphkernel::symshape
80 #endif  // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_SYMBOL_ENGINE_JIT_CPP_VISITOR_H_
81