1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
19
20 #include <map>
21 #include <string>
22 #include <numeric>
23 #include <climits>
24 #include <memory>
25 #include <functional>
26 #include "proto/ps.pb.h"
27 #include "proto/fl.pb.h"
28 #include "ir/anf.h"
29 #include "utils/utils.h"
30 #include "ir/dtype/type_id.h"
31 #include "backend/kernel_compiler/cpu/cpu_kernel.h"
32 #include "schema/fl_job_generated.h"
33 #include "schema/cipher_generated.h"
34 #include "ps/ps_context.h"
35 #include "ps/core/communicator/http_message_handler.h"
36 #include "ps/core/communicator/tcp_server.h"
37 #include "ps/core/communicator/message_handler.h"
38
39 namespace mindspore {
40 namespace fl {
41 namespace server {
42 // Definitions for the server framework.
43 enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
44 enum CommType { HTTP = 0, TCP };
45 enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
46
47 struct RoundConfig {
48 // The name of round. Please refer to round kernel *.cc files.
49 std::string name;
50 // Whether this round has the time window limit.
51 bool check_timeout = false;
52 // The length of the time window. Only used when check_timeout is set to true.
53 size_t time_window = 3000;
54 // Whether this round has to check the request count has reached the threshold.
55 bool check_count = false;
56 // This round's request threshold count. Only used when check_count is set to true.
57 size_t threshold_count = 0;
58 // Whether this round uses the server as threshold count. This is vital for some rounds in elastic scaling scenario.
59 bool server_num_as_threshold = false;
60 };
61
62 struct CipherConfig {
63 float share_secrets_ratio = 1.0;
64 uint64_t cipher_time_window = 300000;
65 size_t exchange_keys_threshold = 0;
66 size_t get_keys_threshold = 0;
67 size_t share_secrets_threshold = 0;
68 size_t get_secrets_threshold = 0;
69 size_t client_list_threshold = 0;
70 size_t reconstruct_secrets_threshold = 0;
71 };
72
73 // Every instance is one training loop that runs fl_iteration_num iterations of federated learning.
74 // During every instance, server's training process could be controlled by scheduler, which will change the state of
75 // this instance.
76 enum class InstanceState {
77 // If this instance is in kRunning state, server could communicate with client/worker and the traning process moves
78 // on.
79 kRunning = 0,
80 // The server is not available for client/worker if in kDisable state.
81 kDisable,
82 // The server is not available for client/worker if in kDisable state. And this state means one instance has finished.
83 // In other words, fl_iteration_num iterations are completed.
84 kFinish
85 };
86
87 using mindspore::kernel::Address;
88 using mindspore::kernel::AddressPtr;
89 using mindspore::kernel::CPUKernel;
90 using FBBuilder = flatbuffers::FlatBufferBuilder;
91 using TimeOutCb = std::function<void(bool, const std::string &)>;
92 using StopTimerCb = std::function<void(void)>;
93 using FinishIterCb = std::function<void(bool, const std::string &)>;
94 using FinalizeCb = std::function<void(void)>;
95 using MessageCallback = std::function<void(const std::shared_ptr<ps::core::MessageHandler> &)>;
96
97 // Information about whether server kernel will reuse kernel node memory from the front end.
98 // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
99 // Value refers to the kernel node's parameter index.
100 using ReuseKernelNodeInfo = std::map<std::string, size_t>;
101
102 // UploadData refers to the data which is uploaded by workers.
103 // Key refers to the data name. For example: "weights", "grad", "learning_rate", etc. This will be set by the worker.
104 // Value refers to the data of the key.
105
106 // We use Address instead of AddressPtr because:
107 // 1. Address doesn't need to call make_shared<T> so it has better performance.
108 // 2. The data uploaded by worker is normally parsed from FlatterBuffers or ProtoBuffer. For example: learning rate, new
109 // weights, etc. Address is enough to store these data.
110
111 // Pay attention that Address only stores the void* pointer of the data, so the data must not be released before the
112 // related logic is done.
113 using UploadData = std::map<std::string, Address>;
114
115 constexpr auto kWeight = "weight";
116 constexpr auto kNewWeight = "new_weight";
117 constexpr auto kAccumulation = "accum";
118 constexpr auto kLearningRate = "lr";
119 constexpr auto kGradient = "grad";
120 constexpr auto kNewGradient = "new_grad";
121 constexpr auto kMomentum = "momentum";
122 constexpr auto kIndices = "indices";
123 constexpr auto kAdamM = "m";
124 constexpr auto kAdamV = "v";
125 constexpr auto kAdamBeta1Power = "beta1_power";
126 constexpr auto kAdamBeta2Power = "beta2_power";
127 constexpr auto kAdamBeta1 = "beta1";
128 constexpr auto kAdamBeta2 = "beta2";
129 constexpr auto kAdamEps = "eps";
130 constexpr auto kFtrlLinear = "linear";
131 constexpr auto kDataSize = "data_size";
132 constexpr auto kNewDataSize = "new_data_size";
133 constexpr auto kStat = "stat";
134
135 // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
136 // launched.
137 using OptimParamNameToIndex = std::map<std::string, std::map<std::string, size_t>>;
138 const OptimParamNameToIndex kMomentumNameToIdx = {
139 {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kLearningRate, 2}, {kGradient, 3}, {kMomentum, 4}}}, {"outputs", {}}};
140 const OptimParamNameToIndex kAdamNameToIdx = {{"inputs",
141 {{kWeight, 0},
142 {kAdamM, 1},
143 {kAdamV, 2},
144 {kAdamBeta1Power, 3},
145 {kAdamBeta2Power, 4},
146 {kLearningRate, 5},
147 {kAdamBeta1, 6},
148 {kAdamBeta2, 7},
149 {kAdamEps, 8},
150 {kGradient, 9}}},
151 {"outputs", {}}};
152 const OptimParamNameToIndex kSparseAdamNameToIdx = {{"inputs",
153 {{kWeight, 0},
154 {kAdamM, 1},
155 {kAdamV, 2},
156 {kAdamBeta1Power, 3},
157 {kAdamBeta2Power, 4},
158 {kLearningRate, 5},
159 {kAdamBeta1, 6},
160 {kAdamBeta1, 7},
161 {kAdamEps, 8},
162 {kGradient, 9},
163 {kIndices, 10}}},
164 {"outputs", {}}};
165 const OptimParamNameToIndex kSparseFtrlNameToIdx = {
166 {"inputs", {{kWeight, 0}, {kAccumulation, 1}, {kFtrlLinear, 2}, {kGradient, 3}, {kIndices, 4}}}, {"outputs", {}}};
167 const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs",
168 {{"weight", 0},
169 {"m", 1},
170 {"v", 2},
171 {"lr", 3},
172 {"beta1", 4},
173 {"beta2", 5},
174 {"eps", 6},
175 {"weight_decay", 7},
176 {"grad", 8}}},
177 {"outputs", {}}};
178 const OptimParamNameToIndex kSGDNameToIdx = {
179 {"inputs", {{kWeight, 0}, {kGradient, 1}, {kLearningRate, 2}, {kAccumulation, 3}, {kMomentum, 4}, {kStat, 5}}},
180 {"outputs", {}}};
181
182 const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
183 {kApplyMomentumOpName, kMomentumNameToIdx}, {kFusedSparseAdamName, kSparseAdamNameToIdx},
184 {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, {kApplyAdamOpName, kAdamNameToIdx},
185 {"AdamWeightDecay", kAdamWeightDecayNameToIdx}, {kSGDName, kSGDNameToIdx}};
186
187 constexpr uint32_t kLeaderServerRank = 0;
188 constexpr size_t kWorkerMgrThreadPoolSize = 32;
189 constexpr size_t kWorkerMgrMaxTaskNum = 64;
190 constexpr size_t kCipherMgrThreadPoolSize = 32;
191 constexpr size_t kCipherMgrMaxTaskNum = 64;
192 constexpr size_t kExecutorThreadPoolSize = 32;
193 constexpr size_t kExecutorMaxTaskNum = 32;
194 constexpr int kHttpSuccess = 200;
195 constexpr uint32_t kThreadSleepTime = 50;
196 constexpr auto kPBProtocol = "PB";
197 constexpr auto kFBSProtocol = "FBS";
198 constexpr auto kSuccess = "Success";
199 constexpr auto kFedAvg = "FedAvg";
200 constexpr auto kAggregationKernelType = "Aggregation";
201 constexpr auto kOptimizerKernelType = "Optimizer";
202 constexpr auto kCtxFuncGraph = "FuncGraph";
203 constexpr auto kCtxIterNum = "iteration";
204 constexpr auto kCtxDeviceMetas = "device_metas";
205 constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
206 constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
207 constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
208 constexpr auto kCtxUpdateModelThld = "update_model_threshold";
209 constexpr auto kCtxClientsKeys = "clients_keys";
210 constexpr auto kCtxClientNoises = "clients_noises";
211 constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
212 constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
213 constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
214 constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
215 constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
216 constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
217 constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
218 constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
219 constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
220 constexpr auto kCtxCipherPrimer = "cipher_primer";
221
222 // This macro the current timestamp in milliseconds.
223 #define CURRENT_TIME_MILLI \
224 std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
225
226 // This method returns the size in bytes of the given TypeId.
GetTypeIdByte(const TypeId & type)227 inline size_t GetTypeIdByte(const TypeId &type) {
228 switch (type) {
229 case kNumberTypeFloat16:
230 return 2;
231 case kNumberTypeUInt32:
232 case kNumberTypeFloat32:
233 return 4;
234 case kNumberTypeUInt64:
235 return 8;
236 default:
237 MS_LOG(EXCEPTION) << "TypeId " << type << " not supported.";
238 return 0;
239 }
240 }
241
GenerateParameterNodeAddrPtr(const CNodePtr & kernel_node,size_t param_idx)242 inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size_t param_idx) {
243 MS_ERROR_IF_NULL_W_RET_VAL(kernel_node, nullptr);
244 auto param_node =
245 AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, param_idx), 0).first->cast<ParameterPtr>();
246 MS_ERROR_IF_NULL_W_RET_VAL(param_node, nullptr);
247 auto param_tensor = param_node->default_param()->cast<tensor::TensorPtr>();
248 MS_ERROR_IF_NULL_W_RET_VAL(param_tensor, nullptr);
249 AddressPtr addr = std::make_shared<kernel::Address>();
250 addr->addr = param_tensor->data_c();
251 addr->size = param_tensor->data().nbytes();
252 return addr;
253 }
254
255 // Definitions for Federated Learning.
256
257 constexpr auto kNetworkError = "Cluster networking failed.";
258
259 // The result code used for round kernels.
260 enum class ResultCode {
261 // If the method is successfully called and round kernel's residual methods should be called, return kSuccess.
262 kSuccess = 0,
263 // If there's error happened in the method and residual methods should not be called but this iteration continues,
264 // return kSuccessAndReturn so that framework won't drop this iteration.
265 kSuccessAndReturn,
266 // If there's error happened and this iteration should be dropped, return kFail.
267 kFail
268 };
269
ConvertResultCode(ResultCode result_code)270 bool inline ConvertResultCode(ResultCode result_code) {
271 switch (result_code) {
272 case ResultCode::kSuccess:
273 return true;
274 case ResultCode::kSuccessAndReturn:
275 return true;
276 case ResultCode::kFail:
277 return false;
278 default:
279 return true;
280 }
281 }
282
283 // Definitions for Parameter Server.
284
285 } // namespace server
286 } // namespace fl
287 } // namespace mindspore
288 #endif // MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
289