• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/comm_op_reuse_tag.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <algorithm>
22 #include "mindspore/core/ops/other_ops.h"
23 #include "ir/func_graph.h"
24 #include "frontend/parallel/ops_info/ops_utils.h"
25 #include "frontend/parallel/device_manager.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28 #include "include/common/utils/utils.h"
29 #include "include/common/utils/comm_manager.h"
30 
31 namespace mindspore {
32 namespace opt {
33 namespace {
is_comm_ops(const AnfNodePtr & node)34 inline bool is_comm_ops(const AnfNodePtr &node) {
35   static const std::vector<PrimitivePtr> kCommunicationOpsPrim = {prim::kPrimAllReduce,
36                                                                   prim::kPrimReduce,
37                                                                   prim::kPrimAllGather,
38                                                                   prim::kPrimReduceScatter,
39                                                                   prim::kPrimCollectiveScatter,
40                                                                   prim::kPrimCollectiveGather,
41                                                                   prim::kPrimAlltoAll,
42                                                                   prim::kPrimAllSwap,
43                                                                   prim::kPrimAllToAllv,
44                                                                   prim::kPrimNeighborExchange,
45                                                                   prim::kPrimNeighborExchangeV2,
46                                                                   prim::kPrimNeighborExchangeV2Grad,
47                                                                   prim::kPrimBarrier,
48                                                                   prim::kPrimBatchISendIRecv,
49                                                                   prim::kPrimAlltoAllV};
50 
51   for (const auto &prim : kCommunicationOpsPrim) {
52     if (IsPrimitiveCNode(node, prim)) {
53       return true;
54     }
55   }
56 
57   return false;
58 }
59 }  // namespace
60 
AddCommOpReuseTag(const FuncGraphPtr & graph)61 void AddCommOpReuseTag(const FuncGraphPtr &graph) {
62   if (parallel::g_device_manager == nullptr) {
63     MS_LOG(INFO) << "parallel::g_device_manager is not initialized.";
64     return;
65   }
66   MS_EXCEPTION_IF_NULL(graph);
67 
68   if (!parallel::IsAutoParallelCareGraph(graph)) {
69     return;
70   }
71   auto manager = graph->manager();
72   MS_EXCEPTION_IF_NULL(manager);
73   const auto &all_nodes = manager->all_nodes();
74   for (auto &node : all_nodes) {
75     MS_EXCEPTION_IF_NULL(node);
76     if (!is_comm_ops(node)) {
77       continue;
78     }
79     auto comm_prim = GetCNodePrimitive(node);
80     MS_EXCEPTION_IF_NULL(comm_prim);
81     if (comm_prim->HasAttr(parallel::FUSION) && GetValue<int64_t>(comm_prim->GetAttr(parallel::FUSION)) != 0) {
82       continue;
83     }
84     (void)comm_prim->AddAttr(parallel::COMM_REUSE, MakeValue(true));
85 
86     std::string group_name = "";
87     if (comm_prim->HasAttr(parallel::GROUP)) {
88       group_name = GetValue<std::string>(comm_prim->GetAttr(parallel::GROUP));
89     }
90     std::vector<unsigned int> rank_list = {};
91     auto long_rank_list = parallel::g_device_manager->FindRankListByHashName(group_name);
92     (void)std::transform(long_rank_list.begin(), long_rank_list.end(), std::back_inserter(rank_list),
93                          [](int64_t d) -> unsigned int { return IntToUint(LongToInt(d)); });
94     (void)comm_prim->AddAttr(kAttrRankList, MakeValue<std::vector<unsigned int>>(rank_list));
95   }
96 }
97 }  // namespace opt
98 }  // namespace mindspore
99