• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
19 
20 #include <memory>
21 #include <string>
22 #include "ps/core/communicator/communicator_base.h"
23 #include "fl/server/common.h"
24 #include "fl/server/iteration_timer.h"
25 #include "fl/server/distributed_count_service.h"
26 #include "fl/server/kernel/round/round_kernel.h"
27 
28 namespace mindspore {
29 namespace fl {
30 namespace server {
31 // Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
32 // multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
33 // and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic.
34 class Round {
35  public:
36   explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000,
37                  bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
38   ~Round() = default;
39 
40   void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
41                   const FinishIterCb &finish_iteration_cb);
42 
43   // Reinitialize count service and round kernel of this round after scaling operations are done.
44   bool ReInitForScaling(uint32_t server_num);
45 
46   // After hyper-parameters are updated, some rounds and kernels should be reinitialized.
47   bool ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window);
48 
49   // Bind a round kernel to this Round. This method should be called after Initialize.
50   void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
51 
52   // This method is the callback which will be set to the communicator and called after the corresponding round message
53   // is sent to the server.
54   void LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message);
55 
56   // Round needs to be reset after each iteration is finished or its timer expires.
57   void Reset();
58 
59   const std::string &name() const;
60   size_t threshold_count() const;
61   bool check_timeout() const;
62   size_t time_window() const;
63 
64  private:
65   // The callbacks which will be set to DistributedCounterService.
66   void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
67   void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
68 
69   // Judge whether the training service is available.
70   bool IsServerAvailable(std::string *reason);
71 
72   std::string name_;
73 
74   // Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set
75   // check_timeout_ to true.
76   bool check_timeout_;
77 
78   // The time window duration for this round when check_timeout_ is set to true.
79   size_t time_window_;
80 
81   // If check_count_ is true, it means the round has to do counting for every round message and the first/last count
82   // event will be triggered.
83   bool check_count_;
84 
85   // The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether
86   // the round message count has reached threshold_count_.
87   size_t threshold_count_;
88 
89   // Whether this round uses the server number as its threshold count.
90   bool server_num_as_threshold_;
91 
92   std::shared_ptr<ps::core::CommunicatorBase> communicator_;
93 
94   // The round kernel for this Round.
95   std::shared_ptr<kernel::RoundKernel> kernel_;
96 
97   // Some rounds may need timer to eliminate the long tail effect.
98   std::shared_ptr<IterationTimer> iter_timer_;
99 
100   // The callbacks which will be set to the round kernel.
101   StopTimerCb stop_timer_cb_;
102   FinishIterCb finish_iteration_cb_;
103   FinalizeCb finalize_cb_;
104 };
105 }  // namespace server
106 }  // namespace fl
107 }  // namespace mindspore
108 #endif  // MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
109