1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/armour/cipher/cipher_meta_storage.h"
18
19 namespace mindspore {
20 namespace armour {
GetClientSharesFromServer(const char * list_name,std::map<std::string,std::vector<clientshare_str>> * clients_shares_list)21 void CipherMetaStorage::GetClientSharesFromServer(
22 const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
23 if (clients_shares_list == nullptr) {
24 MS_LOG(ERROR) << "input clients_shares_list is nullptr";
25 return;
26 }
27 const fl::PBMetadata &clients_shares_pb_out =
28 fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
29 const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
30 auto iter = clients_shares_pb.client_secret_shares().begin();
31 for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) {
32 std::string fl_id = iter->first;
33 const fl::SharesPb &shares_pb = iter->second;
34 std::vector<clientshare_str> encrpted_shares_new;
35 size_t client_share_num = IntToSize(shares_pb.clientsharestrs_size());
36 for (size_t index_shares = 0; index_shares < client_share_num; ++index_shares) {
37 const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
38 clientshare_str new_clientshare;
39 new_clientshare.fl_id = client_share_str_pb.fl_id();
40 new_clientshare.index = client_share_str_pb.index();
41 new_clientshare.share.assign(client_share_str_pb.share().begin(), client_share_str_pb.share().end());
42 encrpted_shares_new.push_back(new_clientshare);
43 }
44 clients_shares_list->insert(std::pair<std::string, std::vector<clientshare_str>>(fl_id, encrpted_shares_new));
45 }
46 }
47
GetClientListFromServer(const char * list_name,std::vector<std::string> * clients_list)48 void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
49 if (clients_list == nullptr) {
50 MS_LOG(ERROR) << "input clients_list is nullptr";
51 return;
52 }
53 const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
54 const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
55 size_t client_list_num = IntToSize(client_list_pb.fl_id_size());
56 for (size_t i = 0; i < client_list_num; ++i) {
57 std::string fl_id = client_list_pb.fl_id(SizeToInt(i));
58 clients_list->push_back(fl_id);
59 }
60 }
61
GetClientKeysFromServer(const char * list_name,std::map<std::string,std::vector<std::vector<uint8_t>>> * clients_keys_list)62 void CipherMetaStorage::GetClientKeysFromServer(
63 const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
64 if (clients_keys_list == nullptr) {
65 MS_LOG(ERROR) << "input clients_keys_list is nullptr";
66 return;
67 }
68 const fl::PBMetadata &clients_keys_pb_out =
69 fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
70 const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
71
72 for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
73 std::string fl_id = iter->first;
74 fl::KeysPb keys_pb = iter->second;
75 std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
76 std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
77 std::vector<std::vector<uint8_t>> cur_keys;
78 cur_keys.push_back(cpk);
79 cur_keys.push_back(spk);
80 (void)clients_keys_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
81 }
82 }
83
GetClientIVsFromServer(const char * list_name,std::map<std::string,std::vector<std::vector<uint8_t>>> * clients_ivs_list)84 void CipherMetaStorage::GetClientIVsFromServer(
85 const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
86 if (clients_ivs_list == nullptr) {
87 MS_LOG(ERROR) << "input clients_ivs_list is nullptr";
88 return;
89 }
90 const fl::PBMetadata &clients_keys_pb_out =
91 fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
92 const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
93
94 for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
95 std::string fl_id = iter->first;
96 fl::KeysPb keys_pb = iter->second;
97 std::vector<uint8_t> ind_iv(keys_pb.ind_iv().begin(), keys_pb.ind_iv().end());
98 std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
99 std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
100
101 std::vector<std::vector<uint8_t>> cur_ivs;
102 cur_ivs.push_back(ind_iv);
103 cur_ivs.push_back(pw_iv);
104 cur_ivs.push_back(pw_salt);
105 (void)clients_ivs_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
106 }
107 }
108
GetClientNoisesFromServer(const char * list_name,std::vector<float> * cur_public_noise)109 bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
110 if (cur_public_noise == nullptr) {
111 MS_LOG(ERROR) << "input cur_public_noise is nullptr";
112 return false;
113 }
114 const fl::PBMetadata &clients_noises_pb_out =
115 fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
116 const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
117 int count = 0;
118 const int count_thld = 1000;
119 while (clients_noises_pb.has_one_client_noises() == false) {
120 const int register_time = 500;
121 std::this_thread::sleep_for(std::chrono::milliseconds(register_time));
122 count++;
123 if (count >= count_thld) break;
124 }
125 if (clients_noises_pb.has_one_client_noises() == false) {
126 MS_LOG(WARNING) << "GetClientNoisesFromServer Count: " << count;
127 return false;
128 }
129 cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
130 clients_noises_pb.one_client_noises().noise().end());
131 return true;
132 }
133
GetPrimeFromServer(const char * prime_name,uint8_t * prime)134 bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, uint8_t *prime) {
135 if (prime == nullptr) {
136 MS_LOG(ERROR) << "input prime is nullptr";
137 return false;
138 }
139 const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
140 fl::Prime prime_pb(prime_pb_out.prime());
141 std::string str = *(prime_pb.mutable_prime());
142 if (str.size() != PRIME_MAX_LEN) {
143 MS_LOG(ERROR) << "get prime size is :" << str.size();
144 return false;
145 } else {
146 if (memcpy_s(prime, PRIME_MAX_LEN, str.data(), str.size()) != 0) {
147 MS_LOG(ERROR) << "Memcpy_s error";
148 return false;
149 }
150 return true;
151 }
152 }
153
UpdateClientToServer(const char * list_name,const std::string & fl_id)154 bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
155 fl::FLId fl_id_pb;
156 fl_id_pb.set_fl_id(fl_id);
157 fl::PBMetadata client_pb;
158 client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
159 bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
160 return retcode;
161 }
162
RegisterPrime(const char * list_name,const std::string & prime)163 void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
164 MS_LOG(INFO) << "register prime: " << prime;
165 fl::Prime prime_id_pb;
166 prime_id_pb.set_prime(prime);
167 fl::PBMetadata prime_pb;
168 prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
169 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
170 uint32_t time = 1;
171 (void)sleep(time);
172 }
173
UpdateClientKeyToServer(const char * list_name,const std::string & fl_id,const std::vector<std::vector<uint8_t>> & cur_public_key)174 bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
175 const std::vector<std::vector<uint8_t>> &cur_public_key) {
176 const size_t correct_size = 2;
177 if (cur_public_key.size() < correct_size) {
178 MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
179 return false;
180 }
181 // update new item to memory server.
182 fl::KeysPb keys;
183 keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
184 keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
185 fl::PairClientKeys pair_client_keys_pb;
186 pair_client_keys_pb.set_fl_id(fl_id);
187 pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
188 fl::PBMetadata client_and_keys_pb;
189 client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
190 bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
191 return retcode;
192 }
193
UpdateClientKeyToServer(const char * list_name,const schema::RequestExchangeKeys * exchange_keys_req)194 bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
195 const schema::RequestExchangeKeys *exchange_keys_req) {
196 std::string fl_id = exchange_keys_req->fl_id()->str();
197 auto fbs_cpk = exchange_keys_req->c_pk();
198 auto fbs_spk = exchange_keys_req->s_pk();
199 if (fbs_cpk == nullptr || fbs_spk == nullptr) {
200 MS_LOG(ERROR) << "public key from exchange_keys_req is null";
201 return false;
202 }
203
204 size_t spk_len = fbs_spk->size();
205 size_t cpk_len = fbs_cpk->size();
206
207 // transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
208 std::vector<std::vector<uint8_t>> cur_public_key;
209 std::vector<uint8_t> cpk(cpk_len);
210 std::vector<uint8_t> spk(spk_len);
211 bool ret_create_code_cpk = CreateArray<uint8_t>(&cpk, *fbs_cpk);
212 bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
213 if (!(ret_create_code_cpk && ret_create_code_spk)) {
214 MS_LOG(ERROR) << "create array for public keys failed";
215 return false;
216 }
217 cur_public_key.push_back(cpk);
218 cur_public_key.push_back(spk);
219
220 auto fbs_ind_iv = exchange_keys_req->ind_iv();
221 std::vector<char> ind_iv;
222 if (fbs_ind_iv == nullptr) {
223 MS_LOG(WARNING) << "ind_iv in exchange_keys_req is nullptr";
224 } else {
225 ind_iv.assign(fbs_ind_iv->begin(), fbs_ind_iv->end());
226 }
227
228 auto fbs_pw_iv = exchange_keys_req->pw_iv();
229 std::vector<char> pw_iv;
230 if (fbs_pw_iv == nullptr) {
231 MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
232 } else {
233 pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
234 }
235
236 auto fbs_pw_salt = exchange_keys_req->pw_salt();
237 std::vector<char> pw_salt;
238 if (fbs_pw_salt == nullptr) {
239 MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
240 } else {
241 pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
242 }
243
244 // update new item to memory server.
245 fl::KeysPb keys;
246 keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
247 keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
248 keys.set_ind_iv(ind_iv.data(), ind_iv.size());
249 keys.set_pw_iv(pw_iv.data(), pw_iv.size());
250 keys.set_pw_salt(pw_salt.data(), pw_salt.size());
251 fl::PairClientKeys pair_client_keys_pb;
252 pair_client_keys_pb.set_fl_id(fl_id);
253 pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
254 fl::PBMetadata client_and_keys_pb;
255 client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
256 bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
257 return retcode;
258 }
259
UpdateClientNoiseToServer(const char * list_name,const std::vector<float> & cur_public_noise)260 bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
261 // update new item to memory server.
262 fl::OneClientNoises noises_pb;
263 *noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
264 fl::PBMetadata client_noises_pb;
265 client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
266 bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
267 return ret;
268 }
269
UpdateClientShareToServer(const char * list_name,const std::string & fl_id,const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> * shares)270 bool CipherMetaStorage::UpdateClientShareToServer(
271 const char *list_name, const std::string &fl_id,
272 const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
273 if (shares == nullptr) {
274 return false;
275 }
276 size_t size_shares = shares->size();
277 fl::SharesPb shares_pb;
278 for (size_t index = 0; index < size_shares; ++index) {
279 // new item
280 fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
281 std::string fl_id_new = (*shares)[SizeToInt(index)]->fl_id()->str();
282 int index_new = (*shares)[SizeToInt(index)]->index();
283 auto share = (*shares)[SizeToInt(index)]->share();
284 if (share == nullptr) return false;
285 client_share_str_new_p->set_share(reinterpret_cast<const char *>(share->data()), share->size());
286 client_share_str_new_p->set_fl_id(fl_id_new);
287 client_share_str_new_p->set_index(index_new);
288 }
289 fl::PairClientShares pair_client_shares_pb;
290 pair_client_shares_pb.set_fl_id(fl_id);
291 pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
292 fl::PBMetadata client_and_shares_pb;
293 client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
294 bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
295 return retcode;
296 }
297
RegisterClass()298 void CipherMetaStorage::RegisterClass() {
299 fl::PBMetadata exchange_keys_client_list;
300 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
301 exchange_keys_client_list);
302 fl::PBMetadata get_keys_client_list;
303 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
304 get_keys_client_list);
305 fl::PBMetadata clients_keys;
306 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
307 fl::PBMetadata reconstruct_client_list;
308 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxReconstructClientList,
309 reconstruct_client_list);
310 fl::PBMetadata clients_reconstruct_shares;
311 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsReconstructShares,
312 clients_reconstruct_shares);
313 fl::PBMetadata share_secretes_client_list;
314 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
315 share_secretes_client_list);
316 fl::PBMetadata get_secretes_client_list;
317 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetSecretsClientList,
318 get_secretes_client_list);
319 fl::PBMetadata clients_encrypt_shares;
320 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
321 clients_encrypt_shares);
322 fl::PBMetadata get_update_clients_list;
323 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList,
324 get_update_clients_list);
325 fl::PBMetadata client_noises;
326 fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientNoises, client_noises);
327 }
328 } // namespace armour
329 } // namespace mindspore
330