1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GPU_ITEM_CONNECTOR_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GPU_ITEM_CONNECTOR_H_ 18 19 #include <memory> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 #include "minddata/dataset/engine/connector.h" 24 #include "minddata/dataset/util/status.h" 25 #include "minddata/dataset/include/dataset/constants.h" 26 #include "include/backend/data_queue/blocking_queue.h" 27 28 using mindspore::device::DataQueueItem; 29 30 namespace mindspore { 31 namespace dataset { 32 33 struct GpuConnectorItem { 34 std::vector<device::DataQueueItem> data_item; 35 bool eoe_flag; // flag to indicate an EOE item in the connector 36 }; 37 38 class GpuConnector : public Connector<GpuConnectorItem> { 39 public: GpuConnector(int32_t num_producers,int32_t num_consumers,int32_t queue_capacity)40 GpuConnector(int32_t num_producers, int32_t num_consumers, int32_t queue_capacity) 41 : Connector<GpuConnectorItem>(num_producers, num_consumers, queue_capacity) { 42 for (int i = 0; i < num_producers; i++) { 43 is_queue_finished_.push_back(false); 44 } 45 } 46 47 ~GpuConnector() = default; 48 Add(int32_t worker_d,GpuConnectorItem && element)49 Status Add(int32_t worker_d, GpuConnectorItem &&element) noexcept { 50 return Connector<GpuConnectorItem>::Push(worker_d, std::move(element)); 51 } 52 Pop(int32_t worker_id,GpuConnectorItem * result)53 Status Pop(int32_t worker_id, GpuConnectorItem *result) noexcept override { 54 RETURN_UNEXPECTED_IF_NULL(result); 55 { 56 MS_ASSERT(worker_id < num_consumers_); 57 std::unique_lock<std::mutex> lock(m_); 58 RETURN_IF_NOT_OK(cv_.Wait(&lock, [this, worker_id]() { return expect_consumer_ == worker_id; })); 59 if (is_queue_finished_[pop_from_]) { 60 std::string errMsg = "ERROR: popping from a finished queue in GpuConnector"; 61 RETURN_STATUS_UNEXPECTED(errMsg); 62 } 63 64 RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); 65 // empty data_item and eoe_flag=false is EOF 66 if ((*result).data_item.empty() && !(*result).eoe_flag) { 67 is_queue_finished_[pop_from_] = true; 68 } 69 70 for (int offset = 1; offset <= num_producers_; offset++) { 71 int32_t nextQueueIndex = (pop_from_ + offset) % num_producers_; 72 if (is_queue_finished_[nextQueueIndex] == false) { 73 pop_from_ = nextQueueIndex; 74 break; 75 } 76 } 77 78 expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; 79 } 80 81 cv_.NotifyAll(); 82 return Status::OK(); 83 } 84 85 private: 86 std::vector<bool> is_queue_finished_; 87 }; 88 } // namespace dataset 89 } // namespace mindspore 90 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GPU_ITEM_CONNECTOR_H_ 91