1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/graph_util/get_parallel_info.h"
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include <tuple>
23 #include <unordered_map>
24
25 #include "ir/func_graph.h"
26 #include "frontend/parallel/ops_info/operator_info.h"
27 #include "frontend/parallel/graph_util/graph_info.h"
28 #include "frontend/parallel/strategy.h"
29 #include "frontend/parallel/tensor_layout/tensor_layout.h"
30 #include "frontend/parallel/ops_info/ops_utils.h"
31 #include "frontend/parallel/parameter_manager.h"
32 #include "frontend/parallel/tensor_layout/shared_parameter.h"
33
34 namespace mindspore {
35 namespace parallel {
36 namespace {
37 constexpr char INPUTS[] = "inputs";
38 constexpr char ATTRS[] = "attrs";
39 using FuncGraphNameMap = const std::unordered_map<FuncGraphPtr, std::string>;
40 static std::unordered_map<std::string, size_t> op_count = {};
41 static std::unordered_map<CNodePtr, std::string> name_map = {};
42
43 // Extract the op name and the topology number of the same node in the graph
44 // e.g, Default/Mul-op32 -> Mul-op0, Default/Mul-op35 -> Mul-op1
GetNodeNameWithCount(const CNodePtr & cnode)45 std::string GetNodeNameWithCount(const CNodePtr &cnode) {
46 if (name_map.find(cnode) != name_map.end()) {
47 return name_map[cnode];
48 }
49
50 std::string node_name;
51 auto is_call_fullname_with_scope = [](const CNodePtr &cnode) {
52 auto value_ptr = cnode->input(0)->cast<ValueNodePtr>();
53 ValuePtr input_value = nullptr;
54 if (value_ptr != nullptr) {
55 input_value = value_ptr->value();
56 }
57 if (input_value != nullptr && input_value->cast<PrimitivePtr>() == nullptr &&
58 input_value->cast<FuncGraphPtr>() == nullptr) {
59 return false;
60 }
61 return true;
62 };
63 if (is_call_fullname_with_scope(cnode)) {
64 auto node_name_with_scope = cnode->fullname_with_scope();
65 size_t left = node_name_with_scope.rfind('/');
66 size_t right = node_name_with_scope.find("-op");
67 node_name = node_name_with_scope.substr(left + 1, right - left - 1);
68 } else {
69 node_name = cnode->ToString();
70 }
71
72 std::ostringstream oss;
73 oss << node_name << '-' << op_count[node_name];
74 name_map[cnode] = oss.str();
75 ++op_count[node_name];
76 return name_map[cnode];
77 }
78
79 // Renames sub-graphs according to the topology order, e.g, @5_construct.395 -> @graph_0
GetAllFuncGraphNameMap(const FuncGraphPtr & graph)80 FuncGraphNameMap GetAllFuncGraphNameMap(const FuncGraphPtr &graph) {
81 MS_EXCEPTION_IF_NULL(graph);
82 auto anf_nodes = TopoSort(graph->get_return(), SuccDeeperSimple, AlwaysInclude);
83 std::unordered_map<FuncGraphPtr, std::string> graph_name_map;
84 size_t graph_count = 0;
85 for (const auto &anf_node : anf_nodes) {
86 auto belong_graph = anf_node->func_graph();
87 if (belong_graph == nullptr) {
88 continue;
89 }
90 if (graph_name_map.find(belong_graph) == graph_name_map.end()) {
91 std::ostringstream oss;
92 oss << "@graph_" << graph_count++;
93 graph_name_map[belong_graph] = oss.str();
94 oss.clear();
95 }
96 }
97 return graph_name_map;
98 }
99
100 // Extract operator name from cnode
GetCNodeOperatorNameWithCount(const CNodePtr & cnode,const FuncGraphNameMap & func_name_map)101 std::string GetCNodeOperatorNameWithCount(const CNodePtr &cnode, const FuncGraphNameMap &func_name_map) {
102 AnfNodePtr op = cnode->input(0);
103 MS_EXCEPTION_IF_NULL(op);
104 std::string op_name;
105 if (IsValueNode<FuncGraph>(op)) {
106 const FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(op);
107 op_name = "call " + func_name_map.at(fg);
108 } else {
109 op_name = GetNodeNameWithCount(cnode);
110 name_map[cnode] = op_name;
111 }
112 return op_name;
113 }
114
GetPyIntValueFromIntegerImm(const ValuePtr & value_node)115 py::int_ GetPyIntValueFromIntegerImm(const ValuePtr &value_node) {
116 MS_EXCEPTION_IF_NULL(value_node);
117 if (!value_node->isa<IntegerImm>()) {
118 MS_LOG(EXCEPTION) << "value_node is not IntegerImm";
119 }
120
121 TypePtr data_type = value_node->type();
122 MS_EXCEPTION_IF_NULL(data_type);
123 TypeId type_id = data_type->type_id();
124 switch (type_id) {
125 case kNumberTypeInt8:
126 return py::int_(GetValue<int8_t>(value_node));
127 case kNumberTypeInt16:
128 return py::int_(GetValue<int16_t>(value_node));
129 case kNumberTypeInt32:
130 return py::int_(GetValue<int32_t>(value_node));
131 case kNumberTypeInt64:
132 return py::int_(GetValue<int64_t>(value_node));
133 case kNumberTypeUInt8:
134 return py::int_(GetValue<uint8_t>(value_node));
135 case kNumberTypeUInt16:
136 return py::int_(GetValue<uint16_t>(value_node));
137 case kNumberTypeUInt32:
138 return py::int_(GetValue<uint32_t>(value_node));
139 case kNumberTypeUInt64:
140 return py::int_(GetValue<uint64_t>(value_node));
141 default:
142 MS_LOG(EXCEPTION) << "The data type: " << data_type << " is invalid.";
143 }
144 }
145
146 // Extract the list of operand names from cnode
GetCNodeOperandNameList(const CNodePtr & cnode,const FuncGraphNameMap & func_name_map)147 py::list GetCNodeOperandNameList(const CNodePtr &cnode, const FuncGraphNameMap &func_name_map) {
148 MS_EXCEPTION_IF_NULL(cnode);
149
150 py::list cnode_inputs_name_list;
151 auto cnode_inputs = cnode->inputs();
152
153 // Skip cnode_inputs[0] which is Primitive value node
154 for (size_t i = 1; i < cnode_inputs.size(); ++i) {
155 const AnfNodePtr &input = cnode_inputs[i];
156 MS_EXCEPTION_IF_NULL(input);
157
158 if (input->isa<Parameter>()) {
159 cnode_inputs_name_list.append(py::str(std::static_pointer_cast<Parameter>(input)->name()));
160 } else if (IsValueNode<FuncGraph>(input)) {
161 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(input);
162 cnode_inputs_name_list.append(func_name_map.at(fg));
163 } else if (input->isa<CNode>()) {
164 cnode_inputs_name_list.append(py::str(GetNodeNameWithCount(input->cast<CNodePtr>())));
165 } else if (input->isa<ValueNode>()) {
166 auto value_node = GetValueNode(input);
167 if (value_node->isa<IntegerImm>()) {
168 cnode_inputs_name_list.append(GetPyIntValueFromIntegerImm(value_node));
169 } else if (value_node->isa<FP32Imm>()) {
170 cnode_inputs_name_list.append(GetValue<float>(value_node));
171 } else if (value_node->isa<FP64Imm>()) {
172 cnode_inputs_name_list.append(GetValue<double>(value_node));
173 } else if (value_node->isa<BoolImm>()) {
174 cnode_inputs_name_list.append(GetValue<bool>(value_node));
175 } else if (value_node->isa<StringImm>()) {
176 cnode_inputs_name_list.append(py::str(GetValue<std::string>(value_node)));
177 } else {
178 cnode_inputs_name_list.append(py::str(value_node->ToString()));
179 }
180 } else {
181 cnode_inputs_name_list.append(py::str(input->ToString()));
182 }
183 }
184 return cnode_inputs_name_list;
185 }
186
GetCNodeAttrs(const CNodePtr & cnode)187 py::dict GetCNodeAttrs(const CNodePtr &cnode) {
188 AnfNodePtr op = cnode->input(0);
189 if (op == nullptr || !IsValueNode<Primitive>(op)) {
190 return py::dict();
191 }
192
193 PrimitivePtr primitive = GetValueNode<PrimitivePtr>(op);
194 auto attrs = primitive->attrs();
195 py::dict cnode_attrs_dict;
196 for (const auto &attr : attrs) {
197 auto key = attr.first;
198 auto value = attr.second;
199 if (value->isa<BoolImm>()) {
200 cnode_attrs_dict[py::str(key)] = GetValue<bool>(value);
201 } else if (value->isa<IntegerImm>()) {
202 cnode_attrs_dict[py::str(key)] = GetPyIntValueFromIntegerImm(value);
203 } else if (value->isa<FP32Imm>()) {
204 cnode_attrs_dict[py::str(key)] = GetValue<float>(value);
205 } else if (value->isa<FP64Imm>()) {
206 cnode_attrs_dict[py::str(key)] = GetValue<double>(value);
207 } else {
208 cnode_attrs_dict[py::str(attr.first)] = py::str(attr.second->ToString());
209 }
210 }
211 return cnode_attrs_dict;
212 }
213
214 // Get cnode info dict in subgraph.
GetParallelCNodeInfoFromSubGraph(const FuncGraphPtr & sub_graph,const FuncGraphNameMap & func_name_map)215 py::dict GetParallelCNodeInfoFromSubGraph(const FuncGraphPtr &sub_graph, const FuncGraphNameMap &func_name_map) {
216 MS_EXCEPTION_IF_NULL(sub_graph);
217 op_count.clear();
218 name_map.clear();
219
220 py::dict cnode_info_dict;
221 auto cnodes = sub_graph->GetOrderedCnodes();
222 for (auto cnode = cnodes.cbegin(); cnode != cnodes.cend(); ++cnode) {
223 std::string op_name_with_count = GetCNodeOperatorNameWithCount(*cnode, func_name_map);
224 py::dict cnode_info;
225 cnode_info[INPUTS] = GetCNodeOperandNameList(*cnode, func_name_map);
226 cnode_info[ATTRS] = GetCNodeAttrs(*cnode);
227 cnode_info_dict[py::str(op_name_with_count)] = cnode_info;
228 }
229 return cnode_info_dict;
230 }
231
GetSharedParameterInfo(const AnfNodePtr & param)232 std::tuple<bool, bool, int64_t, int64_t> GetSharedParameterInfo(const AnfNodePtr ¶m) {
233 MS_EXCEPTION_IF_NULL(param);
234 bool is_pipeline_shared = false;
235 bool is_send = false;
236 int64_t peer_rank = 0;
237 int64_t sr_tag = 0;
238
239 auto shared_params = param->user_data<parallel::SharedParameter>();
240 if (shared_params) {
241 is_pipeline_shared = shared_params->pipeline_shared();
242 is_send = shared_params->is_send();
243 peer_rank = shared_params->peer_rank();
244 sr_tag = shared_params->sr_tag();
245 }
246 return std::tuple(is_pipeline_shared, is_send, peer_rank, sr_tag);
247 }
248 } // namespace
249
GetParameterLayoutFromGraph(const FuncGraphPtr & graph)250 py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
251 MS_EXCEPTION_IF_NULL(graph);
252 py::dict dict;
253 std::vector<AnfNodePtr> graph_params = graph->parameters();
254
255 for (auto para : graph_params) {
256 auto param_ptr = para->cast<ParameterPtr>();
257 MS_EXCEPTION_IF_NULL(param_ptr);
258 std::vector<std::string> names = {param_ptr->name()};
259 auto param_info = param_ptr->param_info();
260 if (param_info) {
261 auto cloned_obj = GetPyParameterObj(param_info, CLONED_OBJ);
262 if (!py::isinstance<py::none>(cloned_obj) && py::isinstance<py::list>(cloned_obj)) {
263 auto obj_list = py::cast<py::list>(cloned_obj);
264 for (size_t i = 0; i < obj_list.size(); ++i) {
265 auto each_obj = obj_list[i];
266 if (py::hasattr(each_obj, "name")) {
267 auto name_obj = python_adapter::GetPyObjAttr(each_obj, "name");
268 names.push_back(py::cast<std::string>(name_obj));
269 }
270 }
271 }
272 }
273 auto tensor_layout = para->user_data<parallel::TensorLayout>();
274 if (tensor_layout == nullptr) {
275 MS_LOG(INFO) << "GetParameterLayout nullptr parameter: " << para->DebugString();
276 } else {
277 const auto &device_arrangement = tensor_layout->device_arrangement().array();
278 const auto &tensor_map = tensor_layout->tensor_map().array();
279 const auto &slice_shape = tensor_layout->base_slice_shape().array();
280 int64_t field_size = tensor_layout->get_field_size();
281 bool uniform_split = tensor_layout->uniform_split();
282 const std::string &opt_shard_group = tensor_layout->opt_shard_group();
283 auto [is_pipeline_shared, is_send, peer_rank, sr_tag] = GetSharedParameterInfo(para);
284 const auto &before_full_shape = tensor_layout->tensor_shape_before().array();
285 const auto &after_slice_shape = tensor_layout->slice_shape().array();
286 py::tuple layout = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
287 opt_shard_group, before_full_shape, after_slice_shape,
288 is_pipeline_shared, is_send, peer_rank, sr_tag);
289 for (auto &name : names) {
290 dict[py::str(name)] = layout;
291 }
292 MS_LOG(INFO) << "GetParameterLayout parameter: " << para->DebugString() << ", layout "
293 << tensor_layout->ToString();
294 }
295 }
296 return dict;
297 }
298
GetParameterLayoutFromResource(const pipeline::ResourcePtr & resource)299 py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
300 py::dict dict;
301 const auto &layout_map = resource->layout_map();
302 for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
303 auto name = iter->first;
304 auto layout = iter->second;
305 const auto &device_arrangement = layout->get_device_arrangement();
306 const auto &tensor_map = layout->get_tensor_map();
307 const auto &slice_shape = layout->get_slice_shape();
308 int64_t field_size = layout->get_field_size();
309 bool uniform_split = layout->get_uniform_split();
310 std::vector<int64_t> before_full_shape;
311 std::vector<int64_t> after_slice_shape;
312 const std::string &opt_shard_group = layout->get_opt_shard_group();
313 bool is_pipeline_shared = layout->pipeline_shared();
314 bool is_send = layout->is_send();
315 int64_t peer_rank = layout->peer_rank();
316 int64_t sr_tag = layout->sr_tag();
317 py::tuple layout_tuple = py::make_tuple(device_arrangement, tensor_map, slice_shape, field_size, uniform_split,
318 opt_shard_group, before_full_shape, after_slice_shape,
319 is_pipeline_shared, is_send, peer_rank, sr_tag);
320 dict[py::str(name)] = layout_tuple;
321 }
322 return dict;
323 }
324
GetAllreduceFusion(const FuncGraphPtr & graph)325 py::dict GetAllreduceFusion(const FuncGraphPtr &graph) {
326 MS_EXCEPTION_IF_NULL(graph);
327 py::dict dict;
328 auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE);
329
330 for (auto prim : allreduce_prim_list) {
331 auto name_ptr = prim->GetAttr("parameter");
332 auto fusion_ptr = prim->GetAttr("fusion");
333 if (fusion_ptr == nullptr) {
334 MS_LOG(EXCEPTION) << "fusion_ptr is nullptr";
335 } else if (name_ptr == nullptr) {
336 continue;
337 }
338 if (!name_ptr->isa<StringImm>()) {
339 MS_LOG(EXCEPTION) << "name is not StringImm";
340 }
341 auto name = name_ptr->cast<StringImmPtr>()->value();
342 if (!fusion_ptr->isa<Int64Imm>()) {
343 MS_LOG(EXCEPTION) << "fusion is not Int64Imm";
344 }
345 int64_t fusion = fusion_ptr->cast<Int64ImmPtr>()->value();
346 dict[py::str(name)] = fusion;
347 }
348 return dict;
349 }
350
351 // In pipeline parallel mode, many parameters are not used and need to be deleted
GetParallelParameterNameListFromGraph(const FuncGraphPtr & graph)352 py::list GetParallelParameterNameListFromGraph(const FuncGraphPtr &graph) {
353 MS_EXCEPTION_IF_NULL(graph);
354
355 py::list parallel_parameter_name_list;
356 std::vector<AnfNodePtr> graph_params = graph->parameters();
357
358 for (auto param : graph_params) {
359 auto param_ptr = std::static_pointer_cast<Parameter>(param);
360 MS_EXCEPTION_IF_NULL(param_ptr);
361 std::string name = param_ptr->name();
362 parallel_parameter_name_list.append(name);
363 }
364 return parallel_parameter_name_list;
365 }
366
GetParallelParameterNameListFromResource(const pipeline::ResourcePtr & resource)367 py::list GetParallelParameterNameListFromResource(const pipeline::ResourcePtr &resource) {
368 auto &layout_map = resource->layout_map();
369 py::list parallel_parameter_name_list;
370 for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
371 auto name = iter->first;
372 parallel_parameter_name_list.append(name);
373 }
374 return parallel_parameter_name_list;
375 }
376
GetParallelCNodeInfoFromGraph(const FuncGraphPtr & graph)377 py::dict GetParallelCNodeInfoFromGraph(const FuncGraphPtr &graph) {
378 MS_EXCEPTION_IF_NULL(graph);
379 // Search and mapping all subgraph names
380 auto func_name_map = GetAllFuncGraphNameMap(graph);
381 py::dict parallel_cnode_info_dict;
382
383 // Get cnode info dict in each subgraph in turn
384 for (const auto &kv : func_name_map) {
385 auto sub_graph_cnode_info_dict = GetParallelCNodeInfoFromSubGraph(kv.first, func_name_map);
386 parallel_cnode_info_dict[py::str(kv.second)] = sub_graph_cnode_info_dict;
387 }
388 op_count.clear();
389 name_map.clear();
390 return parallel_cnode_info_dict;
391 }
392 } // namespace parallel
393 } // namespace mindspore
394