• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/allreduce_fusion/allreduce_fusion.h"
18 #include <memory>
19 #include <queue>
20 #include <string>
21 #include <functional>
22 #include <utility>
23 #include <vector>
24 #include <unordered_map>
25 #include "mindspore/core/ops/other_ops.h"
26 #include "utils/hash_set.h"
27 #include "ir/func_graph.h"
28 #include "frontend/parallel/costmodel_context.h"
29 #include "frontend/parallel/graph_util/node_info.h"
30 #include "frontend/parallel/status.h"
31 #include "frontend/parallel/parameter_manager.h"
32 #include "frontend/parallel/step_parallel.h"
33 #include "utils/log_adapter.h"
34 
35 namespace mindspore {
36 namespace parallel {
SetMirrorFusion(const CNodePtr & mirror_cnode,int64_t fusion,const std::string & parameter_name)37 void SetMirrorFusion(const CNodePtr &mirror_cnode, int64_t fusion, const std::string &parameter_name) {
38   MS_EXCEPTION_IF_NULL(mirror_cnode);
39   MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion;
40   auto node_prim = GetValueNode<PrimitivePtr>(mirror_cnode->input(0));
41   (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared<Int64Imm>(fusion)));
42   (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared<StringImm>(parameter_name)));
43 }
44 
AdjustRelatedFusionNode(const CNodePtr & ret,const std::unordered_map<std::string,CNodePtr> & comm_node_map)45 void AdjustRelatedFusionNode(const CNodePtr &ret, const std::unordered_map<std::string, CNodePtr> &comm_node_map) {
46   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
47   for (const auto &related_node : all_nodes) {
48     if (!IsPrimitiveCNode(related_node)) {
49       continue;
50     }
51     auto related_cnode = related_node->cast<CNodePtr>();
52     if (!related_cnode->HasAttr(kRelatedCommNodeId)) {
53       continue;
54     }
55     auto related_comm_node_id = GetValue<std::string>(related_cnode->GetAttr(kRelatedCommNodeId));
56     if (comm_node_map.find(related_comm_node_id) == comm_node_map.end()) {
57       continue;
58     }
59     auto comm_cnode = comm_node_map.at(related_comm_node_id);
60     if (!IsPrimitiveCNode(comm_cnode)) {
61       continue;
62     }
63     auto node_prim = GetValueNode<PrimitivePtr>(comm_cnode->input(0));
64     if (!node_prim->HasAttr(FUSION)) {
65       continue;
66     }
67     if (!related_cnode->HasPrimalAttr(kRelatedNodeId) || !related_cnode->HasPrimalAttr(kRelatedFusionKey)) {
68       continue;
69     }
70     auto related_fusion_key = GetValue<std::string>(related_cnode->GetPrimalAttr(kRelatedFusionKey));
71     auto fusion_id_pos = related_fusion_key.rfind("_");
72     if (fusion_id_pos != std::string::npos) {
73       auto sub_str = related_fusion_key.substr(0, fusion_id_pos);
74       auto auto_fusion_id = GetValue<int64_t>(node_prim->GetAttr(FUSION));
75       auto new_related_fusion_key = sub_str + "_" + std::to_string(auto_fusion_id);
76       MS_LOG(INFO) << "replace related fusion key to: " << new_related_fusion_key;
77       related_cnode->AddPrimalAttr(kRelatedFusionKey, MakeValue<std::string>(new_related_fusion_key));
78     }
79   }
80 }
81 
SetFusionBySize(const CNodePtr & ret,int64_t threshold,const PrimitivePtr & primp) const82 Status AllCommFusion::SetFusionBySize(const CNodePtr &ret, int64_t threshold, const PrimitivePtr &primp) const {
83   auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
84   auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
85   auto temp = threshold;
86   int64_t fusion = 1;
87   bool init = true;
88   std::string parameter_name;
89   std::string name;
90   std::unordered_map<std::string, CNodePtr> comm_node_map;
91   for (auto &node : todo) {
92     auto cnode = node->cast<CNodePtr>();
93     if (cnode->input(1)->Shape() == nullptr) {
94       continue;
95     }
96     auto input_shapes = GetNodeShape(cnode->input(1));
97     int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
98     FuncGraphPtr func_graph = cnode->func_graph();
99     if (IsPrimitiveEquals(primp, prim::kPrimMirror)) {
100       name = ALL_REDUCE;
101       std::pair<AnfNodePtr, bool> param_node_pair = FindParameter(cnode->input(1), func_graph);
102       if (!param_node_pair.first) {
103         continue;
104       }
105       parameter_name = ParameterName(param_node_pair.first);
106     }
107 
108     if (IsPrimitiveEquals(primp, prim::kPrimMicroStepAllGather) || IsPrimitiveEquals(primp, prim::kPrimAllGather)) {
109       name = ALL_GATHER;
110       if (!cnode->input(0) || !cnode->input(1)) {
111         continue;
112       }
113       PrimitivePtr primp1 = GetValueNode<PrimitivePtr>(cnode->input(0));
114       if (!primp1->HasAttr(RECOMPUTE) || GetValue<bool>(primp1->GetAttr(RECOMPUTE))) {
115         continue;
116       }
117       std::pair<AnfNodePtr, bool> param_node_pair = FindParameterWithAllgather(cnode->input(1), func_graph, name);
118       if (!param_node_pair.first) {
119         continue;
120       }
121       parameter_name = ParameterName(param_node_pair.first);
122     }
123 
124     if (init || input_size < temp) {
125       temp -= input_size;
126       init = false;
127     } else {
128       temp = threshold;
129       fusion++;
130     }
131     SetMirrorFusion(cnode, fusion, parameter_name);
132     comm_node_map[cnode->UniqueId()] = cnode;
133   }
134   AdjustRelatedFusionNode(ret, comm_node_map);
135   MS_LOG(INFO) << name << " fusion by size succeed.";
136   return SUCCESS;
137 }
138 
SetFusionBySizeReduceScatter(const CNodePtr & ret,int64_t threshold,const PrimitivePtr & primp) const139 Status AllCommFusion::SetFusionBySizeReduceScatter(const CNodePtr &ret, int64_t threshold,
140                                                    const PrimitivePtr &primp) const {
141   auto filter = [primp](const AnfNodePtr &node) { return !IsPrimitiveCNode(node, primp); };
142   auto todo = DeepScopedGraphSearchWithFilter(ret, AlwaysInclude, filter);
143   auto temp = threshold;
144   int64_t fusion = 1;
145   bool init = true;
146   std::unordered_map<std::string, CNodePtr> comm_node_map;
147   for (auto &node : todo) {
148     auto cnode = node->cast<CNodePtr>();
149     if (cnode->input(1) == nullptr) {
150       continue;
151     }
152     FuncGraphPtr func_graph = cnode->func_graph();
153     std::pair<AnfNodePtr, bool> param_node_pair =
154       FindParameterWithAllgather(cnode->input(1), func_graph, REDUCE_SCATTER);
155     if (!param_node_pair.first) {
156       continue;
157     }
158     auto parameter_name = ParameterName(param_node_pair.first);
159     auto input_shapes = GetNodeShape(param_node_pair.first);
160     int64_t input_size = std::accumulate(input_shapes[0].begin(), input_shapes[0].end(), 1, std::multiplies<int64_t>());
161     if (init || input_size < temp) {
162       temp -= input_size;
163       init = false;
164     } else {
165       temp = threshold;
166       fusion++;
167     }
168     SetMirrorFusion(cnode, fusion, parameter_name);
169     comm_node_map[cnode->UniqueId()] = cnode;
170   }
171   AdjustRelatedFusionNode(ret, comm_node_map);
172   MS_LOG(INFO) << "Reduce_Scatter fusion by size succeed.";
173   return SUCCESS;
174 }
175 
ProcessCommOpsFusion(const CNodePtr & ret,const std::string & comm_name)176 Status AllCommFusion::ProcessCommOpsFusion(const CNodePtr &ret, const std::string &comm_name) {
177   if (ret == nullptr) {
178     MS_LOG(ERROR) << "ret is nullptr.";
179     return FAILED;
180   }
181   ret_ = ret;
182   root_graph_ = ret_->func_graph();
183   MS_EXCEPTION_IF_NULL(root_graph_);
184   auto graph_set = ForwardGraph(root_graph_);
185   if (graph_set.size() > 1) {
186     MS_LOG(INFO) << comm_name << " fusion don't support multiple subgraphs now.";
187     return SUCCESS;
188   }
189   auto forward_graph = *(graph_set.begin());
190   MS_EXCEPTION_IF_NULL(forward_graph);
191   forward_ret_ = forward_graph->get_return();
192   MS_EXCEPTION_IF_NULL(forward_ret_);
193   if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) {
194     MS_LOG(ERROR) << comm_name << "Graph set_head_cnode failed.";
195     return FAILED;
196   }
197   int64_t threshold = 0;
198   if (comm_name == ALL_REDUCE) {
199     threshold = ParallelContext::GetInstance()->fusion_threshold_mb();
200   } else if (comm_name == ALL_GATHER) {
201     threshold = ParallelContext::GetInstance()->allgather_fusion_threshold_mb();
202   } else if (comm_name == REDUCE_SCATTER) {
203     threshold = ParallelContext::GetInstance()->reducescatter_fusion_threshold_mb();
204   } else {
205     MS_LOG(ERROR) << " Comm Ops must be ALL_REDUCE, ALL_GATHER or REDUCE_SCATTER, but got " << comm_name;
206   }
207   threshold *= DEFAULT_THRESHOLD_MB_TO_BYTE;
208   if (threshold <= 0) {
209     MS_LOG(ERROR) << "The threshold of" << comm_name << "fusion must be larger than 0, but got " << threshold << ".";
210     return FAILED;
211   }
212   if (comm_name == REDUCE_SCATTER) {
213     (void)SetFusionBySizeReduceScatter(ret, threshold, prim::kPrimVirtualAssignAdd);
214   }
215   if (comm_name == ALL_REDUCE) {
216     (void)SetFusionBySize(ret, threshold, prim::kPrimMirror);
217   }
218   if (comm_name == ALL_GATHER) {
219     (void)SetFusionBySize(ret, threshold, prim::kPrimMicroStepAllGather);
220     (void)SetFusionBySize(ret, threshold, prim::kPrimAllGather);
221   }
222 
223   return SUCCESS;
224 }
225 }  // namespace parallel
226 }  // namespace mindspore
227