• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/generate_graph.h"
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <utility>
25 
26 #include "base/base.h"
27 #include "include/common/utils/python_adapter.h"
28 #include "include/common/utils/convert_utils_py.h"
29 #include "include/common/utils/parallel_context.h"
30 #include "frontend/parallel/graph_util/node_info.h"
31 #include "ir/anf.h"
32 #include "ir/value.h"
33 #include "mindspore/ccsrc/pipeline/jit/ps/parse/parse_base.h"
34 #include "utils/log_adapter.h"
35 #include "utils/anf_utils.h"
36 #include "ir/primitive.h"
37 #include "ops/op_utils.h"
38 #include "ops/op_def.h"
39 
40 using mindspore::tensor::Tensor;
41 
42 namespace mindspore {
43 namespace parallel {
44 namespace {
CreateOpPrimtiveWithAttrs(const OperatorAttrs & attrs,const OperatorName & op_name,const std::string & instance_name)45 ValuePtr CreateOpPrimtiveWithAttrs(const OperatorAttrs &attrs, const OperatorName &op_name,
46                                    const std::string &instance_name) {
47   auto op_def = mindspore::ops::GetOpDef(op_name);
48   if (op_def == nullptr) {
49     return CreateOpInstance(attrs, op_name, instance_name);
50   }
51 
52   auto prim = std::make_shared<Primitive>(op_name);
53   MS_EXCEPTION_IF_NULL(prim);
54   prim->set_instance_name(instance_name);
55   for (const auto &[name, value] : attrs) {
56     prim->set_attr(name, value);
57   }
58 
59   return prim;
60 }
61 
RectifyInputsForNewCNode(const std::vector<AnfNodePtr> & inputs)62 std::vector<AnfNodePtr> RectifyInputsForNewCNode(const std::vector<AnfNodePtr> &inputs) {
63   if (inputs.size() <= 1) {
64     MS_LOG(INTERNAL_EXCEPTION) << "For NewCNode, the inputs should not less than two!";
65   }
66 
67   auto value_node = inputs[0]->cast<ValueNodePtr>();
68   MS_EXCEPTION_IF_NULL(value_node);
69   auto value = value_node->value();
70   MS_EXCEPTION_IF_NULL(value);
71   auto prim = value->cast<PrimitivePtr>();
72   MS_EXCEPTION_IF_NULL(prim);
73 
74   auto op_name = prim->name();
75   auto op_def = mindspore::ops::GetOpDef(op_name);
76   if (op_def == nullptr) {
77     return inputs;
78   }
79 
80   std::vector<AnfNodePtr> new_inputs(inputs.begin(), inputs.end());
81   auto op_inputs_num = op_def->indexes_.size();
82   new_inputs.resize(op_inputs_num + 1);  // 1 for primitive.
83 
84   // For new defined op, almost all old attrs is changed to inputs.
85   std::vector<std::string> latter_erase;
86   auto attrs = prim->attrs();
87   for (const auto &[name, value] : attrs) {
88     auto [is_input, node_input_idx] = CheckAndGetValidIdxByOpDef(op_def, op_name, name, new_inputs.size());
89     if (!is_input) {
90       continue;
91     }
92     new_inputs[node_input_idx] = NewValueNode(value);
93     latter_erase.push_back(name);
94   }
95 
96   for (const auto &name : latter_erase) {
97     prim->EraseAttr(name);
98   }
99 
100   return new_inputs;
101 }
102 }  // namespace
103 
CheckAndGetValidIdxByOpDef(const ops::OpDefPtr & op_def,const std::string & op_name,const std::string & attr_name,size_t limit_size)104 std::pair<bool, size_t> CheckAndGetValidIdxByOpDef(const ops::OpDefPtr &op_def, const std::string &op_name,
105                                                    const std::string &attr_name, size_t limit_size) {
106   auto ks_iter = op_def->indexes_.find(attr_name);
107   if (ks_iter == op_def->indexes_.end()) {
108     MS_LOG(DEBUG) << "For " << op_name << ", cannot find a valid index for input " << attr_name
109                   << " in operator-definition.";
110     return std::make_pair(false, SIZE_MAX);
111   }
112 
113   auto idx = ks_iter->second;
114   auto real_idx = idx + 1;
115   if (real_idx >= limit_size) {
116     MS_LOG(INTERNAL_EXCEPTION) << "For " << op_name << ", " << idx << " is not a valid index for input " << attr_name;
117   }
118   return std::make_pair(true, real_idx);
119 }
120 
GetOpPythonPath(const char * op_name)121 const char *GetOpPythonPath(const char *op_name) {
122   static const py::module inner_mod = py::module::import(INNER_OP_PATH);
123   if (py::hasattr(inner_mod, op_name)) {
124     return INNER_OP_PATH;
125   }
126 
127   static const py::module mod = py::module::import(OP_PATH);
128   if (py::hasattr(mod, op_name)) {
129     return OP_PATH;
130   }
131 
132   static const py::module grad_mod = py::module::import(GRAD_OP_PATH);
133   if (py::hasattr(grad_mod, op_name)) {
134     return GRAD_OP_PATH;
135   }
136 
137   static const py::module nn_mod = py::module::import(NN_OPS_PATH);
138   if (py::hasattr(nn_mod, op_name)) {
139     return NN_OPS_PATH;
140   }
141 
142   static const py::module functional_mod = py::module::import(FUNCTIONAL_OP_PATH);
143   if (!py::hasattr(functional_mod, op_name)) {
144     MS_LOG(EXCEPTION) << OP_PATH << " and " << INNER_OP_PATH << " and " << GRAD_OP_PATH << " and " << NN_OPS_PATH
145                       << "and" << FUNCTIONAL_OP_PATH << " don't have op:" << op_name;
146   }
147   return FUNCTIONAL_OP_PATH;
148 }
149 
CreateOpInstance(const OperatorAttrs & attrs,const OperatorName & op_name,const std::string & instance_name)150 ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
151   const auto op_path = GetOpPythonPath(op_name.c_str());
152   std::vector<py::object> arg_list;
153   (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
154                        [](const Attr &attr) { return ValueToPyData(attr.second); });
155   py::object obj =
156     python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
157   ValuePtr op_instance = nullptr;
158   bool succ = parse::ConvertData(obj, &op_instance);
159   if (!succ) {
160     MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail";
161     return nullptr;
162   }
163   return op_instance;
164 }
165 
ConvertToRealInputs(const OperatorName & op_name,const std::string & instance_name,const AnfNodePtrList & inputs,const OperatorAttrs & attrs)166 std::vector<AnfNodePtr> ConvertToRealInputs(const OperatorName &op_name, const std::string &instance_name,
167                                             const AnfNodePtrList &inputs, const OperatorAttrs &attrs) {
168   auto op_def = mindspore::ops::GetOpDef(op_name);
169   if (op_def == nullptr) {
170     // Create old op from python for creating some attr in __init__.
171     auto prim_value = CreateOpInstance(attrs, op_name, instance_name);
172     AnfNodePtrList node_inputs = {NewValueNode(prim_value)};
173     node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
174     return node_inputs;
175   }
176 
177   size_t op_inputs_num = inputs.size() + attrs.size();
178   if (op_inputs_num != op_def->indexes_.size()) {
179     MS_LOG(INTERNAL_EXCEPTION) << "For " << op_name << ", inputs should be " << op_def->indexes_.size()
180                                << ", but got given inputs num " << inputs.size() << " and attrs num " << attrs.size();
181   }
182 
183   auto prim = std::make_shared<Primitive>(op_name);
184   MS_EXCEPTION_IF_NULL(prim);
185   prim->set_instance_name(instance_name);
186 
187   AnfNodePtrList node_inputs;
188   node_inputs.resize(1 + op_inputs_num);  // 1 for primitive value node.
189   node_inputs[0] = NewValueNode(prim);
190 
191   for (size_t i = 0; i < inputs.size(); ++i) {
192     node_inputs[i + 1] = inputs[i];
193   }
194 
195   // For new-defined op, almost all attrs are inputs now, here should insert the value as input in right position.
196   for (size_t i = 0; i < attrs.size(); ++i) {
197     auto [attr_name, attr_value] = attrs[i];
198     auto [is_input, node_input_idx] = CheckAndGetValidIdxByOpDef(op_def, op_name, attr_name, node_inputs.size());
199     if (!is_input) {
200       continue;
201     }
202     node_inputs[node_input_idx] = NewValueNode(attr_value);
203   }
204 
205   return node_inputs;
206 }
207 
CreateCNodeByInputsAndAttr(const FuncGraphPtr & func_graph,const OperatorName & op_name,const std::string & instance_name,const AnfNodePtrList & inputs,const OperatorAttrs & attrs)208 CNodePtr CreateCNodeByInputsAndAttr(const FuncGraphPtr &func_graph, const OperatorName &op_name,
209                                     const std::string &instance_name, const AnfNodePtrList &inputs,
210                                     const OperatorAttrs &attrs) {
211   auto real_inputs = ConvertToRealInputs(op_name, instance_name, inputs, attrs);
212   MS_EXCEPTION_IF_NULL(func_graph);
213   auto cnode = func_graph->NewCNode(real_inputs);
214   return cnode;
215 }
216 
CreateNewCNodeForReplace(const CNodePtr & origin_node,const PrimitivePtr & new_prim)217 CNodePtr CreateNewCNodeForReplace(const CNodePtr &origin_node, const PrimitivePtr &new_prim) {
218   MS_EXCEPTION_IF_NULL(origin_node);
219   auto func_graph = origin_node->func_graph();
220   MS_EXCEPTION_IF_NULL(func_graph);
221   auto inputs = origin_node->inputs();
222   AnfNodePtrList new_inputs(inputs.begin(), inputs.end());
223 
224   MS_EXCEPTION_IF_NULL(new_prim);
225   auto op_name = new_prim->name();
226   auto op_def = mindspore::ops::GetOpDef(op_name);
227   if (op_def != nullptr) {
228     // For new defined op, almost all old attrs is changed to inputs.
229     std::vector<std::string> latter_erase;
230     auto attrs = new_prim->attrs();
231     for (const auto &[name, value] : attrs) {
232       auto [is_input, node_input_idx] = CheckAndGetValidIdxByOpDef(op_def, op_name, name, inputs.size());
233       if (!is_input) {
234         continue;
235       }
236       if (!inputs[node_input_idx]->isa<ValueNode>()) {
237         MS_LOG(INTERNAL_EXCEPTION) << "For auto parallel, the " << (node_input_idx - 1) << " input of " << op_name
238                                    << " must be a value node!";
239       }
240 
241       inputs[node_input_idx] = NewValueNode(value);
242       latter_erase.push_back(name);
243     }
244 
245     for (const auto &name : latter_erase) {
246       new_prim->EraseAttr(name);
247     }
248   }
249 
250   new_inputs[0] = NewValueNode(new_prim);
251   return func_graph->NewCNode(new_inputs);
252 }
253 
ValuePtrToAnfNodePtr(const ValuePtr & value_ptr)254 AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) {
255   auto value_node = NewValueNode(value_ptr);
256   MS_EXCEPTION_IF_NULL(value_node);
257   return value_node->cast<AnfNodePtr>();
258 }
259 
CreateInt32Tensor(int64_t value,bool int64_type)260 AnfNodePtr CreateInt32Tensor(int64_t value, bool int64_type) {
261   mindspore::tensor::TensorPtr tensor_ptr;
262   if (int64_type) {
263     tensor_ptr = std::make_shared<tensor::Tensor>(value, kInt64);
264   } else {
265     tensor_ptr = std::make_shared<tensor::Tensor>(value, kInt32);
266   }
267 
268   ValuePtr value_ptr = MakeValue(tensor_ptr);
269   auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
270   return anf_node_ptr;
271 }
272 
CreateFP32Tensor(float value)273 AnfNodePtr CreateFP32Tensor(float value) {
274   mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(value, kFloat32);
275   ValuePtr value_ptr = MakeValue(tensor_ptr);
276   auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
277   return anf_node_ptr;
278 }
279 
CreateTypeInt(int64_t nbits)280 AnfNodePtr CreateTypeInt(int64_t nbits) {
281   ValuePtr value_ptr = MakeValue(std::make_shared<Int>(nbits));
282   return ValuePtrToAnfNodePtr(value_ptr);
283 }
284 
CreateTypeFloat(int64_t nbits)285 AnfNodePtr CreateTypeFloat(int64_t nbits) {
286   ValuePtr value_ptr = MakeValue(std::make_shared<Float>(nbits));
287   return ValuePtrToAnfNodePtr(value_ptr);
288 }
289 
CreatInt64Imm(int64_t value)290 AnfNodePtr CreatInt64Imm(int64_t value) {
291   ValuePtr value_ptr = MakeValue(std::make_shared<Int64Imm>(value));
292   return ValuePtrToAnfNodePtr(value_ptr);
293 }
294 
CreateFP32Imm(float value)295 AnfNodePtr CreateFP32Imm(float value) {
296   ValuePtr value_ptr = MakeValue(std::make_shared<FP32Imm>(value));
297   return ValuePtrToAnfNodePtr(value_ptr);
298 }
299 
CreateBoolImm(bool value)300 AnfNodePtr CreateBoolImm(bool value) {
301   ValuePtr value_ptr = MakeValue(std::make_shared<BoolImm>(value));
302   return ValuePtrToAnfNodePtr(value_ptr);
303 }
304 
CreateStringImm(std::string value)305 AnfNodePtr CreateStringImm(std::string value) {
306   ValuePtr value_ptr = MakeValue(std::make_shared<StringImm>(value));
307   return ValuePtrToAnfNodePtr(value_ptr);
308 }
309 
CreateTuple(const std::vector<int64_t> & tuple)310 AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple) {
311   std::vector<ValuePtr> value_list;
312   (void)std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list),
313                        [](const int64_t value) { return MakeValue(value); });
314   ValueTuplePtr value_tuple_ptr = std::make_shared<ValueTuple>(value_list);
315   return ValuePtrToAnfNodePtr(value_tuple_ptr);
316 }
317 
GetInstanceNameByCNode(const CNodePtr & cnode)318 std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
319   PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
320   if (!prim) {
321     MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr.";
322   }
323   std::string instance_name = prim->instance_name();
324   return HashInstanceName(instance_name);
325 }
326 
HashInstanceName(const std::string & name)327 std::string HashInstanceName(const std::string &name) {
328   auto using_hash_name = common::GetEnv(USING_HASH_NAME);
329   std::string instance_name;
330   if ((using_hash_name.empty()) || (using_hash_name == "on")) {
331     instance_name = HashName(name);
332   } else {
333     instance_name = name;
334   }
335   return instance_name;
336 }
337 
InsertVirtualPipelineEndNode(const CNodePtr & cnode,const FuncGraphManagerPtr & manager,size_t index,std::string end_flag)338 void InsertVirtualPipelineEndNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, size_t index,
339                                   std::string end_flag) {
340   auto pre_cnode = cnode->input(index)->cast<CNodePtr>();
341   MS_EXCEPTION_IF_NULL(pre_cnode);
342   auto graph = cnode->func_graph();
343   MS_EXCEPTION_IF_NULL(graph);
344   OperatorAttrs attrs_;
345   auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node");
346   auto value_node = NewValueNode(op);
347   auto virtual_end = graph->NewCNode({value_node, pre_cnode});
348   virtual_end->set_abstract(pre_cnode->abstract());
349   virtual_end->AddPrimalAttr(end_flag, pre_cnode->GetPrimalAttr(MICRO));
350   virtual_end->AddPrimalAttr(MICRO, pre_cnode->GetPrimalAttr(MICRO));
351   manager->SetEdge(cnode, SizeToInt(index), virtual_end);
352   if (ParallelContext::GetInstance()->enable_fold_pipeline()) {
353     auto seg = ParallelContext::GetInstance()->pipeline_segment_split_num();
354     virtual_end->AddPrimalAttr(SEGMENT, MakeValue(seg - 1));
355   }
356 }
357 
CreateVirtualConverterBeginNode(const CNodePtr & input_cnode,size_t output_nums)358 CNodePtr CreateVirtualConverterBeginNode(const CNodePtr &input_cnode, size_t output_nums) {
359   auto graph = input_cnode->func_graph();
360   MS_EXCEPTION_IF_NULL(graph);
361   Attr output_nums_attr = {"output_nums", MakeValue(output_nums)};
362   OperatorAttrs attrs_ = {output_nums_attr};
363   auto op = CreateOpInstance(attrs_, "_VirtualConverterBegin", "virtual_converter_begin");
364   auto value_node = NewValueNode(op);
365   auto virtual_begin = graph->NewCNode({value_node, input_cnode});
366   return virtual_begin;
367 }
368 
CreateVirtualConverterEndNode(const FuncGraphPtr & graph,const std::vector<CNodePtr> & input_cnodes)369 CNodePtr CreateVirtualConverterEndNode(const FuncGraphPtr &graph, const std::vector<CNodePtr> &input_cnodes) {
370   if (input_cnodes.empty()) {
371     MS_LOG(EXCEPTION) << "input cnodes for _VirtualConverterEnd is empty.";
372   }
373   Attr input_nums_attr = {"input_nums", MakeValue(input_cnodes.size())};
374   OperatorAttrs attrs_ = {input_nums_attr};
375   auto op = CreateOpInstance(attrs_, "_VirtualConverterEnd", "virtual_converter_End");
376   auto value_node = NewValueNode(op);
377   std::vector<AnfNodePtr> virtual_end_input = {value_node};
378   std::copy(input_cnodes.begin(), input_cnodes.end(), std::back_inserter(virtual_end_input));
379   auto virtual_end = graph->NewCNode(virtual_end_input);
380   return virtual_end;
381 }
382 
Init(const CNodePtr & cnode)383 Status GenerateGraph::Init(const CNodePtr &cnode) {
384   if (!cnode) {
385     MS_LOG(ERROR) << "Init:cnode is nullptr";
386     return FAILED;
387   }
388   cnode_ = cnode;
389   func_graph_ = cnode->func_graph();
390   if (!func_graph_) {
391     MS_LOG(ERROR) << "Init:func_graph_ is nullptr";
392     return FAILED;
393   }
394   manager_ = func_graph_->manager();
395   if (!manager_) {
396     MS_LOG(ERROR) << "Init:manager_ is nullptr";
397     return FAILED;
398   }
399   scope_ = cnode_->scope();
400   if (!scope_) {
401     MS_LOG(ERROR) << "Init:scope_ is nullptr";
402     return FAILED;
403   }
404   virtual_input_node_ = std::make_shared<AnfNode>(nullptr);
405   virtual_input_node_->set_scope(scope_);
406   instance_name_base_ = GetInstanceNameByCNode(cnode_);
407   name_idx_ = 0;
408   return SUCCESS;
409 }
410 
PushBack(const std::vector<AnfNodePtr> & inputs)411 AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
412   auto new_inputs = RectifyInputsForNewCNode(inputs);
413   for (auto &input : new_inputs) {
414     MS_EXCEPTION_IF_NULL(input);  // if error raise here, check if inputs need include attrs
415   }
416   CNodePtr cnode = func_graph_->NewCNode(new_inputs);  // using NewCNode to create anfnode
417   MS_EXCEPTION_IF_NULL(cnode);
418   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
419   SetUserAttrs(origin_attrs_, prim);
420   cnode->set_scope(scope_);
421   auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
422   MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
423   return new_anf_node_ptr;
424 }
425 
NewOpInst(const OperatorName & op_name,const OperatorAttrs & attrs)426 AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
427   name_idx_++;
428   ValuePtr op_prim_instance =
429     CreateOpPrimtiveWithAttrs(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
430   if (op_prim_instance == nullptr) {
431     MS_LOG(EXCEPTION) << "Failure:" << op_name << " NewOpInst failed";
432   }
433   auto value_node = NewValueNode(op_prim_instance);
434   return value_node->cast<AnfNodePtr>();
435 }
436 
NewOpInst(const OperatorName & op_name)437 AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
438   name_idx_++;
439   OperatorAttrs attrs;
440   ValuePtr op_prim_instance =
441     CreateOpPrimtiveWithAttrs(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
442   if (op_prim_instance == nullptr) {
443     MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreateOpInstance failed";
444   }
445   auto value_node = NewValueNode(op_prim_instance);
446   return value_node->cast<AnfNodePtr>();
447 }
448 }  // namespace parallel
449 }  // namespace mindspore
450