1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/graph_util/generate_graph.h"
18
19 #include <algorithm>
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <utility>
25
26 #include "base/base.h"
27 #include "include/common/utils/python_adapter.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "include/common/utils/parallel_context.h"
30 #include "frontend/parallel/graph_util/node_info.h"
31 #include "ir/anf.h"
32 #include "ir/value.h"
33 #include "mindspore/ccsrc/pipeline/jit/ps/parse/parse_base.h"
34 #include "utils/log_adapter.h"
35 #include "utils/anf_utils.h"
36 #include "ir/primitive.h"
37 #include "ops/op_utils.h"
38 #include "ops/op_def.h"
39
40 using mindspore::tensor::Tensor;
41
42 namespace mindspore {
43 namespace parallel {
44 namespace {
CreateOpPrimtiveWithAttrs(const OperatorAttrs & attrs,const OperatorName & op_name,const std::string & instance_name)45 ValuePtr CreateOpPrimtiveWithAttrs(const OperatorAttrs &attrs, const OperatorName &op_name,
46 const std::string &instance_name) {
47 auto op_def = mindspore::ops::GetOpDef(op_name);
48 if (op_def == nullptr) {
49 return CreateOpInstance(attrs, op_name, instance_name);
50 }
51
52 auto prim = std::make_shared<Primitive>(op_name);
53 MS_EXCEPTION_IF_NULL(prim);
54 prim->set_instance_name(instance_name);
55 for (const auto &[name, value] : attrs) {
56 prim->set_attr(name, value);
57 }
58
59 return prim;
60 }
61
RectifyInputsForNewCNode(const std::vector<AnfNodePtr> & inputs)62 std::vector<AnfNodePtr> RectifyInputsForNewCNode(const std::vector<AnfNodePtr> &inputs) {
63 if (inputs.size() <= 1) {
64 MS_LOG(INTERNAL_EXCEPTION) << "For NewCNode, the inputs should not less than two!";
65 }
66
67 auto value_node = inputs[0]->cast<ValueNodePtr>();
68 MS_EXCEPTION_IF_NULL(value_node);
69 auto value = value_node->value();
70 MS_EXCEPTION_IF_NULL(value);
71 auto prim = value->cast<PrimitivePtr>();
72 MS_EXCEPTION_IF_NULL(prim);
73
74 auto op_name = prim->name();
75 auto op_def = mindspore::ops::GetOpDef(op_name);
76 if (op_def == nullptr) {
77 return inputs;
78 }
79
80 std::vector<AnfNodePtr> new_inputs(inputs.begin(), inputs.end());
81 auto op_inputs_num = op_def->indexes_.size();
82 new_inputs.resize(op_inputs_num + 1); // 1 for primitive.
83
84 // For new defined op, almost all old attrs is changed to inputs.
85 std::vector<std::string> latter_erase;
86 auto attrs = prim->attrs();
87 for (const auto &[name, value] : attrs) {
88 auto [is_input, node_input_idx] = CheckAndGetValidIdxByOpDef(op_def, op_name, name, new_inputs.size());
89 if (!is_input) {
90 continue;
91 }
92 new_inputs[node_input_idx] = NewValueNode(value);
93 latter_erase.push_back(name);
94 }
95
96 for (const auto &name : latter_erase) {
97 prim->EraseAttr(name);
98 }
99
100 return new_inputs;
101 }
102 } // namespace
103
CheckAndGetValidIdxByOpDef(const ops::OpDefPtr & op_def,const std::string & op_name,const std::string & attr_name,size_t limit_size)104 std::pair<bool, size_t> CheckAndGetValidIdxByOpDef(const ops::OpDefPtr &op_def, const std::string &op_name,
105 const std::string &attr_name, size_t limit_size) {
106 auto ks_iter = op_def->indexes_.find(attr_name);
107 if (ks_iter == op_def->indexes_.end()) {
108 MS_LOG(DEBUG) << "For " << op_name << ", cannot find a valid index for input " << attr_name
109 << " in operator-definition.";
110 return std::make_pair(false, SIZE_MAX);
111 }
112
113 auto idx = ks_iter->second;
114 auto real_idx = idx + 1;
115 if (real_idx >= limit_size) {
116 MS_LOG(INTERNAL_EXCEPTION) << "For " << op_name << ", " << idx << " is not a valid index for input " << attr_name;
117 }
118 return std::make_pair(true, real_idx);
119 }
120
GetOpPythonPath(const char * op_name)121 const char *GetOpPythonPath(const char *op_name) {
122 static const py::module inner_mod = py::module::import(INNER_OP_PATH);
123 if (py::hasattr(inner_mod, op_name)) {
124 return INNER_OP_PATH;
125 }
126
127 static const py::module mod = py::module::import(OP_PATH);
128 if (py::hasattr(mod, op_name)) {
129 return OP_PATH;
130 }
131
132 static const py::module grad_mod = py::module::import(GRAD_OP_PATH);
133 if (py::hasattr(grad_mod, op_name)) {
134 return GRAD_OP_PATH;
135 }
136
137 static const py::module nn_mod = py::module::import(NN_OPS_PATH);
138 if (py::hasattr(nn_mod, op_name)) {
139 return NN_OPS_PATH;
140 }
141
142 static const py::module functional_mod = py::module::import(FUNCTIONAL_OP_PATH);
143 if (!py::hasattr(functional_mod, op_name)) {
144 MS_LOG(EXCEPTION) << OP_PATH << " and " << INNER_OP_PATH << " and " << GRAD_OP_PATH << " and " << NN_OPS_PATH
145 << "and" << FUNCTIONAL_OP_PATH << " don't have op:" << op_name;
146 }
147 return FUNCTIONAL_OP_PATH;
148 }
149
CreateOpInstance(const OperatorAttrs & attrs,const OperatorName & op_name,const std::string & instance_name)150 ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
151 const auto op_path = GetOpPythonPath(op_name.c_str());
152 std::vector<py::object> arg_list;
153 (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
154 [](const Attr &attr) { return ValueToPyData(attr.second); });
155 py::object obj =
156 python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
157 ValuePtr op_instance = nullptr;
158 bool succ = parse::ConvertData(obj, &op_instance);
159 if (!succ) {
160 MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail";
161 return nullptr;
162 }
163 return op_instance;
164 }
165
ConvertToRealInputs(const OperatorName & op_name,const std::string & instance_name,const AnfNodePtrList & inputs,const OperatorAttrs & attrs)166 std::vector<AnfNodePtr> ConvertToRealInputs(const OperatorName &op_name, const std::string &instance_name,
167 const AnfNodePtrList &inputs, const OperatorAttrs &attrs) {
168 auto op_def = mindspore::ops::GetOpDef(op_name);
169 if (op_def == nullptr) {
170 // Create old op from python for creating some attr in __init__.
171 auto prim_value = CreateOpInstance(attrs, op_name, instance_name);
172 AnfNodePtrList node_inputs = {NewValueNode(prim_value)};
173 node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
174 return node_inputs;
175 }
176
177 size_t op_inputs_num = inputs.size() + attrs.size();
178 if (op_inputs_num != op_def->indexes_.size()) {
179 MS_LOG(INTERNAL_EXCEPTION) << "For " << op_name << ", inputs should be " << op_def->indexes_.size()
180 << ", but got given inputs num " << inputs.size() << " and attrs num " << attrs.size();
181 }
182
183 auto prim = std::make_shared<Primitive>(op_name);
184 MS_EXCEPTION_IF_NULL(prim);
185 prim->set_instance_name(instance_name);
186
187 AnfNodePtrList node_inputs;
188 node_inputs.resize(1 + op_inputs_num); // 1 for primitive value node.
189 node_inputs[0] = NewValueNode(prim);
190
191 for (size_t i = 0; i < inputs.size(); ++i) {
192 node_inputs[i + 1] = inputs[i];
193 }
194
195 // For new-defined op, almost all attrs are inputs now, here should insert the value as input in right position.
196 for (size_t i = 0; i < attrs.size(); ++i) {
197 auto [attr_name, attr_value] = attrs[i];
198 auto [is_input, node_input_idx] = CheckAndGetValidIdxByOpDef(op_def, op_name, attr_name, node_inputs.size());
199 if (!is_input) {
200 continue;
201 }
202 node_inputs[node_input_idx] = NewValueNode(attr_value);
203 }
204
205 return node_inputs;
206 }
207
CreateCNodeByInputsAndAttr(const FuncGraphPtr & func_graph,const OperatorName & op_name,const std::string & instance_name,const AnfNodePtrList & inputs,const OperatorAttrs & attrs)208 CNodePtr CreateCNodeByInputsAndAttr(const FuncGraphPtr &func_graph, const OperatorName &op_name,
209 const std::string &instance_name, const AnfNodePtrList &inputs,
210 const OperatorAttrs &attrs) {
211 auto real_inputs = ConvertToRealInputs(op_name, instance_name, inputs, attrs);
212 MS_EXCEPTION_IF_NULL(func_graph);
213 auto cnode = func_graph->NewCNode(real_inputs);
214 return cnode;
215 }
216
CreateNewCNodeForReplace(const CNodePtr & origin_node,const PrimitivePtr & new_prim)217 CNodePtr CreateNewCNodeForReplace(const CNodePtr &origin_node, const PrimitivePtr &new_prim) {
218 MS_EXCEPTION_IF_NULL(origin_node);
219 auto func_graph = origin_node->func_graph();
220 MS_EXCEPTION_IF_NULL(func_graph);
221 auto inputs = origin_node->inputs();
222 AnfNodePtrList new_inputs(inputs.begin(), inputs.end());
223
224 MS_EXCEPTION_IF_NULL(new_prim);
225 auto op_name = new_prim->name();
226 auto op_def = mindspore::ops::GetOpDef(op_name);
227 if (op_def != nullptr) {
228 // For new defined op, almost all old attrs is changed to inputs.
229 std::vector<std::string> latter_erase;
230 auto attrs = new_prim->attrs();
231 for (const auto &[name, value] : attrs) {
232 auto [is_input, node_input_idx] = CheckAndGetValidIdxByOpDef(op_def, op_name, name, inputs.size());
233 if (!is_input) {
234 continue;
235 }
236 if (!inputs[node_input_idx]->isa<ValueNode>()) {
237 MS_LOG(INTERNAL_EXCEPTION) << "For auto parallel, the " << (node_input_idx - 1) << " input of " << op_name
238 << " must be a value node!";
239 }
240
241 inputs[node_input_idx] = NewValueNode(value);
242 latter_erase.push_back(name);
243 }
244
245 for (const auto &name : latter_erase) {
246 new_prim->EraseAttr(name);
247 }
248 }
249
250 new_inputs[0] = NewValueNode(new_prim);
251 return func_graph->NewCNode(new_inputs);
252 }
253
ValuePtrToAnfNodePtr(const ValuePtr & value_ptr)254 AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) {
255 auto value_node = NewValueNode(value_ptr);
256 MS_EXCEPTION_IF_NULL(value_node);
257 return value_node->cast<AnfNodePtr>();
258 }
259
CreateInt32Tensor(int64_t value,bool int64_type)260 AnfNodePtr CreateInt32Tensor(int64_t value, bool int64_type) {
261 mindspore::tensor::TensorPtr tensor_ptr;
262 if (int64_type) {
263 tensor_ptr = std::make_shared<tensor::Tensor>(value, kInt64);
264 } else {
265 tensor_ptr = std::make_shared<tensor::Tensor>(value, kInt32);
266 }
267
268 ValuePtr value_ptr = MakeValue(tensor_ptr);
269 auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
270 return anf_node_ptr;
271 }
272
CreateFP32Tensor(float value)273 AnfNodePtr CreateFP32Tensor(float value) {
274 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(value, kFloat32);
275 ValuePtr value_ptr = MakeValue(tensor_ptr);
276 auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
277 return anf_node_ptr;
278 }
279
CreateTypeInt(int64_t nbits)280 AnfNodePtr CreateTypeInt(int64_t nbits) {
281 ValuePtr value_ptr = MakeValue(std::make_shared<Int>(nbits));
282 return ValuePtrToAnfNodePtr(value_ptr);
283 }
284
CreateTypeFloat(int64_t nbits)285 AnfNodePtr CreateTypeFloat(int64_t nbits) {
286 ValuePtr value_ptr = MakeValue(std::make_shared<Float>(nbits));
287 return ValuePtrToAnfNodePtr(value_ptr);
288 }
289
CreatInt64Imm(int64_t value)290 AnfNodePtr CreatInt64Imm(int64_t value) {
291 ValuePtr value_ptr = MakeValue(std::make_shared<Int64Imm>(value));
292 return ValuePtrToAnfNodePtr(value_ptr);
293 }
294
CreateFP32Imm(float value)295 AnfNodePtr CreateFP32Imm(float value) {
296 ValuePtr value_ptr = MakeValue(std::make_shared<FP32Imm>(value));
297 return ValuePtrToAnfNodePtr(value_ptr);
298 }
299
CreateBoolImm(bool value)300 AnfNodePtr CreateBoolImm(bool value) {
301 ValuePtr value_ptr = MakeValue(std::make_shared<BoolImm>(value));
302 return ValuePtrToAnfNodePtr(value_ptr);
303 }
304
CreateStringImm(std::string value)305 AnfNodePtr CreateStringImm(std::string value) {
306 ValuePtr value_ptr = MakeValue(std::make_shared<StringImm>(value));
307 return ValuePtrToAnfNodePtr(value_ptr);
308 }
309
CreateTuple(const std::vector<int64_t> & tuple)310 AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple) {
311 std::vector<ValuePtr> value_list;
312 (void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list),
313 [](const int64_t value) { return MakeValue(value); });
314 ValueTuplePtr value_tuple_ptr = std::make_shared<ValueTuple>(value_list);
315 return ValuePtrToAnfNodePtr(value_tuple_ptr);
316 }
317
GetInstanceNameByCNode(const CNodePtr & cnode)318 std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
319 PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
320 if (!prim) {
321 MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr.";
322 }
323 std::string instance_name = prim->instance_name();
324 return HashInstanceName(instance_name);
325 }
326
HashInstanceName(const std::string & name)327 std::string HashInstanceName(const std::string &name) {
328 auto using_hash_name = common::GetEnv(USING_HASH_NAME);
329 std::string instance_name;
330 if ((using_hash_name.empty()) || (using_hash_name == "on")) {
331 instance_name = HashName(name);
332 } else {
333 instance_name = name;
334 }
335 return instance_name;
336 }
337
InsertVirtualPipelineEndNode(const CNodePtr & cnode,const FuncGraphManagerPtr & manager,size_t index,std::string end_flag)338 void InsertVirtualPipelineEndNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, size_t index,
339 std::string end_flag) {
340 auto pre_cnode = cnode->input(index)->cast<CNodePtr>();
341 MS_EXCEPTION_IF_NULL(pre_cnode);
342 auto graph = cnode->func_graph();
343 MS_EXCEPTION_IF_NULL(graph);
344 OperatorAttrs attrs_;
345 auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node");
346 auto value_node = NewValueNode(op);
347 auto virtual_end = graph->NewCNode({value_node, pre_cnode});
348 virtual_end->set_abstract(pre_cnode->abstract());
349 virtual_end->AddPrimalAttr(end_flag, pre_cnode->GetPrimalAttr(MICRO));
350 virtual_end->AddPrimalAttr(MICRO, pre_cnode->GetPrimalAttr(MICRO));
351 manager->SetEdge(cnode, SizeToInt(index), virtual_end);
352 if (ParallelContext::GetInstance()->enable_fold_pipeline()) {
353 auto seg = ParallelContext::GetInstance()->pipeline_segment_split_num();
354 virtual_end->AddPrimalAttr(SEGMENT, MakeValue(seg - 1));
355 }
356 }
357
CreateVirtualConverterBeginNode(const CNodePtr & input_cnode,size_t output_nums)358 CNodePtr CreateVirtualConverterBeginNode(const CNodePtr &input_cnode, size_t output_nums) {
359 auto graph = input_cnode->func_graph();
360 MS_EXCEPTION_IF_NULL(graph);
361 Attr output_nums_attr = {"output_nums", MakeValue(output_nums)};
362 OperatorAttrs attrs_ = {output_nums_attr};
363 auto op = CreateOpInstance(attrs_, "_VirtualConverterBegin", "virtual_converter_begin");
364 auto value_node = NewValueNode(op);
365 auto virtual_begin = graph->NewCNode({value_node, input_cnode});
366 return virtual_begin;
367 }
368
CreateVirtualConverterEndNode(const FuncGraphPtr & graph,const std::vector<CNodePtr> & input_cnodes)369 CNodePtr CreateVirtualConverterEndNode(const FuncGraphPtr &graph, const std::vector<CNodePtr> &input_cnodes) {
370 if (input_cnodes.empty()) {
371 MS_LOG(EXCEPTION) << "input cnodes for _VirtualConverterEnd is empty.";
372 }
373 Attr input_nums_attr = {"input_nums", MakeValue(input_cnodes.size())};
374 OperatorAttrs attrs_ = {input_nums_attr};
375 auto op = CreateOpInstance(attrs_, "_VirtualConverterEnd", "virtual_converter_End");
376 auto value_node = NewValueNode(op);
377 std::vector<AnfNodePtr> virtual_end_input = {value_node};
378 std::copy(input_cnodes.begin(), input_cnodes.end(), std::back_inserter(virtual_end_input));
379 auto virtual_end = graph->NewCNode(virtual_end_input);
380 return virtual_end;
381 }
382
Init(const CNodePtr & cnode)383 Status GenerateGraph::Init(const CNodePtr &cnode) {
384 if (!cnode) {
385 MS_LOG(ERROR) << "Init:cnode is nullptr";
386 return FAILED;
387 }
388 cnode_ = cnode;
389 func_graph_ = cnode->func_graph();
390 if (!func_graph_) {
391 MS_LOG(ERROR) << "Init:func_graph_ is nullptr";
392 return FAILED;
393 }
394 manager_ = func_graph_->manager();
395 if (!manager_) {
396 MS_LOG(ERROR) << "Init:manager_ is nullptr";
397 return FAILED;
398 }
399 scope_ = cnode_->scope();
400 if (!scope_) {
401 MS_LOG(ERROR) << "Init:scope_ is nullptr";
402 return FAILED;
403 }
404 virtual_input_node_ = std::make_shared<AnfNode>(nullptr);
405 virtual_input_node_->set_scope(scope_);
406 instance_name_base_ = GetInstanceNameByCNode(cnode_);
407 name_idx_ = 0;
408 return SUCCESS;
409 }
410
PushBack(const std::vector<AnfNodePtr> & inputs)411 AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
412 auto new_inputs = RectifyInputsForNewCNode(inputs);
413 for (auto &input : new_inputs) {
414 MS_EXCEPTION_IF_NULL(input); // if error raise here, check if inputs need include attrs
415 }
416 CNodePtr cnode = func_graph_->NewCNode(new_inputs); // using NewCNode to create anfnode
417 MS_EXCEPTION_IF_NULL(cnode);
418 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
419 SetUserAttrs(origin_attrs_, prim);
420 cnode->set_scope(scope_);
421 auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
422 MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
423 return new_anf_node_ptr;
424 }
425
NewOpInst(const OperatorName & op_name,const OperatorAttrs & attrs)426 AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
427 name_idx_++;
428 ValuePtr op_prim_instance =
429 CreateOpPrimtiveWithAttrs(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
430 if (op_prim_instance == nullptr) {
431 MS_LOG(EXCEPTION) << "Failure:" << op_name << " NewOpInst failed";
432 }
433 auto value_node = NewValueNode(op_prim_instance);
434 return value_node->cast<AnfNodePtr>();
435 }
436
NewOpInst(const OperatorName & op_name)437 AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
438 name_idx_++;
439 OperatorAttrs attrs;
440 ValuePtr op_prim_instance =
441 CreateOpPrimtiveWithAttrs(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
442 if (op_prim_instance == nullptr) {
443 MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreateOpInstance failed";
444 }
445 auto value_node = NewValueNode(op_prim_instance);
446 return value_node->cast<AnfNodePtr>();
447 }
448 } // namespace parallel
449 } // namespace mindspore
450