1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_MGR_H 18 #define MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_MGR_H 19 20 #include <iostream> 21 #include <functional> 22 #include <map> 23 #include <vector> 24 #include <string> 25 #include <memory> 26 #include <mutex> 27 #include <condition_variable> 28 #include "utils/callback_handler.h" 29 #include "include/backend/visible.h" 30 #include "include/backend/data_queue/data_queue.h" 31 #ifndef BUILD_LITE 32 #include "ir/anf.h" 33 #include "kernel/kernel.h" 34 #endif 35 36 namespace mindspore { 37 namespace device { 38 constexpr unsigned int MAX_WAIT_TIME_IN_SEC = 60; 39 class BlockingQueue; 40 41 // channel_name, dynamic_shape, capacity, addr, shape 42 using DataQueueCreator = 43 std::function<std::shared_ptr<DataQueue>(const std::string &, bool, size_t, const std::vector<size_t> &)>; 44 class Semaphore { 45 public: count_(count)46 explicit Semaphore(int count = 0) : count_(count) {} 47 Signal()48 inline void Signal() { 49 std::unique_lock<std::mutex> lock(mutex_); 50 ++count_; 51 cv_.notify_one(); 52 } 53 Wait()54 inline bool Wait() { 55 std::unique_lock<std::mutex> lock(mutex_); 56 while (count_ == 0) { 57 if (cv_.wait_for(lock, std::chrono::seconds(MAX_WAIT_TIME_IN_SEC)) == std::cv_status::timeout) { 58 return false; 59 } 60 } 61 --count_; 62 return true; 63 } 64 65 private: 66 std::mutex mutex_; 67 std::condition_variable cv_; 68 int count_; 69 }; 70 71 class BACKEND_EXPORT DataQueueMgr { 72 public: DataQueueMgr()73 DataQueueMgr() : init_(false), closed_(false), open_by_dataset_(0) {} 74 75 virtual ~DataQueueMgr() = default; 76 77 static DataQueueMgr &GetInstance() noexcept; 78 void RegisterDataQueueCreator(const std::string &device_name, DataQueueCreator &&creator); 79 void Clear(); 80 std::shared_ptr<DataQueue> CreateDataQueue(const std::string &device_name, const std::string &channel_name, 81 bool dynamic_shape, size_t capacity = 0, 82 const std::vector<size_t> &shape = {}); 83 84 DataQueueStatus Create(const std::string &channel_name, const std::vector<size_t> &shape, const size_t capacity); 85 86 // call for Push thread 87 DataQueueStatus Open(const std::string &channel_name, std::function<void(void *, int32_t)> func); 88 89 // call for Front/Pop thread 90 DataQueueStatus Open(const std::string &channel_name) const; 91 DataQueueStatus Push(const std::string &channel_name, const std::vector<DataQueueItem> &data, 92 unsigned int timeout_in_sec); 93 DataQueueStatus Front(const std::string &channel_name, std::vector<DataQueueItem> *data); 94 DataQueueStatus Pop(const std::string &channel_name); 95 DataQueueStatus FrontAsync(const std::string &channel_name, std::vector<DataQueueItem> *data); 96 void Free(const std::string &channel_name); 97 DataQueueStatus Clear(const std::string &channel_name); 98 void Release(); 99 DataQueueStatus CreateDynamicBufQueue(const std::string &channel_name, const size_t &capacity); 100 std::shared_ptr<BlockingQueue> GetDataQueue(const std::string &channel_name) const; 101 DataQueueStatus SetThreadDevice(const std::string &channel_name) const; 102 103 void Close(const std::string &channel_name) const noexcept; 104 105 bool IsInit() const; 106 107 bool IsClosed() const; 108 109 bool IsCreated(const std::string &channel_name) const; 110 111 bool Destroy(); 112 113 // call for Release GPU Resources 114 bool CloseNotify(); 115 116 // call for dataset send thread 117 void CloseConfirm(); 118 119 size_t Size(const std::string &channel_name); 120 121 size_t Capacity(const std::string &channel_name); 122 123 void Manage(const std::string &channel_name, const std::shared_ptr<BlockingQueue> &queue); 124 125 private: 126 DataQueueMgr(const DataQueueMgr &) = delete; 127 DataQueueMgr &operator=(const DataQueueMgr &) = delete; 128 129 bool init_; 130 bool closed_; 131 std::mutex close_mutex_; 132 std::condition_variable cv_; 133 // how many queues opened by dataset 134 int open_by_dataset_; 135 Semaphore sema; 136 bool dynamic_shape_{false}; 137 size_t default_capacity_{2}; 138 139 std::map<std::string, std::shared_ptr<BlockingQueue>> name_queue_map_; 140 // key: device name, value: DataQueueCreator 141 std::map<std::string, DataQueueCreator> data_queue_creator_map_ = {}; 142 143 HANDLER_DEFINE(bool, DestoryTdtHandle); 144 }; 145 #ifndef BUILD_LITE 146 BACKEND_EXPORT void UpdateGetNextNode(const AnfNodePtr &data_kernel); 147 148 BACKEND_EXPORT void UpdateGetNextNode(const PrimitivePtr &primitive, const std::vector<kernel::KernelTensor *> &inputs, 149 const std::vector<kernel::KernelTensor *> &outputs, 150 std::vector<size_t> *output_size_list); 151 152 BACKEND_EXPORT void UpdateGetNextWithDataQueueItems(const AnfNodePtr &data_kernel, 153 const std::vector<device::DataQueueItem> &data); 154 155 BACKEND_EXPORT void UpdateGetNextWithDataQueueItems(const std::vector<kernel::KernelTensor *> &inputs, 156 const std::vector<kernel::KernelTensor *> &outputs, 157 const std::vector<device::DataQueueItem> &data, 158 std::vector<size_t> *output_size_list); 159 160 BACKEND_EXPORT void RetryPeakItemFromDataQueue(const AnfNodePtr &data_kernel, 161 const std::shared_ptr<BlockingQueue> &data_queue, 162 std::vector<device::DataQueueItem> *data); 163 #endif 164 #define REGISTER_DATA_QUEUE_CREATOR(device_name, creator) \ 165 struct device_name##DataQueueCreatorClass { \ 166 device_name##DataQueueCreatorClass() { \ 167 DataQueueMgr::GetInstance().RegisterDataQueueCreator(device_name, creator); \ 168 } \ 169 } g_##device_name##_data_queue_creator; 170 } // namespace device 171 } // namespace mindspore 172 173 #endif // MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_MGR_H 174