1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_PI_JIT_OPERATION_H_
17 #define MINDSPORE_PI_JIT_OPERATION_H_
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "pipeline/jit/pi/graph_compiler/pi_ir/node.h"
23 #include "pybind11/stl.h"
24
25 namespace mindspore {
26 namespace pijit {
27 namespace ir {
28 using OpCode = int;
29
30 namespace py = pybind11;
31
GetOpName(OpCode op)32 static std::string GetOpName(OpCode op) {
33 static const std::vector<std::string> op_names =
34 py::cast<std::vector<std::string>>(py::module::import("opcode").attr("opname"));
35 return op_names[op];
36 }
37
38 /// \brief Operation is the parent class of all class which represent the operation of a instruction.
39 class Operation : public Node {
40 public:
41 /**
42 * \brief The constructor of operation.
43 *
44 * \param[in] op the opcode of this operation.
45 *
46 * \return The instance of operation.
47 */
Operation(OpCode op)48 explicit Operation(OpCode op) : Operation(op, {}) {}
49
50 /**
51 * \brief The constructor of operation.
52 *
53 * \param[in] op the opcode of this operation.
54 * \param[in] args the opands of this operation.
55 *
56 * \return The instance of operation.
57 */
Operation(OpCode op,const NodePtrList & args)58 explicit Operation(OpCode op, const NodePtrList &args) : opcode_(op), need_ext_instr_(false), args_(args) {}
59
60 /// \brief Destructor.
61 ~Operation() override = default;
62 JIT_DECLARE_PARENT(Operation, Node);
63
64 /**
65 * \brief Judge whether this node is an operation(instruction).
66 *
67 * \return The result of the judgment.
68 */
IsOperation()69 bool IsOperation() const override { return true; }
70
71 /**
72 * \brief Set the id of this operation.
73 *
74 * \note This method should not be actively called by the program writer, it should only be called by the method
75 * Sort()
76 */
SetNodeId(size_t * id)77 void SetNodeId(size_t *id) override {
78 for (const auto &arg : args_) {
79 arg->SetNodeId(id);
80 }
81 Node::SetNodeId(id);
82 }
83
84 /**
85 * \brief Set the offset of this operation.
86 *
87 * \note This method should not be actively called by the program writer, it should only be called by the method
88 * Sort()
89 */
SetOffset(size_t * offset)90 void SetOffset(size_t *offset) override {
91 for (const auto &arg : args_) {
92 arg->SetOffset(offset);
93 }
94 Node::SetOffset(offset);
95 }
96
97 /**
98 * \brief Get opcode of this operation.
99 *
100 * \return the opcode of this operation.
101 */
GetOpCode()102 OpCode GetOpCode() const { return opcode_; }
103
104 /**
105 * \brief Judge whether need to insert a EXTENDED_ARG instruction before this operation.
106 *
107 * \return The result of the judgment.
108 */
NeedExtInstr()109 bool NeedExtInstr() const override { return need_ext_instr_; }
110
111 /**
112 * \brief Mark whether this operation need to insert a EXTENDED_ARG instruction.
113 *
114 * \param[in] need the result.
115 */
SetNeedExtInstr(bool need)116 void SetNeedExtInstr(bool need) override { need_ext_instr_ = need; }
117
118 /**
119 * \brief Get the count of args.
120 *
121 * \return the count of args.
122 */
GetArgsCnt()123 size_t GetArgsCnt() const { return args_.size(); }
124
125 /**
126 * \brief Get the specified positional operand of this operation.
127 *
128 * \return The specified positional operand
129 */
130 const NodePtr &GetArg(size_t index = 0) const { return args_[index]; }
131
132 /**
133 * \brief Set the operand of this operation.
134 *
135 * \param[in] index the position of the arg.
136 * \param[in] arg the value of the arg.
137 */
SetArg(size_t index,const NodePtr & arg)138 void SetArg(size_t index, const NodePtr &arg) { args_[index] = arg; }
139
140 /**
141 * \brief Get the operands of this operation.
142 *
143 * \return the operands of this operation.
144 */
GetArgs()145 const NodePtrList &GetArgs() const { return args_; }
146
147 /**
148 * \brief Get the operands of this operation.
149 *
150 * \return the operands of this operation.
151 */
GetArgs()152 NodePtrList &GetArgs() { return args_; }
153
154 /**
155 * \brief Set the operands of this operation.
156 *
157 * \param[in] args the new value of the operation.
158 */
SetArgs(const NodePtrList & args)159 void SetArgs(const NodePtrList &args) { args_ = args; }
160
161 private:
162 /// \brief The opcode of this operation.
163 OpCode opcode_;
164 /// \brief The EXTENDED_ARG instruction is required.
165 bool need_ext_instr_;
166 /// \brief The operands of this operation.
167 NodePtrList args_;
168 };
169
170 using OperationPtr = std::shared_ptr<Operation>;
171
172 /// \brief UnaryOperation is is the parent class of all class which represent the operation of instruction with one
173 /// operand.
174 class UnaryOperation : public Operation {
175 public:
176 /**
177 * \brief The constructor of unary operation.
178 *
179 * \param[in] op the opcode of this unary operation.
180 * \param[in] arg the operand of this unary operation.
181 *
182 * \return The instance of unary operation.
183 */
UnaryOperation(OpCode op,const NodePtr & arg)184 UnaryOperation(OpCode op, const NodePtr &arg) : Operation(op, {arg}) {}
185
186 /// \brief Destructor.
187 ~UnaryOperation() override = default;
188 JIT_DECLARE_PARENT(UnaryOperation, Operation);
189
190 /**
191 * \brief Set the operand of this unary operation.
192 *
193 * \param[in] arg the value of the arg.
194 */
SetArg(const NodePtr & arg)195 void SetArg(const NodePtr &arg) { Operation::SetArg(0, arg); }
196
197 /**
198 * \brief Get the description of this unary operation.
199 * \return The description.
200 */
ToString()201 std::string ToString() const override {
202 return GetArg()->ToString() + "\n%" + std::to_string(GetNodeId()) + " = " + GetNodeName() + "[" +
203 GetType()->GetName() + "](" + GetOpName(GetOpCode()) + ", %" + std::to_string(GetArg()->GetNodeId()) + ")\n";
204 }
205 };
206
207 using UnaryOperationPtr = std::shared_ptr<UnaryOperation>;
208
209 /// \brief BinaryOperation is is the parent class of all class which represent the operation of instruction with two
210 /// operand.
211 class BinaryOperation : public Operation {
212 public:
213 /**
214 * \brief The constructor of binary operation node.
215 *
216 * \param[in] op the opcode of this binary operation node.
217 * \param[in] left the first operand of this binary operation node.
218 * \param[in] right the second operand of this binary operation node.
219 *
220 * \return The instance of binary operation node.
221 */
BinaryOperation(OpCode op,const NodePtr & left,const NodePtr & right)222 BinaryOperation(OpCode op, const NodePtr &left, const NodePtr &right) : Operation(op, {left, right}) {}
223
224 /// \brief Destructor.
225 ~BinaryOperation() override = default;
226 JIT_DECLARE_PARENT(BinaryOperation, Operation);
227
228 /**
229 * \brief Get the operand of this binary operation.
230 *
231 * \return the operand of this binary operation.
232 */
GetLeftArg()233 const NodePtr &GetLeftArg() const { return GetArg(0); }
234
235 /**
236 * \brief Set the first operand of this binary operation.
237 */
SetLeftArg(const NodePtr & arg)238 void SetLeftArg(const NodePtr &arg) { SetArg(0, arg); }
239
240 /**
241 * \brief Get the operand of this binary operation.
242 *
243 * \return the operand of this binary operation.
244 */
GetRightArg()245 const NodePtr &GetRightArg() const { return GetArg(1); }
246
247 /**
248 * \brief Set the second operand of this binary operation.
249 */
SetRightArg(const NodePtr & arg)250 void SetRightArg(const NodePtr &arg) { SetArg(1, arg); }
251
252 /**
253 * \brief Get the description of this binary operation.
254 * \return The description.
255 */
ToString()256 std::string ToString() const override {
257 return GetArg(0)->ToString() + "\n" + GetArg(1)->ToString() + "\n%" + std::to_string(GetNodeId()) + " = " +
258 GetNodeName() + "[" + GetType()->GetName() + "](" + GetOpName(GetOpCode()) + ", %" +
259 std::to_string(GetArg(0)->GetNodeId()) + ", %" + std::to_string(GetArg(1)->GetNodeId()) + ")\n";
260 }
261 };
262
263 using BinaryOperationPtr = std::shared_ptr<BinaryOperation>;
264
265 /// \brief NaryOperation is is the parent class of all class which represent the operation of instruction with
266 /// indeterminate number of operands.
267 class NaryOperation : public Operation {
268 public:
269 /**
270 * \brief The constructor of nary operation node.
271 *
272 * \param[in] op the opcode of this nary operation node.
273 *
274 * \return The instance of nary operation node.
275 */
NaryOperation(OpCode op)276 explicit NaryOperation(OpCode op) : NaryOperation(op, {}) {}
277
278 /**
279 * \brief The constructor of nary operation node.
280 *
281 * \param[in] op the opcode of this nary operation node.
282 * \param[in] args the operands of this nary operation node.
283 *
284 * \return The instance of nary operation node.
285 */
NaryOperation(OpCode op,const NodePtrList & args)286 NaryOperation(OpCode op, const NodePtrList &args) : Operation(op, args) {}
287
288 /// \brief Destructor.
289 ~NaryOperation() override = default;
290 JIT_DECLARE_PARENT(NaryOperation, Operation);
291
292 /**
293 * \brief Get the description of this nary operation.
294 * \return The description.
295 */
ToString()296 std::string ToString() const override {
297 std::string str;
298 for (const auto &arg : GetArgs()) {
299 str += arg->ToString() + "\n";
300 }
301 str += "%" + std::to_string(GetNodeId()) + " = " + GetNodeName() + "[" + GetType()->GetName() + "](" +
302 GetOpName(GetOpCode());
303 for (const auto &arg : GetArgs()) {
304 str += ", %" + std::to_string(arg->GetNodeId());
305 }
306 str += ")\n";
307 return str;
308 }
309 };
310
311 using NaryOperationPtr = std::shared_ptr<NaryOperation>;
312 } // namespace ir
313 } // namespace pijit
314 } // namespace mindspore
315
316 #endif // MINDSPORE_PI_JIT_OPERATION_H_
317