• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.h"
18 #include "backend/session/anf_runtime_algorithm.h"
19 
20 namespace mindspore {
21 namespace opt {
InsertSplitForInput(const FuncGraphPtr & func_graph,const CNodePtr & node,int64_t rank_size) const22 std::vector<AnfNodePtr> SplitInputsForReduceScatter::InsertSplitForInput(const FuncGraphPtr &func_graph,
23                                                                          const CNodePtr &node,
24                                                                          int64_t rank_size) const {
25   MS_EXCEPTION_IF_NULL(func_graph);
26   size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
27   std::vector<AnfNodePtr> split_outputs;
28   size_t rank_size_t = LongToSize(rank_size);
29   for (size_t i = 0; i < inputs_size; i++) {
30     std::vector<AnfNodePtr> split_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimSplitV->name()))};
31     split_inputs.push_back(AnfAlgo::GetInputNode(node, i));
32     auto split = func_graph->NewCNode(split_inputs);
33     MS_EXCEPTION_IF_NULL(split);
34     std::vector<TypeId> dtypes(rank_size, AnfAlgo::GetPrevNodeOutputInferDataType(node, i));
35     std::vector<std::vector<size_t>> shapes;
36     std::vector<int> size_splits;
37     for (size_t j = 0; j < rank_size_t; j++) {
38       std::vector<size_t> output_node_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, i);
39       output_node_shape[0] /= rank_size_t;
40       shapes.push_back(output_node_shape);
41       size_splits.push_back(output_node_shape[0]);
42     }
43     AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
44     AnfAlgo::SetNodeAttr("split_dim", MakeValue(0L), split);
45     AnfAlgo::SetNodeAttr("num_split", MakeValue(rank_size), split);
46     AnfAlgo::SetNodeAttr("size_splits", MakeValue(size_splits), split);
47     kernel_select_->SelectKernel(split);
48     std::vector<AnfNodePtr> new_outputs;
49     CreateMultipleOutputsOfAnfNode(func_graph, split, AnfAlgo::GetOutputTensorNum(split), &new_outputs);
50     for (size_t j = 0; j < new_outputs.size(); j++) {
51       split_outputs.push_back(new_outputs[j]);
52     }
53   }
54   return split_outputs;
55 }
56 
RearrangeInputsForReduceScatter(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::vector<AnfNodePtr> & inputs,int64_t rank_size) const57 AnfNodePtr SplitInputsForReduceScatter::RearrangeInputsForReduceScatter(const FuncGraphPtr &func_graph,
58                                                                         const AnfNodePtr &node,
59                                                                         const std::vector<AnfNodePtr> &inputs,
60                                                                         int64_t rank_size) const {
61   MS_EXCEPTION_IF_NULL(func_graph);
62   size_t inputs_size = AnfAlgo::GetInputTensorNum(node);
63   std::vector<AnfNodePtr> reduce_scatter_inputs{
64     NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceScatter->name()))};
65   size_t rank_size_t = LongToSize(rank_size);
66   for (size_t i = 0; i < rank_size_t; i++) {
67     for (size_t j = 0, idx = i; j < inputs_size; j++, idx += rank_size_t) {
68       reduce_scatter_inputs.push_back(inputs[idx]);
69     }
70   }
71   auto reduce_scatter = func_graph->NewCNode(reduce_scatter_inputs);
72   MS_EXCEPTION_IF_NULL(reduce_scatter);
73   reduce_scatter->set_abstract(node->abstract());
74   AnfAlgo::CopyNodeAttrs(node, reduce_scatter);
75   AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(1L), reduce_scatter);
76   kernel_select_->SelectKernel(reduce_scatter);
77   return reduce_scatter;
78 }
79 
DefinePattern() const80 const BaseRef SplitInputsForReduceScatter::DefinePattern() const {
81   VarPtr Xs = std::make_shared<SeqVar>();
82   auto prim = std::make_shared<Primitive>(kReduceScatterOpName);
83   return VectorRef({prim, Xs});
84 }
85 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const86 const AnfNodePtr SplitInputsForReduceScatter::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
87                                                       const EquivPtr &) const {
88   MS_EXCEPTION_IF_NULL(node);
89   auto cnode = node->cast<CNodePtr>();
90   MS_EXCEPTION_IF_NULL(cnode);
91 
92   if (AnfAlgo::GetInputTensorNum(node) == 1) {
93     AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(0L), node);
94     return nullptr;
95   }
96   if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
97     return nullptr;
98   }
99   auto fusion = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion);
100   if (fusion <= 0) {
101     return nullptr;
102   }
103   if (AnfAlgo::HasNodeAttr("Fused", cnode)) {
104     return nullptr;
105   }
106 
107   AnfAlgo::SetNodeAttr("Fused", MakeValue(true), node);
108   auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrRankSize);
109   std::vector<AnfNodePtr> split_outputs = InsertSplitForInput(func_graph, cnode, rank_size);
110   return RearrangeInputsForReduceScatter(func_graph, node, split_outputs, rank_size);
111 }
112 }  // namespace opt
113 }  // namespace mindspore
114