• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pass/merge_comm.h"
18 #include <string>
19 #include <vector>
20 #include <unordered_map>
21 #include <utility>
22 #include <memory>
23 #include "mindspore/core/ops/other_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "frontend/optimizer/optimizer.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28 #include "frontend/parallel/graph_util/graph_info.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "utils/log_adapter.h"
31 
32 namespace mindspore {
33 namespace parallel {
34 namespace {
GetMakeTupleValue(const CNodePtr & cnode)35 static Shape GetMakeTupleValue(const CNodePtr &cnode) {
36   MS_EXCEPTION_IF_NULL(cnode);
37   MS_EXCEPTION_IF_CHECK_FAIL(cnode->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
38   auto make_tuple = cnode->input(kIndex2);
39   auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
40   Shape ret;
41   for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
42     auto input_node = make_tuple_cnode->input(i);
43     MS_EXCEPTION_IF_NULL(input_node);
44     auto value_node = GetValueNode(input_node);
45     if (value_node != nullptr && value_node->isa<Int64Imm>()) {
46       auto shape_ele = GetValue<int64_t>(value_node);
47       ret.push_back(shape_ele);
48     } else {
49       ret.push_back(-1);
50     }
51   }
52   return ret;
53 }
54 
IsSameTargetDynamicShape(const CNodePtr & reshape_node_a,const CNodePtr & reshape_node_b)55 static bool IsSameTargetDynamicShape(const CNodePtr &reshape_node_a, const CNodePtr &reshape_node_b) {
56   MS_EXCEPTION_IF_NULL(reshape_node_a);
57   MS_EXCEPTION_IF_NULL(reshape_node_b);
58   MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_a->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
59   MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_b->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
60   if (!IsPrimitiveCNode(reshape_node_a->input(kIndex2), prim::kPrimMakeTuple)) {
61     MS_LOG(WARNING) << "the dst shape of reshape node a is not make_tuple for dynamic shape";
62     return false;
63   }
64 
65   if (!IsPrimitiveCNode(reshape_node_b->input(kIndex2), prim::kPrimMakeTuple)) {
66     MS_LOG(WARNING) << "the dst shape of reshape node b is not make_tuple for dynamic shape";
67     return false;
68   }
69 
70   Shape node_a_shape = GetMakeTupleValue(reshape_node_a);
71   Shape node_b_shape = GetMakeTupleValue(reshape_node_b);
72   MS_LOG(INFO) << "the node a shape is " << node_a_shape << ", the node b shape is " << node_b_shape;
73   if (std::count(node_a_shape.cbegin(), node_a_shape.cend(), -1) > 1) {
74     return false;
75   }
76   if (std::count(node_b_shape.cbegin(), node_b_shape.cend(), -1) > 1) {
77     return false;
78   }
79 
80   return (node_a_shape == node_b_shape);
81 }
82 
IsSameTargetShape(const CNodePtr & reshape_node_a,const CNodePtr & reshape_node_b)83 bool IsSameTargetShape(const CNodePtr &reshape_node_a, const CNodePtr &reshape_node_b) {
84   MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_a->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
85   MS_EXCEPTION_IF_CHECK_FAIL(reshape_node_b->inputs().size() == kSizeThree, "Input size of Reshape is not 3.");
86   if (!reshape_node_a->input(kIndex2)->isa<ValueNode>() || !reshape_node_b->input(kIndex2)->isa<ValueNode>()) {
87     return IsSameTargetDynamicShape(reshape_node_a, reshape_node_b);
88   }
89   auto value_ptr_a = reshape_node_a->input(kIndex2)->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
90   auto value_ptr_b = reshape_node_b->input(kIndex2)->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
91   if (value_ptr_a.size() != value_ptr_b.size()) {
92     return false;
93   }
94   for (size_t i = 0; i < value_ptr_a.size(); i++) {
95     int64_t cur_shape_a = GetValue<int64_t>(value_ptr_a.at(i));
96     int64_t cur_shape_b = GetValue<int64_t>(value_ptr_b.at(i));
97     if (cur_shape_a != cur_shape_b) {
98       return false;
99     }
100   }
101   return true;
102 }
103 
MergeAllGather(const std::vector<AnfNodePtr> & all_nodes,const FuncGraphManagerPtr & manager)104 void MergeAllGather(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
105   std::unordered_map<CNodePtr, std::vector<CNodePtr>> allgather_input_map;
106   for (const auto &node : all_nodes) {
107     if (!IsPrimitiveCNode(node, prim::kPrimAllGather)) {
108       continue;
109     }
110     auto allgather_cnode = node->cast<CNodePtr>();
111     auto pre_node = GetInputNodeWithFilter(allgather_cnode->input(kIndex1), [&](const CNodePtr &cnode) {
112       bool filter = IsPrimitiveCNode(cnode, prim::kPrimReshape);
113       return std::make_pair(filter, 1);
114     });
115     if (!IsPrimitiveCNode(pre_node)) {
116       continue;
117     }
118     auto pre_cnode = pre_node->cast<CNodePtr>();
119     allgather_input_map[pre_cnode].push_back(allgather_cnode);
120   }
121   for (const auto &allgather_pairs : allgather_input_map) {
122     if (allgather_pairs.second.size() <= 1) {
123       continue;
124     }
125     auto allgather_list = allgather_pairs.second;
126     auto allgather_cnode1 = allgather_list.front();
127     auto is_same_allgather =
128       std::all_of(allgather_list.begin(), allgather_list.end(), [&allgather_cnode1](const CNodePtr &allgather_cnode2) {
129         auto ag1_prim = GetCNodePrimitive(allgather_cnode1);
130         auto ag2_prim = GetCNodePrimitive(allgather_cnode2);
131         auto group1 = ag1_prim->GetAttr(GROUP);
132         auto group2 = ag2_prim->GetAttr(GROUP);
133         if (!group1 || !group2) {
134           return false;
135         }
136         if (GetValue<std::string>(group1) != GetValue<std::string>(group2)) {
137           return false;
138         }
139         if (IsPrimitiveCNode(allgather_cnode1->input(kIndex1), prim::kPrimReshape) !=
140             IsPrimitiveCNode(allgather_cnode2->input(kIndex1), prim::kPrimReshape)) {
141           return false;
142         }
143         if (IsPrimitiveCNode(allgather_cnode1->input(kIndex1), prim::kPrimReshape) &&
144             IsPrimitiveCNode(allgather_cnode2->input(kIndex1), prim::kPrimReshape)) {
145           if (!IsSameTargetShape(allgather_cnode1->input(kIndex1)->cast<CNodePtr>(),
146                                  allgather_cnode2->input(kIndex1)->cast<CNodePtr>())) {
147             return false;
148           }
149         }
150         if (allgather_cnode1->func_graph() != allgather_cnode2->func_graph()) {
151           return false;
152         }
153         return true;
154       });
155     if (!is_same_allgather) {
156       MS_LOG(INFO) << "allgather nodes share the same input node:" << allgather_pairs.first->DebugString()
157                    << " is not equal.";
158       continue;
159     }
160     auto ag0 = allgather_list.front();
161     for (const auto &ag : allgather_list) {
162       manager->Replace(ag, ag0);
163     }
164   }
165 }
166 }  // namespace
167 
MergeComm(const FuncGraphPtr & root,const opt::OptimizerPtr & optimizer)168 bool MergeComm(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
169   MS_EXCEPTION_IF_NULL(root);
170   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
171   auto graph_set = ForwardGraph(root);
172   // assume no change to graph
173   bool changes = false;
174   // control whether use model_parallel mode
175   if (!IsAutoParallelCareGraph(root) || (root->has_flag(MERGE_COMM_RUN_ONCE_ONLY)) || graph_set.size() < 1) {
176     return changes;
177   }
178   FuncGraphManagerPtr manager;
179   pipeline::ResourceBasePtr res;
180   if (optimizer == nullptr) {
181     manager = root->manager();
182     res = std::make_shared<pipeline::Resource>();
183     res->set_manager(manager);
184   } else {
185     res = optimizer->resource();
186     MS_EXCEPTION_IF_NULL(res);
187     manager = res->manager();
188   }
189 
190   MS_EXCEPTION_IF_NULL(manager);
191   CNodePtr ret = root->get_return();
192   MS_EXCEPTION_IF_NULL(ret);
193   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
194   MergeAllGather(all_nodes, manager);
195   DumpGraph(root, std::string("merge_comm"));
196 
197   // allreduce fusion only run once
198   root->set_flag(MERGE_COMM_RUN_ONCE_ONLY, true);
199   // Keep all func graph for parallel before save result.
200   SetReserved(root);
201   res->SetResult(pipeline::kStepParallelGraph, root);
202   return changes;
203 }
204 }  // namespace parallel
205 }  // namespace mindspore
206