• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <functional>
25 #include "proto/ps.pb.h"
26 #include "include/backend/distributed/ps/ps_context.h"
27 #include "ps/core/server_node.h"
28 
29 namespace mindspore {
30 namespace fl {
31 namespace server {
32 // The timeout for server collective communication in case of network jitter.
33 constexpr uint32_t kCollectiveCommTimeout = 30;
34 // The max timeout for server collective communication, used in disaster recovery to prevent networking flapping.
35 constexpr uint32_t kCollectiveCommMaxTimeout = 300;
36 
37 // The collective communication groups which are composed of multiple processes. Refer to MPI_Group.
38 struct CommunicationGroupInfo {
39   // This group's rank size.
40   uint32_t size;
41 
42   // This process's global rank id.
43   uint32_t global_rank;
44 
45   // The group ranks consists of global ranks of the processes.
46   std::vector<uint32_t> group_ranks;
47 
48   // The mapping of global ranks and group ranks.
49   std::map<uint32_t, uint32_t> global_to_group_ranks;
50   std::map<uint32_t, uint32_t> group_to_global_ranks;
51 };
52 
53 // CollectiveOpsImpl is the collective communication API of the server.
54 // For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
55 // supported for the elastic scaling feature of the server.
56 class CollectiveOpsImpl {
57  public:
GetInstance()58   static CollectiveOpsImpl &GetInstance() {
59     static CollectiveOpsImpl instance;
60     return instance;
61   }
62 
63   void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
64 
65   template <typename T>
66   bool AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count);
67 
68   template <typename T>
69   bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count, const ps::core::AbstractNodePtr &node);
70 
71   // Collective broadcast within the specified group. The parameter "root" is the group rank of the root process.
72   // Normally 0.
73   template <typename T>
74   bool Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
75                  const ps::core::AbstractNodePtr &node, const CommunicationGroupInfo &group_info);
76 
77  private:
CollectiveOpsImpl()78   CollectiveOpsImpl()
79       : server_node_(nullptr),
80         rank_id_(0),
81         server_num_(0),
82         node_(nullptr),
83         node_role_(ps::core::NodeRole::WORKER),
84         rank_size_(0) {}
85   ~CollectiveOpsImpl() = default;
86   CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
87   CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
88 
89   // Implementation of RingAllReduce.
90   template <typename T>
91   bool RunRingAllReduce(const std::string &data_name, uint32_t send_to_rank, uint32_t recv_from_rank,
92                         const std::vector<size_t> &chunk_sizes, const std::vector<size_t> &chunk_offset,
93                         T *output_buff);
94 
95   // Implementation of RingAllReduce.
96   template <typename T>
97   bool RingAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count);
98 
99   // Implementation of BroadcastAllReduce.
100   template <typename T>
101   bool ReduceBroadcastAllReduce(const std::string &data_name, const void *sendbuff, void *recvbuff, size_t count);
102 
103   // Implementation of RingAllGather.
104   template <typename T>
105   bool RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count);
106 
107   // Implementation of Broadcast. The parameter "root" is the group rank of the root process. Normally 0.
108   template <typename T>
109   bool Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
110                  const CommunicationGroupInfo &group_info);
111 
112   std::shared_ptr<ps::core::ServerNode> server_node_;
113   uint32_t rank_id_;
114   uint32_t server_num_;
115 
116   // The mutex to ensure that collective communication is threadsafe.
117   std::mutex mtx_;
118 
119   // The abstract node could be worker or server. Only nodes which have the same role could use collective
120   // communication.
121   ps::core::AbstractNodePtr node_;
122   ps::core::NodeRole node_role_;
123   uint32_t rank_size_;
124 };
125 }  // namespace server
126 }  // namespace fl
127 }  // namespace mindspore
128 #endif  // MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
129