1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 #include <functional> 25 #include "proto/ps.pb.h" 26 #include "include/backend/distributed/ps/ps_context.h" 27 #include "ps/core/server_node.h" 28 29 namespace mindspore { 30 namespace fl { 31 namespace server { 32 // The timeout for server collective communication in case of network jitter. 33 constexpr uint32_t kCollectiveCommTimeout = 30; 34 // The max timeout for server collective communication, used in disaster recovery to prevent networking flapping. 35 constexpr uint32_t kCollectiveCommMaxTimeout = 300; 36 37 // The collective communication groups which are composed of multiple processes. Refer to MPI_Group. 38 struct CommunicationGroupInfo { 39 // This group's rank size. 40 uint32_t size; 41 42 // This process's global rank id. 43 uint32_t global_rank; 44 45 // The group ranks consists of global ranks of the processes. 46 std::vector<uint32_t> group_ranks; 47 48 // The mapping of global ranks and group ranks. 49 std::map<uint32_t, uint32_t> global_to_group_ranks; 50 std::map<uint32_t, uint32_t> group_to_global_ranks; 51 }; 52 53 // CollectiveOpsImpl is the collective communication API of the server. 54 // For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also 55 // supported for the elastic scaling feature of the server. 56 class CollectiveOpsImpl { 57 public: GetInstance()58 static CollectiveOpsImpl &GetInstance() { 59 static CollectiveOpsImpl instance; 60 return instance; 61 } 62 63 void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node); 64 65 template <typename T> 66 bool AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count); 67 68 template <typename T> 69 bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count, const ps::core::AbstractNodePtr &node); 70 71 // Collective broadcast within the specified group. The parameter "root" is the group rank of the root process. 72 // Normally 0. 73 template <typename T> 74 bool Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root, 75 const ps::core::AbstractNodePtr &node, const CommunicationGroupInfo &group_info); 76 77 private: CollectiveOpsImpl()78 CollectiveOpsImpl() 79 : server_node_(nullptr), 80 rank_id_(0), 81 server_num_(0), 82 node_(nullptr), 83 node_role_(ps::core::NodeRole::WORKER), 84 rank_size_(0) {} 85 ~CollectiveOpsImpl() = default; 86 CollectiveOpsImpl(const CollectiveOpsImpl &) = delete; 87 CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete; 88 89 // Implementation of RingAllReduce. 90 template <typename T> 91 bool RunRingAllReduce(const std::string &data_name, uint32_t send_to_rank, uint32_t recv_from_rank, 92 const std::vector<size_t> &chunk_sizes, const std::vector<size_t> &chunk_offset, 93 T *output_buff); 94 95 // Implementation of RingAllReduce. 96 template <typename T> 97 bool RingAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count); 98 99 // Implementation of BroadcastAllReduce. 100 template <typename T> 101 bool ReduceBroadcastAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count); 102 103 // Implementation of RingAllGather. 104 template <typename T> 105 bool RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count); 106 107 // Implementation of Broadcast. The parameter "root" is the group rank of the root process. Normally 0. 108 template <typename T> 109 bool Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root, 110 const CommunicationGroupInfo &group_info); 111 112 std::shared_ptr<ps::core::ServerNode> server_node_; 113 uint32_t rank_id_; 114 uint32_t server_num_; 115 116 // The mutex to ensure that collective communication is threadsafe. 117 std::mutex mtx_; 118 119 // The abstract node could be worker or server. Only nodes which have the same role could use collective 120 // communication. 121 ps::core::AbstractNodePtr node_; 122 ps::core::NodeRole node_role_; 123 uint32_t rank_size_; 124 }; 125 } // namespace server 126 } // namespace fl 127 } // namespace mindspore 128 #endif // MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_ 129