1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fl/server/kernel/round/get_secrets_kernel.h"
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include <map>
22 #include <utility>
23 #include "fl/armour/cipher/cipher_shares.h"
24
25 namespace mindspore {
26 namespace fl {
27 namespace server {
28 namespace kernel {
InitKernel(size_t)29 void GetSecretsKernel::InitKernel(size_t) {
30 if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
31 iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
32 }
33
34 executor_ = &Executor::GetInstance();
35 MS_EXCEPTION_IF_NULL(executor_);
36 if (!executor_->initialized()) {
37 MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
38 return;
39 }
40
41 cipher_share_ = &armour::CipherShares::GetInstance();
42 }
43
CountForGetSecrets(const std::shared_ptr<FBBuilder> & fbb,const schema::GetShareSecrets * get_secrets_req,const int iter_num)44 bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
45 const schema::GetShareSecrets *get_secrets_req, const int iter_num) {
46 MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
47 if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
48 std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
49 cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num),
50 std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
51 MS_LOG(ERROR) << reason;
52 return false;
53 }
54 return true;
55 }
56
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)57 bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
58 const std::vector<AddressPtr> &outputs) {
59 bool response = false;
60 size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
61 MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
62 std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
63 size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
64 MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
65 << total_duration;
66 clock_t start_time = clock();
67
68 if (inputs.size() != 1 || outputs.size() != 1) {
69 std::string reason = "inputs or outputs size is invalid.";
70 MS_LOG(ERROR) << reason;
71 return false;
72 }
73
74 std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
75 void *req_data = inputs[0]->addr;
76 if (fbb == nullptr || req_data == nullptr) {
77 std::string reason = "FBBuilder builder or req_data is nullptr.";
78 MS_LOG(ERROR) << reason;
79 return false;
80 }
81
82 const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
83 size_t iter_client = IntToSize(get_secrets_req->iteration());
84 if (iter_num != iter_client) {
85 MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
86 << ". client request iteration is " << iter_client;
87 cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
88 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
89 return true;
90 }
91
92 if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
93 MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
94 }
95
96 response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
97 if (!response) {
98 MS_LOG(WARNING) << "get secret shares is failed.";
99 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
100 return true;
101 }
102 if (!CountForGetSecrets(fbb, get_secrets_req, SizeToInt(iter_num))) {
103 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
104 return true;
105 }
106 GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
107 clock_t end_time = clock();
108 double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
109 MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
110 return true;
111 }
112
Reset()113 bool GetSecretsKernel::Reset() {
114 MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
115 MS_LOG(INFO) << "GetSecretsKernel reset!";
116 cipher_share_->ClearShareSecrets();
117 DistributedCountService::GetInstance().ResetCounter(name_);
118 StopTimer();
119 return true;
120 }
121
122 REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
123 } // namespace kernel
124 } // namespace server
125 } // namespace fl
126 } // namespace mindspore
127