• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "include/backend/optimizer/optimizer.h"
23 
24 namespace mindspore {
25 namespace opt {
26 class NeighborExchangeUnifyMindIR : public PatternProcessPass {
27  public:
28   explicit NeighborExchangeUnifyMindIR(bool multigraph = true)
29       : PatternProcessPass("neighbor_exchange_unify_mindir", multigraph) {}
30   ~NeighborExchangeUnifyMindIR() override = default;
31   const BaseRef DefinePattern() const override;
32   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
33 
34  private:
35   CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange) const;
36   std::vector<std::string> MustExistPrimitiveName() const override;
37 };
38 
39 /* AllToAllUnifyMindIR
40  * let rank size is 4, for ge:
41  *                        Input
42  *                          |
43  *     Input         [Split(split_dim)]
44  *       |               / | | \
45  *   [AlltoAll]  ->     [AllToAllv]
46  *       |               \  |  |  /
47  *     Output        [Concat(concat_dim)]
48  *                           |
49  *                         Output
50  * for kbk:
51  *                        Input
52  *                          |
53  *                   [Split(split_dim)]
54  *                       / | | \
55  *     Input          [Concat(dim 0)]
56  *       |                  |
57  *   [AlltoAll]  ->     [AllToAll]
58  *       |                  |
59  *     Output          [Split(dim 0)]
60  *                       \  |  |  /
61  *                   [Concat(concat_dim)]
62  *                           |
63  *                         Output
64  */
65 class AllToAllUnifyMindIR : public PatternProcessPass {
66  public:
67   explicit AllToAllUnifyMindIR(bool multigraph = true) : PatternProcessPass("all_to_all_unify_mindir", multigraph) {}
68   ~AllToAllUnifyMindIR() override = default;
69   const BaseRef DefinePattern() const override;
70   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
71 
72  private:
73   CNodePtr CreateSplitNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const AnfNodePtr &input_node,
74                            int64_t split_count, int64_t split_dim) const;
75   CNodePtr CreateSplitNodeWithSplitDim(const KernelGraphPtr &graph, const CNodePtr &all_to_all) const;
76   CNodePtr CreateSplitNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
77                                    const CNodePtr &input_node) const;
78   CNodePtr CreateAllToAllvNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) const;
79   CNodePtr CreateAllToAllNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &concat) const;
80   CNodePtr CreateConcatNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &input_node,
81                             int64_t split_count, int64_t concat_dim) const;
82   CNodePtr CreateConcatNodeWithConcatDim(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
83                                          const CNodePtr &input_node) const;
84   CNodePtr CreateConcatNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all,
85                                     const CNodePtr &input_node) const;
86   std::vector<std::string> MustExistPrimitiveName() const override;
87 };
88 }  // namespace opt
89 }  // namespace mindspore
90 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_
91