1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19
20 namespace mindspore {
21 namespace opt {
InsertSplitForInput(const FuncGraphPtr & func_graph,const CNodePtr & node,int64_t rank_size) const22 std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const FuncGraphPtr &func_graph,
23 const CNodePtr &node,
24 int64_t rank_size) const {
25 MS_EXCEPTION_IF_NULL(func_graph);
26 size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
27 std::vector<AnfNodePtr> split_outputs;
28 size_t rank_size_t = LongToSize(rank_size);
29 for (size_t i = 0; i < inputs_size; i++) {
30 std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
31 split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
32 auto split = func_graph->NewCNode(split_inputs);
33 MS_EXCEPTION_IF_NULL(split);
34 std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
35 std::vector<std::vector<size_t>> shapes;
36 std::vector<int> size_splits;
37 for (size_t j = 0; j < rank_size_t; j++) {
38 std::vector<size_t> output_node_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
39 output_node_shape[0] /= rank_size_t;
40 shapes.push_back(output_node_shape);
41 size_splits.push_back(output_node_shape[0]);
42 }
43 AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
44 AnfAlgo::SetNodeAttr("split_dim", MakeValue(0L), split);
45 AnfAlgo::SetNodeAttr("num_split", MakeValue(rank_size), split);
46 AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split);
47 kernel_select_->SelectKernel(split);
48 std::vector<AnfNodePtr> new_outputs;
49 CreateMultipleOutputsOfAnfNode(func_graph, split, AnfAlgo::GetOutputTensorNum(split), &new_outputs);
50 for (size_t j = 0; j < new_outputs.size(); j++) {
51 split_outputs.push_back(new_outputs[j]);
52 }
53 }
54 return split_outputs;
55 }
56
RearrangeInputsForReduceScatter(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & inputs,int64_t rank_size) const57 AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const FuncGraphPtr &func_graph,
58 const AnfNodePtr &node,
59 const std::vector<AnfNodePtr> &inputs,
60 int64_t rank_size) const {
61 MS_EXCEPTION_IF_NULL(func_graph);
62 size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
63 std::vector<AnfNodePtr> reduce_scatter_inputs{
64 NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceScatter->name()))};
65 size_t rank_size_t = LongToSize(rank_size);
66 for (size_t i = 0; i < rank_size_t; i++) {
67 for (size_t j = 0, idx = i; j < inputs_size; j++, idx += rank_size_t) {
68 reduce_scatter_inputs.push_back(inputs[idx]);
69 }
70 }
71 auto reduce_scatter = func_graph->NewCNode(reduce_scatter_inputs);
72 MS_EXCEPTION_IF_NULL(reduce_scatter);
73 reduce_scatter->set_abstract(node->abstract());
74 AnfAlgo::CopyNodeAttrs(node, reduce_scatter);
75 AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(1L), reduce_scatter);
76 kernel_select_->SelectKernel(reduce_scatter);
77 return reduce_scatter;
78 }
79
DefinePattern() const80 const BaseRef SplitInputsForReduceScatter::DefinePattern() const {
81 VarPtr Xs = std::make_shared<SeqVar>();
82 auto prim = std::make_shared<Primitive>(kReduceScatterOpName);
83 return VectorRef({prim, Xs});
84 }
85
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const86 const AnfNodePtr SplitInputsForReduceScatter::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
87 const EquivPtr &) const {
88 MS_EXCEPTION_IF_NULL(node);
89 auto cnode = node->cast<CNodePtr>();
90 MS_EXCEPTION_IF_NULL(cnode);
91
92 if (AnfAlgo::GetInputTensorNum(node) == 1) {
93 AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(0L), node);
94 return nullptr;
95 }
96 if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
97 return nullptr;
98 }
99 auto fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
100 if (fusion <= 0) {
101 return nullptr;
102 }
103 if (AnfAlgo::HasNodeAttr("Fused", cnode)) {
104 return nullptr;
105 }
106
107 AnfAlgo::SetNodeAttr("Fused", MakeValue(true), node);
108 auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
109 std::vector<AnfNodePtr> split_outputs = InsertSplitForInput(func_graph, cnode, rank_size);
110 return RearrangeInputsForReduceScatter(func_graph, node, split_outputs, rank_size);
111 }
112 } // namespace opt
113 } // namespace mindspore
114