1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include <functional> 24 #include "proto/ps.pb.h" 25 #include "ps/ps_context.h" 26 #include "ps/core/server_node.h" 27 #include "fl/server/common.h" 28 29 namespace mindspore { 30 namespace fl { 31 namespace server { 32 // The timeout for server collective communication in case of network jitter. 33 constexpr uint32_t kCollectiveCommTimeout = 30; 34 35 // CollectiveOpsImpl is the collective communication API of the server. 36 // For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also 37 // supported for the elastic scaling feature of the server. 38 class CollectiveOpsImpl { 39 public: GetInstance()40 static CollectiveOpsImpl &GetInstance() { 41 static CollectiveOpsImpl instance; 42 return instance; 43 } 44 45 void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node); 46 47 template <typename T> 48 bool AllReduce(const void *sendbuff, void *recvbuff, size_t count); 49 50 // Reinitialize the ring for collective communication after scaling operations are done. 51 bool ReInitForScaling(); 52 53 private: CollectiveOpsImpl()54 CollectiveOpsImpl() : server_node_(nullptr), local_rank_(0), server_num_(0) {} 55 ~CollectiveOpsImpl() = default; 56 CollectiveOpsImpl(const CollectiveOpsImpl &) = delete; 57 CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete; 58 59 // Implementation of RingAllReduce. 60 template <typename T> 61 bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); 62 63 // Implementation of BroadcastAllReduce. 64 template <typename T> 65 bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); 66 67 std::shared_ptr<ps::core::ServerNode> server_node_; 68 uint32_t local_rank_; 69 uint32_t server_num_; 70 71 // The mutex to ensure that collective communication is threadsafe. 72 std::mutex mtx_; 73 }; 74 } // namespace server 75 } // namespace fl 76 } // namespace mindspore 77 #endif // MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_ 78