1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_PI_JIT_GRAPH_CAPTURE_NODE_H 17 #define MINDSPORE_PI_JIT_GRAPH_CAPTURE_NODE_H 18 19 #include <string> 20 #include <map> 21 #include <vector> 22 #include <algorithm> 23 #include <utility> 24 #include <optional> 25 #include <memory> 26 #include "utils/log_adapter.h" 27 #include "pipeline/jit/pi/graph_capture/abstract_object.h" 28 #include "pipeline/jit/pi/graph_capture/constant_info.h" 29 #include "pipeline/jit/pi/utils/utils.h" 30 #include "pipeline/jit/pi/graph_guard/trace.h" 31 32 namespace mindspore { 33 namespace pijit { 34 class Graph; 35 class Block; 36 37 class AbstractNode { 38 public: 39 enum Type { 40 Abstract, 41 kInstr, 42 Value, 43 Call, // call node, it is also a value produced operation 44 Param, // parameter value node 45 CellVar, // cell value node 46 FreeVar, // free value node 47 kUnbound, // unbound value node 48 }; AbstractNode(Type t)49 explicit AbstractNode(Type t) : type_(t), graph_(nullptr), block_(nullptr), marker_(0) {} ~AbstractNode()50 virtual ~AbstractNode() {} 51 GetType()52 Type GetType() const { return type_; } GetGraph()53 Graph *GetGraph() const { return graph_; } SetGraph(Graph * g)54 void SetGraph(Graph *g) { graph_ = g; } GetBlock()55 Block *GetBlock() { return block_; } SetBlock(Block * b)56 void SetBlock(Block *b) { block_ = b; } 57 58 virtual std::string ToString() const; 59 60 private: 61 const Type type_; 62 Graph *graph_; 63 Block *block_; 64 65 public: 66 // remove it 67 int marker_; // for visit 68 }; 69 70 class InstrNode : public AbstractNode { 71 public: InstrNode(int op,int arg)72 InstrNode(int op, int arg) : AbstractNode(kInstr), op_(op), arg_(arg) {} ~InstrNode()73 virtual ~InstrNode() {} GetOpcode()74 int GetOpcode() const { return op_; } GetOparg()75 int GetOparg() const { return arg_; } GetLineNo()76 int GetLineNo() const { return line_; } SetOparg(int arg)77 void SetOparg(int arg) { this->arg_ = arg; } SetOpcode(int op)78 void SetOpcode(int op) { this->op_ = op; } SetLineNo(int l)79 void SetLineNo(int l) { this->line_ = l; } SetName(const std::string & n)80 void SetName(const std::string &n) { name_ = n; } GetName()81 const std::string &GetName() const { return name_; } 82 std::string ToString() const override; 83 bci()84 int bci() const { return bci_; } set_bci(int i)85 void set_bci(int i) { bci_ = i; } 86 87 protected: InstrNode(Type t,int op,int arg)88 InstrNode(Type t, int op, int arg) : AbstractNode(t), op_(op), arg_(arg), line_(-1) {} 89 90 private: 91 int bci_ = -1; 92 int op_; 93 int arg_; 94 int line_ = -1; 95 std::string name_; 96 }; 97 98 class ValueNode : public InstrNode { 99 public: 100 static ValueNode kUnboundLocal; 101 102 ValueNode(AObject *vobj, int opcode, int oparg, const std::vector<ValueNode *> &inputs = {}) InstrNode(Value,opcode,oparg)103 : InstrNode(Value, opcode, oparg), vobj_(vobj), inputs_(inputs) {} ~ValueNode()104 virtual ~ValueNode() {} 105 getInputs()106 std::vector<ValueNode *> &getInputs() { return inputs_; } getInputs()107 const std::vector<ValueNode *> &getInputs() const { return inputs_; } input(int i)108 ValueNode *input(int i) const { return inputs_[i]; } AddInput(ValueNode * v)109 void AddInput(ValueNode *v) { inputs_.push_back(v); } ClearInputs()110 void ClearInputs() { inputs_.clear(); } 111 112 void SetVobj(AObject *vobj); GetVobj()113 const auto &GetVobj() const { return vobj_; } 114 115 AObject *get_attr(const std::string &nam); 116 117 AObject *binary_subscr(ValueNode *sub); 118 119 std::string ToString() const override; GetParent()120 ValueNode *GetParent() { return parent_.value_or(nullptr); } 121 void SetParent(ValueNode *parent); 122 123 bool IsConstantValue() const; 124 void SetConstantValue(bool constant); 125 const std::unique_ptr<ConstantInfo> &MakeConstantInfo(); GetConstantInfo()126 const std::unique_ptr<ConstantInfo> &GetConstantInfo() const { return constant_info_; } 127 GetTrace()128 TracePtr GetTrace() { return trace_; } SetTrace(TracePtr t)129 void SetTrace(TracePtr t) { trace_ = t; } 130 131 protected: 132 ValueNode(Type type, AObject *vobj, int opcode, int oparg, const std::vector<ValueNode *> &inputs = {}) InstrNode(type,opcode,oparg)133 : InstrNode(type, opcode, oparg), vobj_(vobj), inputs_(inputs) {} 134 135 private: 136 // value info 137 AObject *vobj_; 138 139 // constant info 140 std::unique_ptr<ConstantInfo> constant_info_; 141 142 // which nodes are used, ordered parameter 143 std::vector<ValueNode *> inputs_; 144 145 // recode relationship between local and CallNode 146 std::optional<ValueNode *> parent_; 147 148 // Trace cache to be reused 149 TracePtr trace_; 150 }; 151 152 // simulate PyCellObject, oparg is index 153 class CellVarNode : public ValueNode { 154 public: CellVarNode(Type t)155 explicit CellVarNode(Type t) : ValueNode(t, nullptr, -1, CO_CELL_NOT_AN_ARG), val_(nullptr) {} SetFromParam(int i)156 void SetFromParam(int i) { SetOparg(i); } GetFromParam()157 int GetFromParam() const { return GetOparg(); } SetIndex(int i)158 void SetIndex(int i) { return SetOpcode(i); } GetIndex()159 int GetIndex() const { return GetOpcode(); } GetValue()160 auto GetValue() const { return val_; } SetValue(ValueNode * v)161 void SetValue(ValueNode *v) { val_ = v; } GetCellOper()162 const auto &GetCellOper() const { return cell_oper_; } GetCellOper()163 auto &GetCellOper() { return cell_oper_; } AddCellOper(ValueNode * i)164 void AddCellOper(ValueNode *i) { cell_oper_.push_back(i); } ~CellVarNode()165 virtual ~CellVarNode() {} 166 std::string ToString() const override; 167 168 private: 169 ValueNode *val_; 170 std::vector<ValueNode *> cell_oper_; // record cell operation 171 }; 172 173 class ParamNode : public ValueNode { 174 public: ParamNode(AObject * o,int index)175 ParamNode(AObject *o, int index) : ValueNode(Param, o, 0, index, {}) {} 176 std::string ToString() const override; IsMixedPrecisionType()177 bool IsMixedPrecisionType() { return mixedPrecisionType_ != nullptr; } GetMixedPrecisionType()178 PyObject *GetMixedPrecisionType() { return mixedPrecisionType_; } SetMixedPrecisionType(PyObject * type)179 void SetMixedPrecisionType(PyObject *type) { mixedPrecisionType_ = type; } ~ParamNode()180 virtual ~ParamNode() {} 181 182 protected: 183 PyObject *mixedPrecisionType_{nullptr}; 184 }; 185 186 class CallNode : public ValueNode { 187 public: CallNode(int opcode,int oparg,const std::vector<ValueNode * > & inputs)188 CallNode(int opcode, int oparg, const std::vector<ValueNode *> &inputs) 189 : ValueNode(Call, nullptr, opcode, oparg, inputs), sub_graph_(nullptr) {} ~CallNode()190 virtual ~CallNode() {} 191 GetSubGraph()192 Graph *GetSubGraph() const { return sub_graph_; } 193 void SetSubGraph(Graph *n); 194 std::string ToString() const override; SetInlineReason(InlineReason r)195 void SetInlineReason(InlineReason r) { reason_ = r; } GetInlineReason()196 InlineReason GetInlineReason() { return reason_; } 197 AddParam(ValueNode * p)198 void AddParam(ValueNode *p) { 199 params_.push_back(p); 200 if (p) { 201 p->SetParent(this); 202 } 203 } 204 GetParams()205 const auto &GetParams() const { return params_; } GetArgs()206 std::vector<py::object> GetArgs() { 207 std::vector<py::object> args; 208 std::transform(getInputs().begin() + 1, getInputs().end(), std::back_inserter(args), 209 [](ValueNode *n) { return n->GetVobj() ? n->GetVobj()->GetPyObject() : py::object(); }); 210 return args; 211 } 212 213 private: 214 // sub-graph if traced function 215 Graph *sub_graph_; 216 217 InlineReason reason_ = InlineReason::kInlineUnknown; 218 219 std::vector<ValueNode *> params_; // extra values for inline function 220 }; 221 222 bool IsNonLocalValue(ValueNode *i); 223 } // namespace pijit 224 } // namespace mindspore 225 226 #endif // MINDSPORE_PI_JIT_GRAPH_CAPTURE_NODE_H 227