1 /** 2 * Copyright 2022-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 #include "include/backend/optimizer/optimizer.h" 23 24 namespace mindspore { 25 namespace opt { 26 class NeighborExchangeUnifyMindIR : public PatternProcessPass { 27 public: 28 explicit NeighborExchangeUnifyMindIR(bool multigraph = true) 29 : PatternProcessPass("neighbor_exchange_unify_mindir", multigraph) {} 30 ~NeighborExchangeUnifyMindIR() override = default; 31 const BaseRef DefinePattern() const override; 32 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 33 34 private: 35 CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange) const; 36 std::vector<std::string> MustExistPrimitiveName() const override; 37 }; 38 39 /* AllToAllUnifyMindIR 40 * let rank size is 4, for ge: 41 * Input 42 * | 43 * Input [Split(split_dim)] 44 * | / | | \ 45 * [AlltoAll] -> [AllToAllv] 46 * | \ | | / 47 * Output [Concat(concat_dim)] 48 * | 49 * Output 50 * for kbk: 51 * Input 52 * | 53 * [Split(split_dim)] 54 * / | | \ 55 * Input [Concat(dim 0)] 56 * | | 57 * [AlltoAll] -> [AllToAll] 58 * | | 59 * Output [Split(dim 0)] 60 * \ | | / 61 * [Concat(concat_dim)] 62 * | 63 * Output 64 */ 65 class AllToAllUnifyMindIR : public PatternProcessPass { 66 public: 67 explicit AllToAllUnifyMindIR(bool multigraph = true) : PatternProcessPass("all_to_all_unify_mindir", multigraph) {} 68 ~AllToAllUnifyMindIR() override = default; 69 const BaseRef DefinePattern() const override; 70 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 71 72 private: 73 CNodePtr CreateSplitNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const AnfNodePtr &input_node, 74 int64_t split_count, int64_t split_dim) const; 75 CNodePtr CreateSplitNodeWithSplitDim(const KernelGraphPtr &graph, const CNodePtr &all_to_all) const; 76 CNodePtr CreateSplitNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all, 77 const CNodePtr &input_node) const; 78 CNodePtr CreateAllToAllvNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) const; 79 CNodePtr CreateAllToAllNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &concat) const; 80 CNodePtr CreateConcatNode(const KernelGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &input_node, 81 int64_t split_count, int64_t concat_dim) const; 82 CNodePtr CreateConcatNodeWithConcatDim(const KernelGraphPtr &graph, const CNodePtr &all_to_all, 83 const CNodePtr &input_node) const; 84 CNodePtr CreateConcatNodeWithDim0(const KernelGraphPtr &graph, const CNodePtr &all_to_all, 85 const CNodePtr &input_node) const; 86 std::vector<std::string> MustExistPrimitiveName() const override; 87 }; 88 } // namespace opt 89 } // namespace mindspore 90 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_ 91