• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_FUNC_INLINER_H_
18 #define MINDSPORE_PI_JIT_FUNC_INLINER_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_mutator.h"
24 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_visitor.h"
25 #include "utils/convert_utils_base.h"
26 
27 namespace mindspore {
28 namespace pijit {
29 class FuncInlineDetector : public ir::IRVisitor {
30  public:
FuncInlineDetector(const ir::FunctionNodePtr & func)31   explicit FuncInlineDetector(const ir::FunctionNodePtr &func) : func_(func), index_(0), cur_root_node_(nullptr) {}
32   virtual ~FuncInlineDetector() = default;
33   void Run();
34 
35   void Visit_(const ir::FunctionNodePtr &node) override;
36   void Visit_(const ir::CallNodePtr &node) override;
37   size_t GetRootNodeIndex(const ir::CallNodePtr &node) const;
38   const ir::NodePtr &GetRootNode(const ir::CallNodePtr &node) const;
39 
40  private:
41   bool CanBeInlined(const ir::NodePtr &node) const;
42   void EvolvingFunction(const ir::FunctionNodePtr &func_node, const ir::NodePtrList args) const;
43 
44   const ir::FunctionNodePtr func_;
45   size_t index_;
46   ir::NodePtr cur_root_node_;
47   std::map<ir::CallNodePtr, size_t> node_2_index_;
48   std::map<ir::CallNodePtr, ir::NodePtr> node_2_root_;
49 };
50 
51 using FuncInlineDetectorPtr = std::shared_ptr<FuncInlineDetector>;
52 
53 class FuncLocalVarRenamer : public ir::IRVisitor {
54  public:
FuncLocalVarRenamer(const ir::FunctionNodePtr & func)55   explicit FuncLocalVarRenamer(const ir::FunctionNodePtr &func) : func_(func) {}
56   virtual ~FuncLocalVarRenamer() = default;
57   void Run();
58 
59   void Visit_(const ir::ParameterPtr &node) override;
60   void Visit_(const ir::ValuePtr &node) override;
61 
62  private:
63   const ir::FunctionNodePtr func_;
64 };
65 
66 using FuncLocalVarRenamerPtr = std::shared_ptr<FuncLocalVarRenamer>;
67 
68 class FuncParameterEliminator : public ir::IRMutator {
69  public:
FuncParameterEliminator(const ir::FunctionNodePtr & func,const ir::NodePtrList & args)70   explicit FuncParameterEliminator(const ir::FunctionNodePtr &func, const ir::NodePtrList &args)
71       : func_(func), args_(args) {}
72   virtual ~FuncParameterEliminator() = default;
73   void Run();
74 
75   ir::NodePtr Mutate_(const ir::ParameterPtr &node) override;
76   ir::NodePtr Mutate_(const ir::LoadValueNodePtr &node) override;
77   ir::NodePtr Mutate_(const ir::StoreNodePtr &node) override;
78 
79  private:
80   const ir::FunctionNodePtr func_;
81   const ir::NodePtrList args_;
82   std::map<std::string, ir::NodePtr> name_2_node_;
83 };
84 
85 using FuncParameterEliminatorPtr = std::shared_ptr<FuncParameterEliminator>;
86 
87 // FuncInliner to convert ir graph to function graph
88 class FuncInliner : public ir::IRMutator {
89  public:
FuncInliner(const ir::FunctionNodePtr & func)90   explicit FuncInliner(const ir::FunctionNodePtr &func)
91       : func_(func), detector_(std::make_shared<FuncInlineDetector>(func)), inserted_nodes_cnt_(0) {}
92   virtual ~FuncInliner() = default;
93   void Run();
94   void InsertSubFunction();
95 
96   // overloadable Mutate function.
97   ir::NodePtr Mutate_(const ir::CallNodePtr &node) override;
98 
99  private:
100   const ir::FunctionNodePtr func_;
101   const FuncInlineDetectorPtr detector_;
102   size_t inserted_nodes_cnt_;
103   std::map<size_t, ir::FunctionNodePtr> index_2_function_;
104 };
105 
106 using FuncInlinerPtr = std::shared_ptr<FuncInliner>;
107 }  // namespace pijit
108 }  // namespace mindspore
109 
110 #endif  // MINDSPORE_PI_JIT_FUNC_INLINER_H_
111