1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ 18 #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 #include <mutex> 25 #include <queue> 26 #include <utility> 27 #include <chrono> 28 #include <thread> 29 #include <unordered_map> 30 #include "backend/kernel_compiler/common_utils.h" 31 #include "backend/kernel_compiler/cpu/cpu_kernel.h" 32 #include "fl/server/common.h" 33 #include "fl/server/local_meta_store.h" 34 #include "fl/server/distributed_count_service.h" 35 #include "fl/server/distributed_metadata_store.h" 36 37 namespace mindspore { 38 namespace fl { 39 namespace server { 40 namespace kernel { 41 constexpr uint64_t kReleaseDuration = 100; 42 // RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round 43 // kernels to represent the process. They receive and parse messages from the server communication module. After 44 // handling these messages, round kernels allocate response data and send it back. 45 46 // For example, the main process of federated learning is: 47 // startFLJob round->updateModel round->getModel round. 48 class RoundKernel : virtual public CPUKernel { 49 public: 50 RoundKernel(); 51 virtual ~RoundKernel(); 52 53 // RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this 54 // inherited method is empty. InitKernel(const CNodePtr & kernel_node)55 void InitKernel(const CNodePtr &kernel_node) override {} 56 57 // Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count 58 // messages. 59 virtual void InitKernel(size_t threshold_count) = 0; 60 61 // Launch the round kernel logic to handle the message passed by the communication module. 62 virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, 63 const std::vector<AddressPtr> &outputs) = 0; 64 65 // Some rounds could be stateful in a iteration. Reset method resets the status of this round. 66 virtual bool Reset() = 0; 67 68 // The counter event handlers for DistributedCountService. 69 // The callbacks when first message and last message for this round kernel is received. 70 // These methods is called by class DistributedCountService and triggered by counting server. 71 virtual void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message); 72 virtual void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message); 73 74 // Called when this round is finished. This round timer's Stop method will be called. 75 void StopTimer() const; 76 77 // Called after this iteration(including all rounds) is finished. All rounds' Reset method will 78 // be called. 79 void FinishIteration() const; 80 81 // Release the response data allocated inside the round kernel. 82 // Server framework must call this after the response data is sent back. 83 void Release(const AddressPtr &addr_ptr); 84 85 // Set round kernel name, which could be used in round kernel's methods. 86 void set_name(const std::string &name); 87 88 // Set callbacks to be called under certain triggered conditions. 89 void set_stop_timer_cb(const StopTimerCb &timer_stopper); 90 void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb); 91 92 protected: 93 // Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent 94 // back to worker. 95 void GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len); 96 97 // Round kernel's name. 98 std::string name_; 99 100 // The current received message count for this round in this iteration. 101 size_t current_count_; 102 103 // The required received message count for this round in one iteration. 104 size_t required_count_; 105 106 // The reason causes the error in this round kernel. 107 std::string error_reason_; 108 109 StopTimerCb stop_timer_cb_; 110 FinishIterCb finish_iteration_cb_; 111 112 // Members below are used for allocating and releasing response data on the heap. 113 114 // To ensure the performance, we use another thread to release data on the heap. So the operation on the data should 115 // be threadsafe. 116 std::atomic_bool running_; 117 std::thread release_thread_; 118 119 // Data needs to be released and its mutex; 120 std::mutex release_mtx_; 121 std::queue<AddressPtr> heap_data_to_release_; 122 std::mutex heap_data_mtx_; 123 std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_; 124 }; 125 } // namespace kernel 126 } // namespace server 127 } // namespace fl 128 } // namespace mindspore 129 #endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ 130