1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_ROUND_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_ROUND_H_ 19 20 #include <memory> 21 #include <string> 22 #include "ps/core/communicator/communicator_base.h" 23 #include "fl/server/common.h" 24 #include "fl/server/iteration_timer.h" 25 #include "fl/server/distributed_count_service.h" 26 #include "fl/server/kernel/round/round_kernel.h" 27 28 namespace mindspore { 29 namespace fl { 30 namespace server { 31 // Round helps server to handle network round messages and launch round kernels. One iteration in server consists of 32 // multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting 33 // and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic. 34 class Round { 35 public: 36 explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000, 37 bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false); 38 ~Round() = default; 39 40 void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb, 41 const FinishIterCb &finish_iteration_cb); 42 43 // Reinitialize count service and round kernel of this round after scaling operations are done. 44 bool ReInitForScaling(uint32_t server_num); 45 46 // After hyper-parameters are updated, some rounds and kernels should be reinitialized. 47 bool ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window); 48 49 // Bind a round kernel to this Round. This method should be called after Initialize. 50 void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel); 51 52 // This method is the callback which will be set to the communicator and called after the corresponding round message 53 // is sent to the server. 54 void LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message); 55 56 // Round needs to be reset after each iteration is finished or its timer expires. 57 void Reset(); 58 59 const std::string &name() const; 60 size_t threshold_count() const; 61 bool check_timeout() const; 62 size_t time_window() const; 63 64 private: 65 // The callbacks which will be set to DistributedCounterService. 66 void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message); 67 void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message); 68 69 // Judge whether the training service is available. 70 bool IsServerAvailable(std::string *reason); 71 72 std::string name_; 73 74 // Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set 75 // check_timeout_ to true. 76 bool check_timeout_; 77 78 // The time window duration for this round when check_timeout_ is set to true. 79 size_t time_window_; 80 81 // If check_count_ is true, it means the round has to do counting for every round message and the first/last count 82 // event will be triggered. 83 bool check_count_; 84 85 // The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether 86 // the round message count has reached threshold_count_. 87 size_t threshold_count_; 88 89 // Whether this round uses the server number as its threshold count. 90 bool server_num_as_threshold_; 91 92 std::shared_ptr<ps::core::CommunicatorBase> communicator_; 93 94 // The round kernel for this Round. 95 std::shared_ptr<kernel::RoundKernel> kernel_; 96 97 // Some rounds may need timer to eliminate the long tail effect. 98 std::shared_ptr<IterationTimer> iter_timer_; 99 100 // The callbacks which will be set to the round kernel. 101 StopTimerCb stop_timer_cb_; 102 FinishIterCb finish_iteration_cb_; 103 FinalizeCb finalize_cb_; 104 }; 105 } // namespace server 106 } // namespace fl 107 } // namespace mindspore 108 #endif // MINDSPORE_CCSRC_FL_SERVER_ROUND_H_ 109