1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ 19 20 #include <set> 21 #include <string> 22 #include <memory> 23 #include <unordered_map> 24 #include "proto/ps.pb.h" 25 #include "fl/server/common.h" 26 #include "ps/core/server_node.h" 27 #include "ps/core/communicator/tcp_communicator.h" 28 29 namespace mindspore { 30 namespace fl { 31 namespace server { 32 constexpr uint32_t kDefaultCountingServerRank = 0; 33 constexpr auto kModuleDistributedCountService = "DistributedCountService"; 34 // The callbacks for the first count and last count event. 35 typedef struct { 36 MessageCallback first_count_handler; 37 MessageCallback last_count_handler; 38 } CounterHandlers; 39 40 // DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds, 41 // aggregation counting, etc. 42 43 // The counting could be called by any server, but only one server has the information 44 // of the cluster count and we mark this server as the counting server. Other servers must communicate with this 45 // counting server to increase/query count number. 46 47 // On the first count or last count event, DistributedCountService on the counting server triggers the event on other 48 // servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency. 49 class DistributedCountService { 50 public: GetInstance()51 static DistributedCountService &GetInstance() { 52 static DistributedCountService instance; 53 return instance; 54 } 55 56 // Initialize counter service with the server node because communication is needed. 57 void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node, uint32_t counting_server_rank); 58 59 // Register message callbacks of the counting server to handle messages sent by the other servers. 60 void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator); 61 62 // Register counter to the counting server for the name with its threshold count in server cluster dimension and 63 // first/last count event callbacks. 64 void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers); 65 66 // Reinitialize counter due to the change of threshold count. 67 bool ReInitCounter(const std::string &name, size_t global_threshold_count); 68 69 // Report a count to the counting server. Parameter 'id' is in case of repeated counting. Parameter 'reason' is the 70 // reason why counting failed. 71 bool Count(const std::string &name, const std::string &id, std::string *reason = nullptr); 72 73 // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, 74 // this method returns true. 75 bool CountReachThreshold(const std::string &name); 76 77 // Reset the count of the name to 0. 78 void ResetCounter(const std::string &name); 79 80 // Reinitialize counting service after scaling operations are done. 81 bool ReInitForScaling(); 82 83 // Returns the server rank because in some cases the callers use this rank as the 'id' for method 84 // Count. local_rank()85 uint32_t local_rank() { return local_rank_; } 86 87 private: 88 DistributedCountService() = default; 89 ~DistributedCountService() = default; 90 DistributedCountService(const DistributedCountService &) = delete; 91 DistributedCountService &operator=(const DistributedCountService &) = delete; 92 93 // Callback for the reporting count message from other servers. Only counting server will call this method. 94 void HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 95 96 // Callback for the querying whether threshold count is reached message from other servers. Only counting 97 // server will call this method. 98 void HandleCountReachThresholdRequest(const std::shared_ptr<ps::core::MessageHandler> &message); 99 100 // Callback for the first/last event message from the counting server. Only other servers will call this 101 // method. 102 void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message); 103 104 // Call the callbacks when the first/last count event is triggered. 105 bool TriggerCounterEvent(const std::string &name, std::string *reason = nullptr); 106 bool TriggerFirstCountEvent(const std::string &name, std::string *reason = nullptr); 107 bool TriggerLastCountEvent(const std::string &name, std::string *reason = nullptr); 108 109 // Members for the communication between counting server and other servers. 110 std::shared_ptr<ps::core::ServerNode> server_node_; 111 std::shared_ptr<ps::core::TcpCommunicator> communicator_; 112 uint32_t local_rank_; 113 uint32_t server_num_; 114 115 // Only one server will be set to do the real counting. 116 uint32_t counting_server_rank_; 117 118 // Key: name, e.g, startFLJob, updateModel, push. 119 // Value: a set of id without repeatation because each work may report multiple times. 120 std::unordered_map<std::string, std::set<std::string>> global_current_count_; 121 122 // Key: name, e.g, StartFLJobCount. 123 // Value: global threshold count in the server cluster dimension for this name. 124 std::unordered_map<std::string, size_t> global_threshold_count_; 125 126 // First/last count event callbacks of the name. 127 std::unordered_map<std::string, CounterHandlers> counter_handlers_; 128 129 // Because the count is increased/queried conccurently, we must ensure the operations are threadsafe. 130 std::unordered_map<std::string, std::mutex> mutex_; 131 }; 132 } // namespace server 133 } // namespace fl 134 } // namespace mindspore 135 #endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ 136