• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/pass/split_matmul_comm_elementwise_fp.h"
18 
19 #include <memory>
20 
21 #include "mindspore/core/ops/other_ops.h"
22 #include "mindspore/core/ops/array_ops.h"
23 #include "mindspore/core/ops/math_ops.h"
24 #include "frontend/parallel/step_parallel.h"
25 #include "frontend/parallel/graph_util/graph_info.h"
26 #include "include/common/utils/utils.h"
27 #include "ir/pattern_matcher.h"
28 
29 namespace mindspore {
30 namespace parallel {
31 namespace {
32 constexpr int64_t kInt64Num0 = 0;
33 constexpr int64_t kInt64Num1 = 1;
34 constexpr int64_t kInt64Num2 = 2;
35 }  // namespace
36 
IsForwardCNode(const CNodePtr & cnode)37 static bool IsForwardCNode(const CNodePtr &cnode) {
38   MS_EXCEPTION_IF_NULL(cnode);
39   return !(cnode->HasPrimalAttr(kPrimalAttrForwardUniqueId) || cnode->HasAttr(kAttrDuplicated));
40 }
41 
42 // MatMul -> AllReduce -> Add
PatternFilter(const AnfNodePtr & node)43 static bool PatternFilter(const AnfNodePtr &node) {
44   auto cnode = node->cast<CNodePtr>();
45   if (cnode == nullptr || !IsForwardCNode(cnode)) {
46     return true;
47   }
48 
49   static PrimitiveSet expect_prim_type = {prim::kPrimAllReduce};
50   if (!IsOneOfPrimitiveCNode(cnode, expect_prim_type)) {
51     return true;
52   }
53   auto input_node = cnode->input(kIndex1);
54   if (input_node == nullptr || !IsPrimitiveCNode(input_node, prim::kPrimMatMul)) {
55     return true;
56   }
57   const auto &input_node_set = cnode->func_graph()->manager()->node_users()[input_node];
58   if (input_node_set.size() != kSizeOne) {
59     return true;
60   }
61   const auto &output_node_set = cnode->func_graph()->manager()->node_users()[cnode];
62   if (output_node_set.size() != kSizeOne) {
63     return true;
64   }
65 
66   auto output_node = output_node_set.front().first;
67   auto index = output_node_set.front().second;
68   if (!IsPrimitiveCNode(output_node, prim::kPrimAdd) || index != kIndex1) {
69     return true;
70   }
71   return false;
72 }
73 
CopyAllAttrs(const CNodePtr & dst_cnode,const CNodePtr & src_cnode)74 static void CopyAllAttrs(const CNodePtr &dst_cnode, const CNodePtr &src_cnode) {
75   MS_EXCEPTION_IF_NULL(dst_cnode);
76   MS_EXCEPTION_IF_NULL(src_cnode);
77   dst_cnode->set_attrs(src_cnode->attrs());
78   auto dst_prim_node = GetCNodePrimitive(dst_cnode);
79   auto src_prim_node = GetCNodePrimitive(src_cnode);
80   auto src_attrs = src_prim_node->attrs();
81   for (const auto &attr : src_attrs) {
82     dst_prim_node->set_attr(attr.first, attr.second);
83   }
84 }
85 
SplitIntoInterleaved(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & manager,const AnfNodePtr & comm_node)86 static void SplitIntoInterleaved(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager,
87                                  const AnfNodePtr &comm_node) {
88   auto comm_cnode = comm_node->cast<CNodePtr>();
89   auto matmul_cnode = comm_cnode->input(kIndex1)->cast<CNodePtr>();
90   auto add_cnode = manager->node_users()[comm_cnode].front().first->cast<CNodePtr>();
91 
92   bool transpose_a = GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_a"));
93   bool transpose_b = GetValue<bool>(GetCNodePrimitive(matmul_cnode)->GetAttr("transpose_b"));
94   const int64_t axis_a_0 = transpose_a ? 1 : 0;
95   const int64_t axis_b_1 = transpose_b ? 0 : 1;
96 
97   auto comm_primtive = GetCNodePrimitive(comm_cnode);
98   auto matmul_input1 = matmul_cnode->input(kIndex1);
99   auto matmul_input1_shape = BaseShapeToShape(AnfAlgo::GetOutputDetailShape(matmul_input1, 0));
100   if (matmul_input1_shape[axis_a_0] % kInt64Num2 != 0) {
101     return;
102   }
103   auto matmul_input2 = matmul_cnode->input(kIndex2);
104   auto matmul_input2_shape = BaseShapeToShape(AnfAlgo::GetOutputDetailShape(matmul_input2, 0));
105   auto add_input2 = add_cnode->input(kIndex2);
106 
107   // Create const value
108   auto value0 = NewValueNode(MakeValue(kInt64Num0));
109   auto value1 = NewValueNode(MakeValue(kInt64Num1));
110 
111   // New split(matmul_input1, axis_a_0, 2)
112   auto split_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimSplit->Clone()), matmul_input1,
113                                            NewValueNode<int64_t>(axis_a_0), NewValueNode<int64_t>(kInt64Num2)});
114   int64_t slice_size = matmul_input1_shape[axis_a_0] / kInt64Num2;
115   AddCNodePrimAttr(split_cnode, kAttrSizeSplits, MakeValue(ShapeVector{slice_size, slice_size}));
116   AddCNodePrimAttr(split_cnode, kAttrNumSplit, MakeValue(kInt64Num2));
117 
118   // branch_a: split_cnode->TupleGetItem(0)->Matmul_a->AllReduce_a->Add_a
119   auto tuple_get_item_a = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), split_cnode, value0});
120   auto matmul_a = func_graph->NewCNode({NewValueNode(prim::kPrimMatMul->Clone()), tuple_get_item_a, matmul_input2});
121   CopyAllAttrs(matmul_a, matmul_cnode);
122   auto comm_a = func_graph->NewCNode({NewValueNode(comm_primtive->Clone()), matmul_a});
123   CopyAllAttrs(comm_a, comm_cnode);
124   CNodePtr add_a = func_graph->NewCNode({NewValueNode(prim::kPrimAdd->Clone()), comm_a, add_input2});
125 
126   // branch_b: split_cnode->TupleGetItem(1)->Matmul_b->AllReduce_b(depend AllReduce_a) ->Add_b
127   auto tuple_get_item_b = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), split_cnode, value1});
128   auto matmul_b = func_graph->NewCNode({NewValueNode(prim::kPrimMatMul->Clone()), tuple_get_item_b, matmul_input2});
129   CopyAllAttrs(matmul_b, matmul_cnode);
130   auto depend = func_graph->NewCNode({NewValueNode(prim::kPrimDepend->Clone()), matmul_b, comm_a});
131   auto comm_b = func_graph->NewCNode({NewValueNode(comm_primtive->Clone()), depend});
132   CopyAllAttrs(comm_b, comm_cnode);
133   auto add_b = func_graph->NewCNode({NewValueNode(prim::kPrimAdd->Clone()), comm_b, add_input2});
134 
135   // New concat(MakeTuple(add_a, add_b))
136   auto make_tuple_cnode = func_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple->Clone()), add_a, add_b});
137   auto concat_cnode = func_graph->NewCNode(
138     {NewValueNode(prim::kPrimConcat->Clone()), make_tuple_cnode, NewValueNode(MakeValue(axis_a_0))});
139 
140   // Infer path_a abstract
141   auto dtype = common::AnfAlgo::GetOutputInferDataType(matmul_input1, 0);
142   ShapeVector split_single_shape = matmul_input1_shape;
143   split_single_shape[axis_a_0] /= kInt64Num2;
144   auto split_shape_abstract = std::make_shared<abstract::Shape>(split_single_shape);
145   common::AnfAlgo::SetOutputTypeAndDetailShape({dtype, dtype}, {split_shape_abstract, split_shape_abstract},
146                                                split_cnode.get());
147   common::AnfAlgo::SetOutputTypeAndDetailShape({dtype}, {split_shape_abstract}, tuple_get_item_a.get());
148   ShapeVector matmul_ab_shape{split_single_shape[axis_a_0], matmul_input2_shape[axis_b_1]};
149   auto matmul_ab_abstract = std::make_shared<abstract::Shape>(matmul_ab_shape);
150   common::AnfAlgo::SetOutputTypeAndDetailShape({dtype}, {matmul_ab_abstract}, matmul_a.get());
151   comm_a->set_abstract(matmul_a->abstract());
152   add_a->set_abstract(matmul_a->abstract());
153 
154   // set path_b from path_a
155   tuple_get_item_b->set_abstract(tuple_get_item_a->abstract());
156   matmul_b->set_abstract(matmul_a->abstract());
157   comm_b->set_abstract(matmul_b->abstract());
158   depend->set_abstract(comm_b->abstract());
159   add_b->set_abstract(add_a->abstract());
160 
161   // set abstract for make_tuple and concat
162   common::AnfAlgo::SetOutputTypeAndDetailShape({dtype, dtype}, {matmul_ab_abstract, matmul_ab_abstract},
163                                                make_tuple_cnode.get());
164   concat_cnode->set_abstract(add_cnode->abstract());
165 
166   // Replace graph
167   auto prev_cnode = matmul_cnode->input(kIndex1);
168   manager->SetEdge(split_cnode, kIndex1, prev_cnode);
169   auto next_cnode_users = manager->node_users()[add_cnode];
170   for (const auto &next_cnode_pair : next_cnode_users) {
171     manager->SetEdge(next_cnode_pair.first, next_cnode_pair.second, concat_cnode);
172   }
173 }
174 
175 // From:
176 // MatMul -> AllReduce -> Add
177 // To:
178 //        --> TupleGetItem(0) -> MatMul_a ->                        AllReduce_a -> Add_a
179 // Split                                                                                 -> Concat
180 //        --> TupleGetItem(1) -> MatMul_b -> Depend(AllReduce_a) -> AllReduce_b -> Add_b
SplitMatmulCommElementwiseFp(const FuncGraphPtr & func_graph)181 void SplitMatmulCommElementwiseFp(const FuncGraphPtr &func_graph) {
182   if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel &&
183       parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) {
184     MS_LOG(INFO) << "SplitMatmulCommElementwiseFp is only support under [semi_]auto_parallel, skip it.";
185     return;
186   }
187 
188   auto ms_context = MsContext::GetInstance();
189   MS_EXCEPTION_IF_NULL(ms_context);
190   auto is_enable = ms_context->get_param<bool>(MS_CTX_INTERLEAVED_MATMUL_COMM);
191   if (!is_enable) {
192     return;
193   }
194 
195   MS_EXCEPTION_IF_NULL(func_graph);
196   auto manager = func_graph->manager();
197   MS_EXCEPTION_IF_NULL(manager);
198   auto todo = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, PatternFilter);
199   for (const auto &node : todo) {
200     SplitIntoInterleaved(func_graph, manager, node);
201   }
202 }
203 }  // namespace parallel
204 }  // namespace mindspore
205