• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_NODE_H
17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_NODE_H
18 
19 #include <string>
20 #include <map>
21 #include <vector>
22 #include <algorithm>
23 #include <utility>
24 #include <optional>
25 #include <memory>
26 #include "utils/log_adapter.h"
27 #include "pipeline/jit/pi/graph_capture/abstract_object.h"
28 #include "pipeline/jit/pi/graph_capture/constant_info.h"
29 #include "pipeline/jit/pi/utils/utils.h"
30 #include "pipeline/jit/pi/graph_guard/trace.h"
31 
32 namespace mindspore {
33 namespace pijit {
34 class Graph;
35 class Block;
36 
37 class AbstractNode {
38  public:
39   enum Type {
40     Abstract,
41     kInstr,
42     Value,
43     Call,      // call node, it is also a value produced operation
44     Param,     // parameter value node
45     CellVar,   // cell value node
46     FreeVar,   // free value node
47     kUnbound,  // unbound value node
48   };
AbstractNode(Type t)49   explicit AbstractNode(Type t) : type_(t), graph_(nullptr), block_(nullptr), marker_(0) {}
~AbstractNode()50   virtual ~AbstractNode() {}
51 
GetType()52   Type GetType() const { return type_; }
GetGraph()53   Graph *GetGraph() const { return graph_; }
SetGraph(Graph * g)54   void SetGraph(Graph *g) { graph_ = g; }
GetBlock()55   Block *GetBlock() { return block_; }
SetBlock(Block * b)56   void SetBlock(Block *b) { block_ = b; }
57 
58   virtual std::string ToString() const;
59 
60  private:
61   const Type type_;
62   Graph *graph_;
63   Block *block_;
64 
65  public:
66   // remove it
67   int marker_;  // for visit
68 };
69 
70 class InstrNode : public AbstractNode {
71  public:
InstrNode(int op,int arg)72   InstrNode(int op, int arg) : AbstractNode(kInstr), op_(op), arg_(arg) {}
~InstrNode()73   virtual ~InstrNode() {}
GetOpcode()74   int GetOpcode() const { return op_; }
GetOparg()75   int GetOparg() const { return arg_; }
GetLineNo()76   int GetLineNo() const { return line_; }
SetOparg(int arg)77   void SetOparg(int arg) { this->arg_ = arg; }
SetOpcode(int op)78   void SetOpcode(int op) { this->op_ = op; }
SetLineNo(int l)79   void SetLineNo(int l) { this->line_ = l; }
SetName(const std::string & n)80   void SetName(const std::string &n) { name_ = n; }
GetName()81   const std::string &GetName() const { return name_; }
82   std::string ToString() const override;
83 
bci()84   int bci() const { return bci_; }
set_bci(int i)85   void set_bci(int i) { bci_ = i; }
86 
87  protected:
InstrNode(Type t,int op,int arg)88   InstrNode(Type t, int op, int arg) : AbstractNode(t), op_(op), arg_(arg), line_(-1) {}
89 
90  private:
91   int bci_ = -1;
92   int op_;
93   int arg_;
94   int line_ = -1;
95   std::string name_;
96 };
97 
98 class ValueNode : public InstrNode {
99  public:
100   static ValueNode kUnboundLocal;
101 
102   ValueNode(AObject *vobj, int opcode, int oparg, const std::vector<ValueNode *> &inputs = {})
InstrNode(Value,opcode,oparg)103       : InstrNode(Value, opcode, oparg), vobj_(vobj), inputs_(inputs) {}
~ValueNode()104   virtual ~ValueNode() {}
105 
getInputs()106   std::vector<ValueNode *> &getInputs() { return inputs_; }
getInputs()107   const std::vector<ValueNode *> &getInputs() const { return inputs_; }
input(int i)108   ValueNode *input(int i) const { return inputs_[i]; }
AddInput(ValueNode * v)109   void AddInput(ValueNode *v) { inputs_.push_back(v); }
ClearInputs()110   void ClearInputs() { inputs_.clear(); }
111 
112   void SetVobj(AObject *vobj);
GetVobj()113   const auto &GetVobj() const { return vobj_; }
114 
115   AObject *get_attr(const std::string &nam);
116 
117   AObject *binary_subscr(ValueNode *sub);
118 
119   std::string ToString() const override;
GetParent()120   ValueNode *GetParent() { return parent_.value_or(nullptr); }
121   void SetParent(ValueNode *parent);
122 
123   bool IsConstantValue() const;
124   void SetConstantValue(bool constant);
125   const std::unique_ptr<ConstantInfo> &MakeConstantInfo();
GetConstantInfo()126   const std::unique_ptr<ConstantInfo> &GetConstantInfo() const { return constant_info_; }
127 
GetTrace()128   TracePtr GetTrace() { return trace_; }
SetTrace(TracePtr t)129   void SetTrace(TracePtr t) { trace_ = t; }
130 
131  protected:
132   ValueNode(Type type, AObject *vobj, int opcode, int oparg, const std::vector<ValueNode *> &inputs = {})
InstrNode(type,opcode,oparg)133       : InstrNode(type, opcode, oparg), vobj_(vobj), inputs_(inputs) {}
134 
135  private:
136   // value info
137   AObject *vobj_;
138 
139   // constant info
140   std::unique_ptr<ConstantInfo> constant_info_;
141 
142   // which nodes are used, ordered parameter
143   std::vector<ValueNode *> inputs_;
144 
145   // recode relationship between local and CallNode
146   std::optional<ValueNode *> parent_;
147 
148   // Trace cache to be reused
149   TracePtr trace_;
150 };
151 
152 // simulate PyCellObject, oparg is index
153 class CellVarNode : public ValueNode {
154  public:
CellVarNode(Type t)155   explicit CellVarNode(Type t) : ValueNode(t, nullptr, -1, CO_CELL_NOT_AN_ARG), val_(nullptr) {}
SetFromParam(int i)156   void SetFromParam(int i) { SetOparg(i); }
GetFromParam()157   int GetFromParam() const { return GetOparg(); }
SetIndex(int i)158   void SetIndex(int i) { return SetOpcode(i); }
GetIndex()159   int GetIndex() const { return GetOpcode(); }
GetValue()160   auto GetValue() const { return val_; }
SetValue(ValueNode * v)161   void SetValue(ValueNode *v) { val_ = v; }
GetCellOper()162   const auto &GetCellOper() const { return cell_oper_; }
GetCellOper()163   auto &GetCellOper() { return cell_oper_; }
AddCellOper(ValueNode * i)164   void AddCellOper(ValueNode *i) { cell_oper_.push_back(i); }
~CellVarNode()165   virtual ~CellVarNode() {}
166   std::string ToString() const override;
167 
168  private:
169   ValueNode *val_;
170   std::vector<ValueNode *> cell_oper_;  // record cell operation
171 };
172 
173 class ParamNode : public ValueNode {
174  public:
ParamNode(AObject * o,int index)175   ParamNode(AObject *o, int index) : ValueNode(Param, o, 0, index, {}) {}
176   std::string ToString() const override;
IsMixedPrecisionType()177   bool IsMixedPrecisionType() { return mixedPrecisionType_ != nullptr; }
GetMixedPrecisionType()178   PyObject *GetMixedPrecisionType() { return mixedPrecisionType_; }
SetMixedPrecisionType(PyObject * type)179   void SetMixedPrecisionType(PyObject *type) { mixedPrecisionType_ = type; }
~ParamNode()180   virtual ~ParamNode() {}
181 
182  protected:
183   PyObject *mixedPrecisionType_{nullptr};
184 };
185 
186 class CallNode : public ValueNode {
187  public:
CallNode(int opcode,int oparg,const std::vector<ValueNode * > & inputs)188   CallNode(int opcode, int oparg, const std::vector<ValueNode *> &inputs)
189       : ValueNode(Call, nullptr, opcode, oparg, inputs), sub_graph_(nullptr) {}
~CallNode()190   virtual ~CallNode() {}
191 
GetSubGraph()192   Graph *GetSubGraph() const { return sub_graph_; }
193   void SetSubGraph(Graph *n);
194   std::string ToString() const override;
SetInlineReason(InlineReason r)195   void SetInlineReason(InlineReason r) { reason_ = r; }
GetInlineReason()196   InlineReason GetInlineReason() { return reason_; }
197 
AddParam(ValueNode * p)198   void AddParam(ValueNode *p) {
199     params_.push_back(p);
200     if (p) {
201       p->SetParent(this);
202     }
203   }
204 
GetParams()205   const auto &GetParams() const { return params_; }
GetArgs()206   std::vector<py::object> GetArgs() {
207     std::vector<py::object> args;
208     std::transform(getInputs().begin() + 1, getInputs().end(), std::back_inserter(args),
209                    [](ValueNode *n) { return n->GetVobj() ? n->GetVobj()->GetPyObject() : py::object(); });
210     return args;
211   }
212 
213  private:
214   // sub-graph if traced function
215   Graph *sub_graph_;
216 
217   InlineReason reason_ = InlineReason::kInlineUnknown;
218 
219   std::vector<ValueNode *> params_;  // extra values for inline function
220 };
221 
222 bool IsNonLocalValue(ValueNode *i);
223 }  // namespace pijit
224 }  // namespace mindspore
225 
226 #endif  // MINDSPORE_PI_JIT_GRAPH_CAPTURE_NODE_H
227