• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_PI_JIT_FUNC_GRAPH_BUILDER_H_
18 #define MINDSPORE_PI_JIT_FUNC_GRAPH_BUILDER_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include "ir/func_graph.h"
25 #include "ir/value.h"
26 #include "pipeline/jit/pi/graph_compiler/pi_ir/ir_mutator.h"
27 #include "utils/trace_info.h"
28 
29 namespace mindspore::pijit {
30 namespace py = pybind11;
31 
32 class MindNode : public ir::Node {
33  public:
MindNode(const AnfNodePtr & node)34   explicit MindNode(const AnfNodePtr &node) : node_(node) {}
35 
36   // Destructor.
37   ~MindNode() override = default;
38   JIT_DECLARE_PARENT(MindNode, Node);
39 
GetAnfNode()40   const AnfNodePtr &GetAnfNode() const { return node_; }
41 
42   /**
43    * \brief Get the description of this Mind Node.
44    * \return The description.
45    */
ToString()46   std::string ToString() const override { return node_->DebugString(); }
47 
48  private:
49   AnfNodePtr node_;
50 };
51 
52 using MindNodePtr = std::shared_ptr<MindNode>;
53 
54 // FuncGraphBuilder to convert ir graph to function graph
55 class FuncGraphBuilder : public ir::IRMutator {
56  public:
FuncGraphBuilder(const ir::FunctionNodePtr & func)57   explicit FuncGraphBuilder(const ir::FunctionNodePtr &func) : FuncGraphBuilder(func, {}, NewValueNode(kNone)) {}
FuncGraphBuilder(const ir::FunctionNodePtr & func,const AnfNodePtrList & args,const AnfNodePtr & kwargs)58   FuncGraphBuilder(const ir::FunctionNodePtr &func, const AnfNodePtrList &args, const AnfNodePtr &kwargs)
59       : func_(func),
60         args_(args),
61         kwargs_(kwargs),
62         func_graph_(std::make_shared<FuncGraph>()),
63         last_line_no_(func->GetFirstLineNo()) {}
64   virtual ~FuncGraphBuilder() = default;
65   static FuncGraphPtr BuildFuncGraph(const ir::FunctionNodePtr &func, const py::tuple &args, const py::dict &kwargs);
66   static FuncGraphPtr BuildFuncGraph(const ir::FunctionNodePtr &func, const AnfNodePtrList &args,
67                                      const AnfNodePtr &kwargs);
68 
69   // overloadable Mutate function.
70   ir::NodePtr Mutate_(const ir::RefNodePtr &node) override;
71   ir::NodePtr Mutate_(const ir::ParameterPtr &node) override;
72   ir::NodePtr Mutate_(const ir::FunctionNodePtr &node) override;
73   ir::NodePtr Mutate_(const ir::ValuePtr &node) override;
74   ir::NodePtr Mutate_(const ir::IfNodePtr &node) override;
75   ir::NodePtr Mutate_(const ir::BinaryOperationPtr &node) override;
76   ir::NodePtr Mutate_(const ir::NegativeNodePtr &node) override;
77   ir::NodePtr Mutate_(const ir::NotNodePtr &node) override;
78   ir::NodePtr Mutate_(const ir::InvertNodePtr &node) override;
79   ir::NodePtr Mutate_(const ir::ReturnNodePtr &node) override;
80   ir::NodePtr Mutate_(const ir::CastNodePtr &node) override;
81   ir::NodePtr Mutate_(const ir::FormatNodePtr &node) override;
82   ir::NodePtr Mutate_(const ir::AddNodePtr &node) override;
83   ir::NodePtr Mutate_(const ir::SubNodePtr &node) override;
84   ir::NodePtr Mutate_(const ir::MulNodePtr &node) override;
85   ir::NodePtr Mutate_(const ir::DivNodePtr &node) override;
86   ir::NodePtr Mutate_(const ir::BitwiseNodePtr &node) override;
87   ir::NodePtr Mutate_(const ir::IsNodePtr &node) override;
88   ir::NodePtr Mutate_(const ir::ContainsNodePtr &node) override;
89   ir::NodePtr Mutate_(const ir::StoreNodePtr &node) override;
90   ir::NodePtr Mutate_(const ir::CompareNodePtr &node) override;
91   ir::NodePtr Mutate_(const ir::LoadValueNodePtr &node) override;
92   ir::NodePtr Mutate_(const ir::LoadFieldNodePtr &node) override;
93   ir::NodePtr Mutate_(const ir::BuildNodePtr &node) override;
94   ir::NodePtr Mutate_(const ir::CallNodePtr &node) override;
95   ir::NodePtr Mutate_(const ir::UpdateNodePtr &node) override;
96   ir::NodePtr Mutate_(const ir::SubscrNodePtr &node) override;
97   ir::NodePtr Mutate_(const ir::AttrNodePtr &node) override;
98 
99  private:
100   void UpdateLocation(const AnfNodePtr &anf_node, const ir::NodePtr &node);
101   AnfNodePtr ConvertListOrTupleToCNode(const py::object &obj);
102   AnfNodePtr GetAnfNode(const ir::NodePtr &node);
103   AnfNodePtr MergeList(const AnfNodePtr &left, const AnfNodePtr &right);
104   std::pair<AnfNodePtrList, AnfNodePtrList> GetKeysAndValueOfDict(const AnfNodePtr &node);
105   AnfNodePtr MergeDict(const AnfNodePtr &left, const AnfNodePtr &right);
106 
107   const ir::FunctionNodePtr func_;
108   AnfNodePtrList args_;
109   AnfNodePtr kwargs_;
110   FuncGraphPtr func_graph_;
111   int last_line_no_;
112   bool enable_debug_info_{false};
113 
114   // Store variable's name, variable's node.
115   std::map<std::string, AnfNodePtr> assigned_vars_;
116 };
117 
118 using FuncGraphBuilderPtr = std::shared_ptr<FuncGraphBuilder>;
119 }  // namespace mindspore::pijit
120 
121 #endif  // MINDSPORE_PI_JIT_FUNC_GRAPH_BUILDER_H_
122