1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H 18 #define MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H 19 20 #include <string> 21 #include <memory> 22 #include <vector> 23 #include <functional> 24 #include "include/backend/visible.h" 25 26 namespace mindspore { 27 namespace device { 28 class DeviceContext; 29 30 enum class DataQueueStatus : int { SUCCESS = 0, QUEUE_EXIST, QUEUE_NOT_EXIST, ERROR_INPUT, INTERNAL_ERROR, TIMEOUT }; 31 32 struct DataQueueItem { 33 int32_t worker_id{0}; 34 std::string data_type; 35 size_t data_len{0}; 36 void *data_ptr{nullptr}; 37 std::vector<int64_t> shapes; 38 void *device_addr{nullptr}; 39 // add tensor type when tdt need more types than data and end-of-sequence 40 }; 41 42 class BACKEND_EXPORT DataQueue { 43 public: 44 explicit DataQueue(const std::string &channel_name, const size_t capacity); 45 virtual ~DataQueue() = default; 46 RegisterRelease(const std::function<void (void *,int32_t)> & func)47 virtual void RegisterRelease(const std::function<void(void *, int32_t)> &func) { host_release_ = func; } IsOpen()48 virtual bool IsOpen() const { return !closed_; } Close()49 virtual void Close() { closed_ = true; } IsEmpty()50 virtual bool IsEmpty() const { return size_ == 0; } IsFull()51 virtual bool IsFull() const { return size_ == capacity_; } FrontAsync(std::vector<DataQueueItem> * data)52 virtual DataQueueStatus FrontAsync(std::vector<DataQueueItem> *data) const { return DataQueueStatus::SUCCESS; } 53 virtual DataQueueStatus Push(std::vector<DataQueueItem> data) = 0; 54 virtual DataQueueStatus Front(std::vector<DataQueueItem> *data) const = 0; 55 virtual DataQueueStatus Pop() = 0; SetThreadDevice()56 virtual void SetThreadDevice() {} Size()57 virtual size_t Size() const { return size_; } Capacity()58 virtual size_t Capacity() const { return capacity_; } QueryQueueSize()59 virtual size_t QueryQueueSize() const { return 0; } QueueType()60 virtual std::string QueueType() const { return "Unknown"; } 61 62 protected: 63 const std::string channel_name_; 64 size_t head_; 65 size_t tail_; 66 size_t size_; 67 size_t capacity_; 68 bool closed_{false}; 69 std::function<void(void *, int32_t)> host_release_; 70 DeviceContext *device_context_; 71 72 private: 73 DataQueue(const DataQueue &) = delete; 74 DataQueue &operator=(const DataQueue &) = delete; 75 }; 76 } // namespace device 77 } // namespace mindspore 78 #endif // MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H 79