• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
18 #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <mutex>
25 #include <queue>
26 #include <utility>
27 #include <chrono>
28 #include <thread>
29 #include <unordered_map>
30 #include "backend/kernel_compiler/common_utils.h"
31 #include "backend/kernel_compiler/cpu/cpu_kernel.h"
32 #include "fl/server/common.h"
33 #include "fl/server/local_meta_store.h"
34 #include "fl/server/distributed_count_service.h"
35 #include "fl/server/distributed_metadata_store.h"
36 
37 namespace mindspore {
38 namespace fl {
39 namespace server {
40 namespace kernel {
41 constexpr uint64_t kReleaseDuration = 100;
42 // RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
43 // kernels to represent the process. They receive and parse messages from the server communication module. After
44 // handling these messages, round kernels allocate response data and send it back.
45 
46 // For example, the main process of federated learning is:
47 // startFLJob round->updateModel round->getModel round.
48 class RoundKernel : virtual public CPUKernel {
49  public:
50   RoundKernel();
51   virtual ~RoundKernel();
52 
53   // RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this
54   // inherited method is empty.
InitKernel(const CNodePtr & kernel_node)55   void InitKernel(const CNodePtr &kernel_node) override {}
56 
57   // Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count
58   // messages.
59   virtual void InitKernel(size_t threshold_count) = 0;
60 
61   // Launch the round kernel logic to handle the message passed by the communication module.
62   virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
63                       const std::vector<AddressPtr> &outputs) = 0;
64 
65   // Some rounds could be stateful in a iteration. Reset method resets the status of this round.
66   virtual bool Reset() = 0;
67 
68   // The counter event handlers for DistributedCountService.
69   // The callbacks when first message and last message for this round kernel is received.
70   // These methods is called by class DistributedCountService and triggered by counting server.
71   virtual void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
72   virtual void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
73 
74   // Called when this round is finished. This round timer's Stop method will be called.
75   void StopTimer() const;
76 
77   // Called after this iteration(including all rounds) is finished. All rounds' Reset method will
78   // be called.
79   void FinishIteration() const;
80 
81   // Release the response data allocated inside the round kernel.
82   // Server framework must call this after the response data is sent back.
83   void Release(const AddressPtr &addr_ptr);
84 
85   // Set round kernel name, which could be used in round kernel's methods.
86   void set_name(const std::string &name);
87 
88   // Set callbacks to be called under certain triggered conditions.
89   void set_stop_timer_cb(const StopTimerCb &timer_stopper);
90   void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb);
91 
92  protected:
93   // Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
94   // back to worker.
95   void GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len);
96 
97   // Round kernel's name.
98   std::string name_;
99 
100   // The current received message count for this round in this iteration.
101   size_t current_count_;
102 
103   // The required received message count for this round in one iteration.
104   size_t required_count_;
105 
106   // The reason causes the error in this round kernel.
107   std::string error_reason_;
108 
109   StopTimerCb stop_timer_cb_;
110   FinishIterCb finish_iteration_cb_;
111 
112   // Members below are used for allocating and releasing response data on the heap.
113 
114   // To ensure the performance, we use another thread to release data on the heap. So the operation on the data should
115   // be threadsafe.
116   std::atomic_bool running_;
117   std::thread release_thread_;
118 
119   // Data needs to be released and its mutex;
120   std::mutex release_mtx_;
121   std::queue<AddressPtr> heap_data_to_release_;
122   std::mutex heap_data_mtx_;
123   std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
124 };
125 }  // namespace kernel
126 }  // namespace server
127 }  // namespace fl
128 }  // namespace mindspore
129 #endif  // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
130