• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pass/matmul_add_comm_reduction.h"
18 #include <memory>
19 #include <list>
20 #include <vector>
21 #include <string>
22 #include <utility>
23 #include "include/common/utils/utils.h"
24 #include "frontend/optimizer/optimizer.h"
25 #include "frontend/parallel/step_parallel.h"
26 #include "frontend/parallel/step_parallel_utils.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/other_ops.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 namespace {
33 constexpr size_t kCommReductionValidCommOpsNum = 2;
34 constexpr auto MATMUL_ADD_COMM_BEGIN = "matmul_add_comm_begin";
35 constexpr auto MATMUL_ADD_COMM_END = "matmul_add_comm_end";
36 constexpr auto MATMUL_ADD_COMM_MUL = "matmul_add_comm_mul";
37 constexpr const char MATMUL_ADD_COMM_REDUCTION[] = "matmul_add_comm_reduction";
38 
IsSubRankList(const RankList & child_list,const RankList & parent_list)39 bool IsSubRankList(const RankList &child_list, const RankList &parent_list) {
40   for (auto &child : child_list) {
41     if (std::find(parent_list.begin(), parent_list.end(), child) == parent_list.end()) {
42       return false;
43     }
44   }
45   return true;
46 }
47 
IsPrimitiveAttrValid(const PrimitivePtr & prim,const std::string & attr_name)48 bool IsPrimitiveAttrValid(const PrimitivePtr &prim, const std::string &attr_name) {
49   MS_EXCEPTION_IF_NULL(prim);
50   return !prim->HasAttr(attr_name) || !GetValue<bool>(prim->GetAttr(attr_name));
51 }
52 
IsAddNodeValid(const AnfNodePtr & add_node,const AnfNodePtr & comm_node)53 bool IsAddNodeValid(const AnfNodePtr &add_node, const AnfNodePtr &comm_node) {
54   OperatorInfoPtr add_distribute_operator = add_node->user_data<OperatorInfo>();
55   if (add_distribute_operator == nullptr) {
56     return false;
57   }
58   TensorInfo node_add_tensor_in = add_distribute_operator->inputs_tensor_info()[LongToSize(1)];
59   TensorLayout node_add_tensor_layout = node_add_tensor_in.tensor_layout();
60   const auto node_add_rank_list = node_add_tensor_layout.InferRepeatedGroup();
61 
62   auto comm_prim = GetCNodePrimitive(comm_node);
63   if (!comm_prim->HasAttr(GROUP)) {
64     return false;
65   }
66   auto comm_group = GetValue<std::string>(comm_prim->GetAttr(GROUP));
67   MS_EXCEPTION_IF_NULL(g_device_manager);
68   auto comm_rank_list = g_device_manager->FindRankListByHashName(comm_group);
69   return IsSubRankList(comm_rank_list, node_add_rank_list);
70 }
71 
IsPrimitiveLinear(const AnfNodePtr & anode)72 bool IsPrimitiveLinear(const AnfNodePtr &anode) {
73   MS_EXCEPTION_IF_NULL(anode);
74   if (IsPrimitiveCNode(anode, prim::kPrimReduceAll) || IsPrimitiveCNode(anode, prim::kPrimReduceAny) ||
75       IsPrimitiveCNode(anode, prim::kPrimReduceMean) || IsPrimitiveCNode(anode, prim::kPrimReduceMax) ||
76       IsPrimitiveCNode(anode, prim::kPrimReduceMin) || IsPrimitiveCNode(anode, prim::kPrimReduceProd) ||
77       IsPrimitiveCNode(anode, prim::kPrimReduceSum) || IsPrimitiveCNode(anode, prim::kPrimSquareSumV1)) {
78     return false;
79   }
80   return true;
81 }
82 
FindPullDownNode(const AnfNodePtr & anode)83 AnfNodePtr FindPullDownNode(const AnfNodePtr &anode) {
84   auto pre_node = GetInputNodeWithFilter(anode, [&](const AnfNodePtr &cur_anode) {
85     auto cur_cnode = cur_anode->cast<CNodePtr>();
86     auto prim = GetCNodePrimitive(cur_cnode);
87     if (prim == nullptr) {
88       return std::make_pair(false, LongToSize(0));
89     }
90     auto cur_node_input_list = cur_cnode->inputs();
91     for (size_t i = 1; i < cur_node_input_list.size(); ++i) {
92       auto cur_input_node = cur_node_input_list[i];
93       // find first non Tensor CNode
94       if (IsValueNode<tensor::Tensor>(cur_input_node)) {
95         continue;
96       }
97       auto input_prim = GetCNodePrimitive(cur_input_node);
98       if (input_prim == nullptr) {
99         return std::make_pair(false, i);
100       }
101       // cur prim must in ALLREDUCE_PULL_DOWN_WHITE_LIST and input_prim is not marked or marked false
102       bool filter = (ALLREDUCE_PULL_DOWN_WHITE_LIST.find(prim->name()) != ALLREDUCE_PULL_DOWN_WHITE_LIST.end() ||
103                      prim->name() == MATMUL || prim->name() == BATCH_MATMUL) &&
104                     IsPrimitiveAttrValid(input_prim, MATMUL_ADD_COMM_BEGIN);
105       return std::make_pair(filter, i);
106     }
107     return std::make_pair(false, LongToSize(1));
108   });
109   return pre_node;
110 }
111 
FindAllValidAddNode(const FuncGraphPtr & graph,HashMap<AnfNodePtr,std::vector<AnfNodePtr>> * pull_down_node_map)112 void FindAllValidAddNode(const FuncGraphPtr &graph, HashMap<AnfNodePtr, std::vector<AnfNodePtr>> *pull_down_node_map) {
113   std::list<CNodePtr> graph_orders = graph->GetOrderedCnodes();
114   std::vector<CNodePtr> origin_nodes_topological(graph_orders.cbegin(), graph_orders.cend());
115   for (const auto &node : origin_nodes_topological) {
116     // add node
117     auto prim = GetCNodePrimitive(node);
118     if (prim == nullptr || prim->name() != ADD || IsPrimitiveAttrValid(prim, MATMUL_ADD_COMM_END)) {
119       continue;
120     }
121     auto input_nodes = node->inputs();
122     for (size_t i = 1; i < input_nodes.size(); ++i) {
123       auto input_node = input_nodes[i];
124       if (!IsPrimitiveLinear(input_node)) {
125         continue;
126       }
127       auto comm_node = FindPullDownNode(input_node);
128       if (comm_node == nullptr) {
129         MS_LOG(INFO) << "For matmul add comm reduction, can not find valid comm node, node is "
130                      << input_node->DebugString();
131         continue;
132       }
133       if ((!IsPrimitiveCNode(comm_node, prim::kPrimAllReduce) &&
134            !IsPrimitiveCNode(comm_node, prim::kPrimReduceScatter))) {
135         MS_LOG(INFO) << "For matmul comm reduction, comm node is not allreduce or reduce scatter, node is "
136                      << comm_node->DebugString();
137         continue;
138       }
139 
140       auto comm_cnode = comm_node->cast<CNodePtr>();
141       MS_EXCEPTION_IF_NULL(comm_node);
142       auto pre_prim = GetCNodePrimitive(comm_cnode->input(1));
143       if (pre_prim == nullptr || IsPrimitiveAttrValid(pre_prim, MATMUL_ADD_COMM_BEGIN)) {
144         MS_LOG(INFO) << "For matmul comm reduction,  cannot find matmul/batch matmul node, "
145                      << "skip cur node: " << input_node->DebugString();
146         continue;
147       }
148       (*pull_down_node_map)[node].push_back(comm_node);
149       MS_LOG(INFO) << "For matmul comm reduction, find one side with matmul-allreduce structure, add node is: "
150                    << node->DebugString() << " comm node is: " << comm_node->DebugString();
151     }
152   }
153 }
154 
FindBiasAdd(const AnfNodePtr & comm_node,const AnfNodePtr & add_node_input)155 AnfNodePtr FindBiasAdd(const AnfNodePtr &comm_node, const AnfNodePtr &add_node_input) {
156   MS_EXCEPTION_IF_NULL(comm_node);
157   auto add_node = GetInputNodeWithFilter(add_node_input, [&](const AnfNodePtr &anode) {
158     auto prim = GetCNodePrimitive(anode);
159     if (prim == nullptr) {
160       return std::make_pair(false, 0);
161     }
162     // find add node, current ops must lie in ALLREDUCE_PULL_DOWN_WHITE_LIST, cannot be add node or equal to comm node
163     bool filter = (ALLREDUCE_PULL_DOWN_WHITE_LIST.find(prim->name()) != ALLREDUCE_PULL_DOWN_WHITE_LIST.end() ||
164                    prim->name() == MATMUL || prim->name() == BATCH_MATMUL) &&
165                   prim->name() != ADD && anode != comm_node;
166     return std::make_pair(filter, 1);
167   });
168   return add_node;
169 }
170 
HandleNodeBiasAdd(const AnfNodePtr & comm_node,const AnfNodePtr & add_node_input)171 void HandleNodeBiasAdd(const AnfNodePtr &comm_node, const AnfNodePtr &add_node_input) {
172   MS_EXCEPTION_IF_NULL(comm_node);
173   MS_EXCEPTION_IF_NULL(add_node_input);
174   auto comm_prim = GetCNodePrimitive(comm_node);
175   MS_EXCEPTION_IF_NULL(comm_prim);
176   if (!comm_prim->HasAttr(GROUP)) {
177     MS_LOG(INFO) << "For matmul comm reduction, cur prim has not attr " << GROUP
178                  << ", skip it, node is: " << comm_node->DebugString();
179     return;
180   }
181   auto comm_group = GetValue<std::string>(comm_prim->GetAttr(GROUP));
182   MS_EXCEPTION_IF_NULL(g_device_manager);
183   auto comm_rank_list = g_device_manager->FindRankListByHashName(comm_group);
184   double rank_size = 1.0 / comm_rank_list.size();
185 
186   auto add_node = FindBiasAdd(comm_node, add_node_input);
187   if (add_node == nullptr || !IsPrimitiveCNode(add_node, prim::kPrimAdd)) {
188     MS_LOG(INFO) << "For matmul comm reduction, cannot find bias add node, find node is: " << add_node->DebugString()
189                  << " start node is " << add_node_input->DebugString();
190     return;
191   }
192   if (!IsAddNodeValid(add_node, comm_node)) {
193     MS_LOG(INFO) << "For matmul comm reduction, strategy of add node mismatched, skip it, add node is: "
194                  << add_node->DebugString();
195     return;
196   }
197   auto add_cnode = add_node->cast<CNodePtr>();
198   MS_EXCEPTION_IF_NULL(add_cnode);
199   // find load node for bias parameter
200   auto bias_side_start_node = add_cnode->input(2);
201   auto bias_node = GetInputNodeWithFilter(bias_side_start_node, [&](const AnfNodePtr &anode) {
202     auto prim = GetCNodePrimitive(anode);
203     if (prim == nullptr) {
204       return std::make_pair(false, 0);
205     }
206     bool filter = ALLREDUCE_PULL_DOWN_WHITE_LIST.find(prim->name()) != ALLREDUCE_PULL_DOWN_WHITE_LIST.end();
207     return std::make_pair(filter, 1);
208   });
209   if (bias_node == nullptr || !IsPrimitiveCNode(bias_node, prim::kPrimLoad)) {
210     MS_LOG(INFO) << "For comm reduction, cannot find load op for bias parameter along current add node, please "
211                     "check whether it exists, cur add node is: "
212                  << add_node->DebugString();
213     return;
214   }
215   // insert mul node
216   auto bias_node_abstract = bias_node->abstract();
217   MS_EXCEPTION_IF_NULL(bias_node_abstract);
218   auto bias_dtype = bias_node_abstract->cast<abstract::AbstractTensorPtr>();
219   MS_EXCEPTION_IF_NULL(bias_dtype);
220   auto bias_dtype_ele = bias_dtype->element();
221   MS_EXCEPTION_IF_NULL(bias_dtype_ele);
222   mindspore::tensor::TensorPtr tensor_ptr =
223     std::make_shared<mindspore::tensor::Tensor>(rank_size, bias_dtype_ele->GetType());
224   auto const_node = NewValueNode(MakeValue(tensor_ptr));
225   const_node->set_abstract(const_node->value()->ToAbstract());
226 
227   auto mul_prim = NewValueNode(prim::kPrimMul);
228   auto cur_prim = GetValueNode<PrimitivePtr>(mul_prim);
229   MS_EXCEPTION_IF_NULL(cur_prim);
230   (void)cur_prim->AddAttr(MATMUL_ADD_COMM_MUL, MakeValue(true));
231   AnfNodePtrList mul_node_inputs = {mul_prim, bias_node, const_node};
232   auto fg = comm_node->func_graph();
233   MS_EXCEPTION_IF_NULL(fg);
234   auto mul_node = fg->NewCNode(mul_node_inputs);
235   mul_node->set_abstract(bias_node->abstract()->Clone());
236 
237   MS_EXCEPTION_IF_NULL(fg);
238   auto manager = fg->manager();
239   MS_EXCEPTION_IF_NULL(manager);
240   (void)manager->Replace(bias_node, mul_node);
241   MS_LOG(INFO) << "for comm reduction, insert new mul node after parameter node";
242 }
243 
HandleNodePullUp(const AnfNodePtr & add_node,const std::vector<AnfNodePtr> & comm_node_list,HashMap<AnfNodePtr,AnfNodePtr> * comm_node_map)244 void HandleNodePullUp(const AnfNodePtr &add_node, const std::vector<AnfNodePtr> &comm_node_list,
245                       HashMap<AnfNodePtr, AnfNodePtr> *comm_node_map) {
246   for (size_t index = 0; index < comm_node_list.size(); ++index) {
247     // Node pull down
248     // Node After AllReduce pull up
249     auto each_node = comm_node_list[index];
250     auto each_cnode = each_node->cast<CNodePtr>();
251     auto pre_node = each_cnode->input(1);
252     auto pre_prim = GetCNodePrimitive(pre_node);
253     if (pre_prim == nullptr || IsPrimitiveAttrValid(pre_prim, MATMUL_ADD_COMM_BEGIN)) {
254       MS_LOG(INFO) << "For comm reduction, its pre node does not marked or marked false, skip it.";
255       continue;
256     }
257     auto graph = each_node->func_graph();
258     MS_EXCEPTION_IF_NULL(graph);
259     auto manager = graph->manager();
260     MS_EXCEPTION_IF_NULL(manager);
261     auto add_cnode = add_node->cast<CNodePtr>();
262     HandleNodeBiasAdd(each_node, add_cnode->input(index + 1));
263     (void)manager->Replace(each_node, pre_node);
264     MS_LOG(INFO) << "For comm reduction, pull up node next to comm node, node is: " << pre_node->DebugString();
265     if ((*comm_node_map).find(add_node) == (*comm_node_map).end()) {
266       (*comm_node_map)[add_node] = each_node;
267     }
268   }
269 }
270 
HandleNodePullDown(const AnfNodePtr & add_node,const AnfNodePtr & comm_node)271 void HandleNodePullDown(const AnfNodePtr &add_node, const AnfNodePtr &comm_node) {
272   auto comm_cnode = comm_node->cast<CNodePtr>();
273   MS_EXCEPTION_IF_NULL(comm_cnode);
274   AnfNodePtrList new_comm_node_inputs = {comm_cnode->input(0), add_node};
275   auto graph = add_node->func_graph();
276   MS_EXCEPTION_IF_NULL(graph);
277   auto new_comm_node = graph->NewCNode(new_comm_node_inputs);
278   new_comm_node->set_abstract(comm_node->abstract());
279   auto prim = GetCNodePrimitive(new_comm_node);
280   (void)prim->AddAttr(MATMUL_ADD_COMM_REDUCTION, MakeValue(true));
281 
282   auto manager = graph->manager();
283   MS_EXCEPTION_IF_NULL(manager);
284   (void)manager->Replace(add_node, new_comm_node);
285   MS_LOG(INFO) << "For comm reduction, pull down comm node, node is: " << new_comm_node->DebugString();
286 }
287 
HandleAddNode(const HashMap<AnfNodePtr,std::vector<AnfNodePtr>> & pull_down_node_map)288 void HandleAddNode(const HashMap<AnfNodePtr, std::vector<AnfNodePtr>> &pull_down_node_map) {
289   HashMap<AnfNodePtr, AnfNodePtr> comm_node_map;
290   for (auto &each_pull_down_node : pull_down_node_map) {
291     if (each_pull_down_node.second.size() < kCommReductionValidCommOpsNum) {
292       MS_LOG(INFO) << "For comm reduction, cur node cannot find match structure, skip it. current node is "
293                    << each_pull_down_node.first->DebugString();
294       continue;
295     }
296     // Handle node pull up
297     HandleNodePullUp(each_pull_down_node.first, each_pull_down_node.second, &comm_node_map);
298     // Handle node pull down
299     HandleNodePullDown(each_pull_down_node.first, comm_node_map[each_pull_down_node.first]);
300   }
301 }
302 
303 }  // namespace
304 
305 // For Structure as following:
306 //  MatMul/BatchMatMul -> AllReduce -> ... -> X -> Add, and MatMul/BatchMatMul -> AllReduce -> ... -> Y -> Add
307 // Change it to MatMul/BatchMatMul -> ... -> X -> Add -> AllReduce and MatMul/BatchMatMul -> ... -> Y -> Add ->
308 // AllReduce thus it can reduce a communication op.
MatmulAddCommReduction(const FuncGraphPtr & graph,const opt::OptimizerPtr &)309 bool MatmulAddCommReduction(const FuncGraphPtr &graph, const opt::OptimizerPtr &) {
310   MS_EXCEPTION_IF_NULL(graph);
311   auto manager = graph->manager();
312   MS_EXCEPTION_IF_NULL(manager);
313   // assume no change to graph
314   bool changes = false;
315   HashMap<AnfNodePtr, std::vector<AnfNodePtr>> pull_down_node_map;
316   // candidate node to pull down
317   for (const auto &each_graph : manager->func_graphs()) {
318     FindAllValidAddNode(each_graph, &pull_down_node_map);
319   }
320   // Node Pull up
321   HandleAddNode(pull_down_node_map);
322   return changes;
323 }
324 }  // namespace parallel
325 }  // namespace mindspore
326