1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ 18 #include <utility> 19 #include <vector> 20 #include <string> 21 #include "include/backend/visible.h" 22 #include "include/backend/optimizer/pass.h" 23 #include "ir/func_graph.h" 24 #include "ir/anf.h" 25 #include "include/common/utils/utils.h" 26 #include "ops/array_op_name.h" 27 #include "ops/ascend_op_name.h" 28 #include "ops/framework_op_name.h" 29 #include "ops/other_op_name.h" 30 31 namespace mindspore { 32 namespace opt { 33 struct CommunicationOpInfo { 34 std::vector<CNodePtr> communication_op_nodes; 35 std::vector<float> input_grad_size; 36 std::vector<float> input_grad_time; 37 std::string group_name; 38 }; 39 40 class BACKEND_EXPORT CommunicationOpFusion : public Pass { 41 public: 42 explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) Pass(name)43 : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} 44 ~CommunicationOpFusion() override = default; 45 bool Run(const FuncGraphPtr &func_graph) override; 46 47 private: 48 bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, 49 const std::vector<size_t> &segment_index) const; 50 void GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold, 51 std::vector<size_t> *segment_index) const; 52 AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, 53 const CommunicationOpInfo &communication_op_info, size_t start_index, 54 size_t end_index) const; 55 bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, std::vector<size_t> *segment_index, 56 const std::string &group) const; 57 std::string op_name_; 58 size_t groups_ = 1; 59 }; 60 61 class SendFusion : public CommunicationOpFusion { 62 public: 63 explicit SendFusion(size_t groups = 1) : CommunicationOpFusion("send_fusion", kSendOpName, groups) {} 64 ~SendFusion() override = default; 65 }; 66 67 class RecvFusion : public CommunicationOpFusion { 68 public: 69 explicit RecvFusion(size_t groups = 1) : CommunicationOpFusion("recv_fusion", kReceiveOpName, groups) {} 70 ~RecvFusion() override = default; 71 }; 72 73 class AllReduceFusion : public CommunicationOpFusion { 74 public: 75 explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} 76 ~AllReduceFusion() override = default; 77 }; 78 79 class AllGatherFusion : public CommunicationOpFusion { 80 public: 81 explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} 82 ~AllGatherFusion() override = default; 83 }; 84 85 class BroadcastFusion : public CommunicationOpFusion { 86 public: 87 explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} 88 ~BroadcastFusion() override = default; 89 }; 90 91 class ReduceScatterFusion : public CommunicationOpFusion { 92 public: 93 explicit ReduceScatterFusion(size_t groups = 1) 94 : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} 95 ~ReduceScatterFusion() override = default; 96 }; 97 } // namespace opt 98 } // namespace mindspore 99 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_COMMUNICATION_OP_FUSION_H_ 100