1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DATA_QUEUE_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_DATA_QUEUE_H_ 19 20 #include <unistd.h> 21 #include <memory> 22 #include <vector> 23 #include <string> 24 #include <map> 25 #include <functional> 26 #include <queue> 27 #include "runtime/hardware/device_context_manager.h" 28 #include "include/backend/data_queue/data_queue.h" 29 #include "include/backend/data_queue/blocking_queue.h" 30 #include "acl/acl_tdt.h" 31 32 namespace mindspore { 33 namespace device { 34 class AscendDataQueueDynamic : public DataQueue { 35 public: 36 explicit AscendDataQueueDynamic(const std::string &channel_name, const size_t capacity); 37 ~AscendDataQueueDynamic() override = default; 38 39 DataQueueStatus Push(std::vector<DataQueueItem> data) override; 40 DataQueueStatus Front(std::vector<DataQueueItem> *data) const override; 41 DataQueueStatus Pop() override; 42 43 private: 44 struct NodeInfo { 45 std::vector<DataQueueItem> data_; 46 }; 47 aclrtStream stream_; 48 std::unique_ptr<NodeInfo[]> node_info_; 49 }; 50 51 namespace tdt_handle { 52 void AddHandle(acltdtChannelHandle **handle, std::thread *use_thread); 53 bool DestroyHandle(); 54 void DelHandle(acltdtChannelHandle **handle); 55 bool IsClosed(); 56 } // namespace tdt_handle 57 58 class WingmanQueue : public DataQueue { 59 public: WingmanQueue(const std::string & channel_name)60 explicit WingmanQueue(const std::string &channel_name) : DataQueue(channel_name, 0) {} 61 ~WingmanQueue() override = default; 62 void Close() override; 63 DataQueueStatus Push(std::vector<DataQueueItem> data) override; 64 DataQueueStatus Front(std::vector<DataQueueItem> *data) const override; 65 DataQueueStatus FrontAsync(std::vector<DataQueueItem> *data) const override; 66 DataQueueStatus Pop() override; IsEmpty()67 bool IsEmpty() const override { return queue_.empty(); } IsFull()68 bool IsFull() const override { return false; } Size()69 size_t Size() const override { return queue_.size(); } 70 71 private: 72 std::queue<std::vector<DataQueueItem>> queue_; 73 }; 74 75 class AscendTdtQueue : public DataQueue { 76 public: 77 explicit AscendTdtQueue(const std::string &channel_name); 78 ~AscendTdtQueue() override; 79 80 bool IsOpen() const override; 81 DataQueueStatus Push(std::vector<DataQueueItem> data) override; Front(std::vector<DataQueueItem> * data)82 DataQueueStatus Front(std::vector<DataQueueItem> *data) const override { return DataQueueStatus::SUCCESS; } Pop()83 DataQueueStatus Pop() override { return DataQueueStatus::SUCCESS; } 84 size_t QueryQueueSize() const override; QueueType()85 std::string QueueType() const override { return queue_type_; } 86 87 private: 88 void DestroyAclDataset(acltdtDataset *acl_dataset, bool include_data_item = true) const; 89 bool AssembleTensor2AclDataset(const std::vector<DataQueueItem> &data, acltdtDataset *acl_dataset) const; 90 void ParseType(aclDataType acl_data_type, std::string *data_type) const; 91 bool Translate(const std::vector<DataQueueItem> &data, acltdtDataset **output_acl_dataset) const; 92 93 acltdtChannelHandle *acl_handle_; 94 uint32_t device_id_; 95 std::string queue_type_; 96 }; 97 std::shared_ptr<BlockingQueue> GetTdtWingManQueue(const PrimitivePtr &prim); 98 std::shared_ptr<BlockingQueue> GetTdtWingManQueue(const std::shared_ptr<AnfNode> &node); 99 void CloseTdtWingManQueue(const PrimitivePtr &prim); 100 void CloseTdtWingManQueue(const std::shared_ptr<AnfNode> &node); 101 } // namespace device 102 } // namespace mindspore 103 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_BLOCKING_QUEUE_H_ 104