1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ 18 #include <utility> 19 #include <vector> 20 #include <string> 21 22 #include "backend/optimizer/common/pass.h" 23 #include "ir/func_graph.h" 24 #include "ir/anf.h" 25 #include "utils/utils.h" 26 27 namespace mindspore { 28 namespace opt { 29 struct CommunicationOpInfo { 30 std::vector<CNodePtr> communication_op_nodes; 31 std::vector<float> input_grad_size; 32 std::vector<float> input_grad_time; 33 }; 34 35 class CommunicationOpFusion : public Pass { 36 public: 37 explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) Pass(name)38 : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} 39 ~CommunicationOpFusion() override = default; 40 bool Run(const FuncGraphPtr &graph) override; 41 42 private: 43 bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num, 44 const std::vector<size_t> &segment_index) const; 45 AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, 46 const CommunicationOpInfo &communication_op_info, size_t start_index, 47 size_t end_index) const; 48 bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, 49 std::vector<size_t> *segment_index, const std::string &group) const; 50 std::string op_name_; 51 size_t groups_ = 1; 52 }; 53 54 class SendFusion : public CommunicationOpFusion { 55 public: 56 explicit SendFusion(size_t groups = 1) : CommunicationOpFusion("send_fusion", kHcomSendOpName, groups) {} 57 ~SendFusion() override = default; 58 }; 59 60 class RecvFusion : public CommunicationOpFusion { 61 public: 62 explicit RecvFusion(size_t groups = 1) : CommunicationOpFusion("recv_fusion", kReceiveOpName, groups) {} 63 ~RecvFusion() override = default; 64 }; 65 66 class AllReduceFusion : public CommunicationOpFusion { 67 public: 68 explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} 69 ~AllReduceFusion() override = default; 70 }; 71 72 class AllGatherFusion : public CommunicationOpFusion { 73 public: 74 explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} 75 ~AllGatherFusion() override = default; 76 }; 77 78 class BroadcastFusion : public CommunicationOpFusion { 79 public: 80 explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} 81 ~BroadcastFusion() override = default; 82 }; 83 84 class ReduceScatterFusion : public CommunicationOpFusion { 85 public: 86 explicit ReduceScatterFusion(size_t groups = 1) 87 : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} 88 ~ReduceScatterFusion() override = default; 89 }; 90 } // namespace opt 91 } // namespace mindspore 92 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ 93