• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
18 #define MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "proto/comm.pb.h"
24 #include "schema/fl_job_generated.h"
25 #include "schema/cipher_generated.h"
26 #include "ps/ps_context.h"
27 #include "ps/core/worker_node.h"
28 #include "ps/core/cluster_metadata.h"
29 #include "ps/core/communicator/tcp_communicator.h"
30 
31 namespace mindspore {
32 namespace fl {
33 using FBBuilder = flatbuffers::FlatBufferBuilder;
34 
35 // The step number for worker to judge whether to communicate with server.
36 constexpr uint32_t kTrainBeginStepNum = 1;
37 constexpr uint32_t kTrainEndStepNum = 0;
38 constexpr uint32_t kOneStepPerIteration = 1;
39 
40 // The sleeping time of the worker thread before the networking is completed.
41 constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
42 
43 // The time duration between retrying when server is in safemode.
44 constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
45 
46 // The rank of the leader server.
47 constexpr uint32_t kLeaderServerRank = 0;
48 
49 // The timeout for worker sending message to server in case of network jitter.
50 constexpr uint32_t kWorkerTimeout = 30;
51 
52 enum class IterationState {
53   // This iteration is still in process.
54   kRunning,
55   // This iteration is completed and the next iteration is not started yet.
56   kCompleted
57 };
58 
59 namespace worker {
60 // This class is used for hybrid training mode for now. In later version, parameter server mode will also use this class
61 // as worker.
62 class FLWorker {
63  public:
GetInstance()64   static FLWorker &GetInstance() {
65     static FLWorker instance;
66     return instance;
67   }
68   void Run();
69   void Finalize();
70   bool SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
71                     std::shared_ptr<std::vector<unsigned char>> *output = nullptr);
72 
73   uint32_t server_num() const;
74   uint32_t worker_num() const;
75   uint32_t rank_id() const;
76   uint64_t worker_step_num_per_iteration() const;
77 
78   // Check whether worker has exited.
79   bool running() const;
80 
81   // These methods set the worker's iteration state.
82   void SetIterationRunning();
83   void SetIterationCompleted();
84 
85   void set_fl_iteration_num(uint64_t iteration_num);
86   uint64_t fl_iteration_num() const;
87 
88   void set_data_size(int data_size);
89   int data_size() const;
90 
91   std::string fl_name() const;
92   std::string fl_id() const;
93 
94  private:
FLWorker()95   FLWorker()
96       : running_(false),
97         server_num_(0),
98         worker_num_(0),
99         scheduler_ip_(""),
100         scheduler_port_(0),
101         worker_node_(nullptr),
102         rank_id_(UINT32_MAX),
103         iteration_num_(0),
104         data_size_(0),
105         worker_step_num_per_iteration_(1),
106         server_iteration_state_(IterationState::kCompleted),
107         worker_iteration_state_(IterationState::kCompleted),
108         safemode_(false) {}
109   ~FLWorker() = default;
110   FLWorker(const FLWorker &) = delete;
111   FLWorker &operator=(const FLWorker &) = delete;
112 
113   // Initialize the scaler for worker
114   void InitializeFollowerScaler();
115 
116   // The handlers for the iteration state events.
117   void HandleIterationRunningEvent();
118   void HandleIterationCompletedEvent();
119 
120   // The barriers before scaling operations.
121   void ProcessBeforeScalingOut();
122   void ProcessBeforeScalingIn();
123 
124   // The handlers after scheduler's scaling operations are done.
125   void ProcessAfterScalingOut();
126   void ProcessAfterScalingIn();
127 
128   std::atomic_bool running_;
129   uint32_t server_num_;
130   uint32_t worker_num_;
131   std::string scheduler_ip_;
132   uint16_t scheduler_port_;
133   std::shared_ptr<ps::core::WorkerNode> worker_node_;
134   uint32_t rank_id_;
135 
136   // The federated learning iteration number.
137   std::atomic<uint64_t> iteration_num_;
138 
139   // Data size for this federated learning job.
140   int data_size_;
141 
142   // The worker standalone training step number before communicating with server. This used in hybrid training mode.
143   uint64_t worker_step_num_per_iteration_;
144 
145   // The iteration state is either running or completed.
146   // This variable represents the server iteration state and should be changed by events
147   // kIterationRunning/kIterationCompleted. triggered by server.
148   std::atomic<IterationState> server_iteration_state_;
149 
150   // This variable represents the worker iteration state and should be changed by worker training process.
151   std::atomic<IterationState> worker_iteration_state_;
152 
153   // The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state.
154   std::atomic_bool safemode_;
155 };
156 }  // namespace worker
157 }  // namespace fl
158 }  // namespace mindspore
159 #endif  // MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
160