• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fl/server/kernel/round/get_secrets_kernel.h"
18 #include <vector>
19 #include <memory>
20 #include <string>
21 #include <map>
22 #include <utility>
23 #include "fl/armour/cipher/cipher_shares.h"
24 
25 namespace mindspore {
26 namespace fl {
27 namespace server {
28 namespace kernel {
InitKernel(size_t)29 void GetSecretsKernel::InitKernel(size_t) {
30   if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
31     iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
32   }
33 
34   executor_ = &Executor::GetInstance();
35   MS_EXCEPTION_IF_NULL(executor_);
36   if (!executor_->initialized()) {
37     MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
38     return;
39   }
40 
41   cipher_share_ = &armour::CipherShares::GetInstance();
42 }
43 
CountForGetSecrets(const std::shared_ptr<FBBuilder> & fbb,const schema::GetShareSecrets * get_secrets_req,const int iter_num)44 bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
45                                           const schema::GetShareSecrets *get_secrets_req, const int iter_num) {
46   MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
47   if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
48     std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
49     cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num),
50                                       std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
51     MS_LOG(ERROR) << reason;
52     return false;
53   }
54   return true;
55 }
56 
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> &,const std::vector<AddressPtr> & outputs)57 bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
58                               const std::vector<AddressPtr> &outputs) {
59   bool response = false;
60   size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
61   MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
62   std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
63   size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
64   MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
65                << total_duration;
66   clock_t start_time = clock();
67 
68   if (inputs.size() != 1 || outputs.size() != 1) {
69     std::string reason = "inputs or outputs size is invalid.";
70     MS_LOG(ERROR) << reason;
71     return false;
72   }
73 
74   std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
75   void *req_data = inputs[0]->addr;
76   if (fbb == nullptr || req_data == nullptr) {
77     std::string reason = "FBBuilder builder or req_data is nullptr.";
78     MS_LOG(ERROR) << reason;
79     return false;
80   }
81 
82   const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
83   size_t iter_client = IntToSize(get_secrets_req->iteration());
84   if (iter_num != iter_client) {
85     MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
86                   << ". client request iteration is " << iter_client;
87     cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
88     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
89     return true;
90   }
91 
92   if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
93     MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
94   }
95 
96   response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
97   if (!response) {
98     MS_LOG(WARNING) << "get secret shares is failed.";
99     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
100     return true;
101   }
102   if (!CountForGetSecrets(fbb, get_secrets_req, SizeToInt(iter_num))) {
103     GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
104     return true;
105   }
106   GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
107   clock_t end_time = clock();
108   double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
109   MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
110   return true;
111 }
112 
Reset()113 bool GetSecretsKernel::Reset() {
114   MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
115   MS_LOG(INFO) << "GetSecretsKernel reset!";
116   cipher_share_->ClearShareSecrets();
117   DistributedCountService::GetInstance().ResetCounter(name_);
118   StopTimer();
119   return true;
120 }
121 
122 REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
123 }  // namespace kernel
124 }  // namespace server
125 }  // namespace fl
126 }  // namespace mindspore
127