• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/get_parallel_info.h"
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <tuple>
23 #include <unordered_map>
24 
25 #include "ir/func_graph.h"
26 #include "frontend/parallel/ops_info/operator_info.h"
27 #include "frontend/parallel/graph_util/graph_info.h"
28 #include "frontend/parallel/strategy.h"
29 #include "frontend/parallel/tensor_layout/tensor_layout.h"
30 #include "frontend/parallel/ops_info/ops_utils.h"
31 #include "frontend/parallel/parameter_manager.h"
32 #include "frontend/parallel/tensor_layout/shared_parameter.h"
33 
34 namespace mindspore {
35 namespace parallel {
36 namespace {
37 constexpr char INPUTS[] = "inputs";
38 constexpr char ATTRS[] = "attrs";
39 using FuncGraphNameMap = const std::unordered_map<FuncGraphPtr, std::string>;
40 static std::unordered_map<std::string, size_t> op_count = {};
41 static std::unordered_map<CNodePtr, std::string> name_map = {};
42 
43 // Extract the op name and the topology number of the same node in the graph
44 // e.g, Default/Mul-op32 -> Mul-op0, Default/Mul-op35 -> Mul-op1
GetNodeNameWithCount(const CNodePtr & cnode)45 std::string GetNodeNameWithCount(const CNodePtr &cnode) {
46   if (name_map.find(cnode) != name_map.end()) {
47     return name_map[cnode];
48   }
49 
50   std::string node_name;
51   auto is_call_fullname_with_scope = [](const CNodePtr &cnode) {
52     auto value_ptr = cnode->input(0)->cast<ValueNodePtr>();
53     ValuePtr input_value = nullptr;
54     if (value_ptr != nullptr) {
55       input_value = value_ptr->value();
56     }
57     if (input_value != nullptr && input_value->cast<PrimitivePtr>() == nullptr &&
58         input_value->cast<FuncGraphPtr>() == nullptr) {
59       return false;
60     }
61     return true;
62   };
63   if (is_call_fullname_with_scope(cnode)) {
64     auto node_name_with_scope = cnode->fullname_with_scope();
65     size_t left = node_name_with_scope.rfind('/');
66     size_t right = node_name_with_scope.find("-op");
67     node_name = node_name_with_scope.substr(left + 1, right - left - 1);
68   } else {
69     node_name = cnode->ToString();
70   }
71 
72   std::ostringstream oss;
73   oss << node_name << '-' << op_count[node_name];
74   name_map[cnode] = oss.str();
75   ++op_count[node_name];
76   return name_map[cnode];
77 }
78 
79 // Renames sub-graphs according to the topology order, e.g, @5_construct.395 -> @graph_0
GetAllFuncGraphNameMap(const FuncGraphPtr & graph)80 FuncGraphNameMap GetAllFuncGraphNameMap(const FuncGraphPtr &graph) {
81   MS_EXCEPTION_IF_NULL(graph);
82   auto anf_nodes = TopoSort(graph->get_return(), SuccDeeperSimple, AlwaysInclude);
83   std::unordered_map<FuncGraphPtr, std::string> graph_name_map;
84   size_t graph_count = 0;
85   for (const auto &anf_node : anf_nodes) {
86     auto belong_graph = anf_node->func_graph();
87     if (belong_graph == nullptr) {
88       continue;
89     }
90     if (graph_name_map.find(belong_graph) == graph_name_map.end()) {
91       std::ostringstream oss;
92       oss << "@graph_" << graph_count++;
93       graph_name_map[belong_graph] = oss.str();
94       oss.clear();
95     }
96   }
97   return graph_name_map;
98 }
99 
100 // Extract operator name from cnode
GetCNodeOperatorNameWithCount(const CNodePtr & cnode,const FuncGraphNameMap & func_name_map)101 std::string GetCNodeOperatorNameWithCount(const CNodePtr &cnode, const FuncGraphNameMap &func_name_map) {
102   AnfNodePtr op = cnode->input(0);
103   MS_EXCEPTION_IF_NULL(op);
104   std::string op_name;
105   if (IsValueNode<FuncGraph>(op)) {
106     const FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op);
107     op_name = "call " + func_name_map.at(fg);
108   } else {
109     op_name = GetNodeNameWithCount(cnode);
110     name_map[cnode] = op_name;
111   }
112   return op_name;
113 }
114 
GetPyIntValueFromIntegerImm(const ValuePtr & value_node)115 py::int_ GetPyIntValueFromIntegerImm(const ValuePtr &value_node) {
116   MS_EXCEPTION_IF_NULL(value_node);
117   if (!value_node->isa<IntegerImm>()) {
118     MS_LOG(EXCEPTION) << "value_node is not IntegerImm";
119   }
120 
121   TypePtr data_type = value_node->type();
122   MS_EXCEPTION_IF_NULL(data_type);
123   TypeId type_id = data_type->type_id();
124   switch (type_id) {
125     case kNumberTypeInt8:
126       return py::int_(GetValue<int8_t>(value_node));
127     case kNumberTypeInt16:
128       return py::int_(GetValue<int16_t>(value_node));
129     case kNumberTypeInt32:
130       return py::int_(GetValue<int32_t>(value_node));
131     case kNumberTypeInt64:
132       return py::int_(GetValue<int64_t>(value_node));
133     case kNumberTypeUInt8:
134       return py::int_(GetValue<uint8_t>(value_node));
135     case kNumberTypeUInt16:
136       return py::int_(GetValue<uint16_t>(value_node));
137     case kNumberTypeUInt32:
138       return py::int_(GetValue<uint32_t>(value_node));
139     case kNumberTypeUInt64:
140       return py::int_(GetValue<uint64_t>(value_node));
141     default:
142       MS_LOG(EXCEPTION) << "The data type: " << data_type << " is invalid.";
143   }
144 }
145 
146 // Extract the list of operand names from cnode
GetCNodeOperandNameList(const CNodePtr & cnode,const FuncGraphNameMap & func_name_map)147 py::list GetCNodeOperandNameList(const CNodePtr &cnode, const FuncGraphNameMap &func_name_map) {
148   MS_EXCEPTION_IF_NULL(cnode);
149 
150   py::list cnode_inputs_name_list;
151   auto cnode_inputs = cnode->inputs();
152 
153   // Skip cnode_inputs[0] which is Primitive value node
154   for (size_t i = 1; i < cnode_inputs.size(); ++i) {
155     const AnfNodePtr &input = cnode_inputs[i];
156     MS_EXCEPTION_IF_NULL(input);
157 
158     if (input->isa<Parameter>()) {
159       cnode_inputs_name_list.append(py::str(std::static_pointer_cast<Parameter>(input)->name()));
160     } else if (IsValueNode<FuncGraph>(input)) {
161       FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(input);
162       cnode_inputs_name_list.append(func_name_map.at(fg));
163     } else if (input->isa<CNode>()) {
164       cnode_inputs_name_list.append(py::str(GetNodeNameWithCount(input->cast<CNodePtr>())));
165     } else if (input->isa<ValueNode>()) {
166       auto value_node = GetValueNode(input);
167       if (value_node->isa<IntegerImm>()) {
168         cnode_inputs_name_list.append(GetPyIntValueFromIntegerImm(value_node));
169       } else if (value_node->isa<FP32Imm>()) {
170         cnode_inputs_name_list.append(GetValue<float>(value_node));
171       } else if (value_node->isa<FP64Imm>()) {
172         cnode_inputs_name_list.append(GetValue<double>(value_node));
173       } else if (value_node->isa<BoolImm>()) {
174         cnode_inputs_name_list.append(GetValue<bool>(value_node));
175       } else if (value_node->isa<StringImm>()) {
176         cnode_inputs_name_list.append(py::str(GetValue<std::string>(value_node)));
177       } else {
178         cnode_inputs_name_list.append(py::str(value_node->ToString()));
179       }
180     } else {
181       cnode_inputs_name_list.append(py::str(input->ToString()));
182     }
183   }
184   return cnode_inputs_name_list;
185 }
186 
GetCNodeAttrs(const CNodePtr & cnode)187 py::dict GetCNodeAttrs(const CNodePtr &cnode) {
188   AnfNodePtr op = cnode->input(0);
189   if (op == nullptr || !IsValueNode<Primitive>(op)) {
190     return py::dict();
191   }
192 
193   PrimitivePtr primitive = GetValueNode<PrimitivePtr>(op);
194   auto attrs = primitive->attrs();
195   py::dict cnode_attrs_dict;
196   for (const auto &attr : attrs) {
197     auto key = attr.first;
198     auto value = attr.second;
199     if (value->isa<BoolImm>()) {
200       cnode_attrs_dict[py::str(key)] = GetValue<bool>(value);
201     } else if (value->isa<IntegerImm>()) {
202       cnode_attrs_dict[py::str(key)] = GetPyIntValueFromIntegerImm(value);
203     } else if (value->isa<FP32Imm>()) {
204       cnode_attrs_dict[py::str(key)] = GetValue<float>(value);
205     } else if (value->isa<FP64Imm>()) {
206       cnode_attrs_dict[py::str(key)] = GetValue<double>(value);
207     } else {
208       cnode_attrs_dict[py::str(attr.first)] = py::str(attr.second->ToString());
209     }
210   }
211   return cnode_attrs_dict;
212 }
213 
214 // Get cnode info dict in subgraph.
GetParallelCNodeInfoFromSubGraph(const FuncGraphPtr & sub_graph,const FuncGraphNameMap & func_name_map)215 py::dict GetParallelCNodeInfoFromSubGraph(const FuncGraphPtr &sub_graph, const FuncGraphNameMap &func_name_map) {
216   MS_EXCEPTION_IF_NULL(sub_graph);
217   op_count.clear();
218   name_map.clear();
219 
220   py::dict cnode_info_dict;
221   auto cnodes = sub_graph->GetOrderedCnodes();
222   for (auto cnode = cnodes.cbegin(); cnode != cnodes.cend(); ++cnode) {
223     std::string op_name_with_count = GetCNodeOperatorNameWithCount(*cnode, func_name_map);
224     py::dict cnode_info;
225     cnode_info[INPUTS] = GetCNodeOperandNameList(*cnode, func_name_map);
226     cnode_info[ATTRS] = GetCNodeAttrs(*cnode);
227     cnode_info_dict[py::str(op_name_with_count)] = cnode_info;
228   }
229   return cnode_info_dict;
230 }
231 
GetSharedParameterInfo(const AnfNodePtr & param)232 std::tuple<bool, bool, int64_t, int64_t> GetSharedParameterInfo(const AnfNodePtr &param) {
233   MS_EXCEPTION_IF_NULL(param);
234   bool is_pipeline_shared = false;
235   bool is_send = false;
236   int64_t peer_rank = 0;
237   int64_t sr_tag = 0;
238 
239   auto shared_params = param->user_data<parallel::SharedParameter>();
240   if (shared_params) {
241     is_pipeline_shared = shared_params->pipeline_shared();
242     is_send = shared_params->is_send();
243     peer_rank = shared_params->peer_rank();
244     sr_tag = shared_params->sr_tag();
245   }
246   return std::tuple(is_pipeline_shared, is_send, peer_rank, sr_tag);
247 }
248 }  // namespace
249 
GetParameterLayoutFromGraph(const FuncGraphPtr & graph)250 py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
251   MS_EXCEPTION_IF_NULL(graph);
252   py::dict dict;
253   std::vector<AnfNodePtr> graph_params = graph->parameters();
254 
255   for (auto para : graph_params) {
256     auto param_ptr = para->cast<ParameterPtr>();
257     MS_EXCEPTION_IF_NULL(param_ptr);
258     std::vector<std::string> names = {param_ptr->name()};
259     auto param_info = param_ptr->param_info();
260     if (param_info) {
261       auto cloned_obj = GetPyParameterObj(param_info, CLONED_OBJ);
262       if (!py::isinstance<py::none>(cloned_obj) && py::isinstance<py::list>(cloned_obj)) {
263         auto obj_list = py::cast<py::list>(cloned_obj);
264         for (size_t i = 0; i < obj_list.size(); ++i) {
265           auto each_obj = obj_list[i];
266           if (py::hasattr(each_obj, "name")) {
267             auto name_obj = python_adapter::GetPyObjAttr(each_obj, "name");
268             names.push_back(py::cast<std::string>(name_obj));
269           }
270         }
271       }
272     }
273     auto tensor_layout = para->user_data<parallel::TensorLayout>();
274     if (tensor_layout == nullptr) {
275       MS_LOG(INFO) << "GetParameterLayout nullptr parameter: " << para->DebugString();
276     } else {
277       const auto &device_arrangement = tensor_layout->device_arrangement().array();
278       const auto &tensor_map = tensor_layout->tensor_map().array();
279       const auto &slice_shape = tensor_layout->base_slice_shape().array();
280       int64_t field_size = tensor_layout->get_field_size();
281       bool uniform_split = tensor_layout->uniform_split();
282       const std::string &opt_shard_group = tensor_layout->opt_shard_group();
283       auto [is_pipeline_shared, is_send, peer_rank, sr_tag] = GetSharedParameterInfo(para);
284       const auto &before_full_shape = tensor_layout->tensor_shape_before().array();
285       const auto &after_slice_shape = tensor_layout->slice_shape().array();
286       py::tuple layout = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
287                                         opt_shard_group, before_full_shape, after_slice_shape,
288                                         is_pipeline_shared, is_send, peer_rank, sr_tag);
289       for (auto &name : names) {
290         dict[py::str(name)] = layout;
291       }
292       MS_LOG(INFO) << "GetParameterLayout parameter: " << para->DebugString() << ", layout "
293                    << tensor_layout->ToString();
294     }
295   }
296   return dict;
297 }
298 
GetParameterLayoutFromResource(const pipeline::ResourcePtr & resource)299 py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
300   py::dict dict;
301   const auto &layout_map = resource->layout_map();
302   for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
303     auto name = iter->first;
304     auto layout = iter->second;
305     const auto &device_arrangement = layout->get_device_arrangement();
306     const auto &tensor_map = layout->get_tensor_map();
307     const auto &slice_shape = layout->get_slice_shape();
308     int64_t field_size = layout->get_field_size();
309     bool uniform_split = layout->get_uniform_split();
310     std::vector<int64_t> before_full_shape;
311     std::vector<int64_t> after_slice_shape;
312     const std::string &opt_shard_group = layout->get_opt_shard_group();
313     bool is_pipeline_shared = layout->pipeline_shared();
314     bool is_send = layout->is_send();
315     int64_t peer_rank = layout->peer_rank();
316     int64_t sr_tag = layout->sr_tag();
317     py::tuple layout_tuple = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
318                                             opt_shard_group, before_full_shape, after_slice_shape,
319                                             is_pipeline_shared, is_send, peer_rank, sr_tag);
320     dict[py::str(name)] = layout_tuple;
321   }
322   return dict;
323 }
324 
GetAllreduceFusion(const FuncGraphPtr & graph)325 py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
326   MS_EXCEPTION_IF_NULL(graph);
327   py::dict dict;
328   auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE);
329 
330   for (auto prim : allreduce_prim_list) {
331     auto name_ptr = prim->GetAttr("parameter");
332     auto fusion_ptr = prim->GetAttr("fusion");
333     if (fusion_ptr == nullptr) {
334       MS_LOG(EXCEPTION) << "fusion_ptr is nullptr";
335     } else if (name_ptr == nullptr) {
336       continue;
337     }
338     if (!name_ptr->isa<StringImm>()) {
339       MS_LOG(EXCEPTION) << "name is not StringImm";
340     }
341     auto name = name_ptr->cast<StringImmPtr>()->value();
342     if (!fusion_ptr->isa<Int64Imm>()) {
343       MS_LOG(EXCEPTION) << "fusion is not Int64Imm";
344     }
345     int64_t fusion = fusion_ptr->cast<Int64ImmPtr>()->value();
346     dict[py::str(name)] = fusion;
347   }
348   return dict;
349 }
350 
351 // In pipeline parallel mode, many parameters are not used and need to be deleted
GetParallelParameterNameListFromGraph(const FuncGraphPtr & graph)352 py::list GetParallelParameterNameListFromGraph(const FuncGraphPtr &graph) {
353   MS_EXCEPTION_IF_NULL(graph);
354 
355   py::list parallel_parameter_name_list;
356   std::vector<AnfNodePtr> graph_params = graph->parameters();
357 
358   for (auto param : graph_params) {
359     auto param_ptr = std::static_pointer_cast<Parameter>(param);
360     MS_EXCEPTION_IF_NULL(param_ptr);
361     std::string name = param_ptr->name();
362     parallel_parameter_name_list.append(name);
363   }
364   return parallel_parameter_name_list;
365 }
366 
GetParallelParameterNameListFromResource(const pipeline::ResourcePtr & resource)367 py::list GetParallelParameterNameListFromResource(const pipeline::ResourcePtr &resource) {
368   auto &layout_map = resource->layout_map();
369   py::list parallel_parameter_name_list;
370   for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
371     auto name = iter->first;
372     parallel_parameter_name_list.append(name);
373   }
374   return parallel_parameter_name_list;
375 }
376 
GetParallelCNodeInfoFromGraph(const FuncGraphPtr & graph)377 py::dict GetParallelCNodeInfoFromGraph(const FuncGraphPtr &graph) {
378   MS_EXCEPTION_IF_NULL(graph);
379   // Search and mapping all subgraph names
380   auto func_name_map = GetAllFuncGraphNameMap(graph);
381   py::dict parallel_cnode_info_dict;
382 
383   // Get cnode info dict in each subgraph in turn
384   for (const auto &kv : func_name_map) {
385     auto sub_graph_cnode_info_dict = GetParallelCNodeInfoFromSubGraph(kv.first, func_name_map);
386     parallel_cnode_info_dict[py::str(kv.second)] = sub_graph_cnode_info_dict;
387   }
388   op_count.clear();
389   name_map.clear();
390   return parallel_cnode_info_dict;
391 }
392 }  // namespace parallel
393 }  // namespace mindspore
394