1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DB_CONNECTOR_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DB_CONNECTOR_H_ 18 19 #include <memory> 20 #include <utility> 21 #include "minddata/dataset/core/tensor_row.h" 22 #include "minddata/dataset/engine/connector.h" 23 24 #include "minddata/dataset/include/dataset/constants.h" 25 26 namespace mindspore { 27 namespace dataset { 28 // DbConnector is a derived class from Connector with added logic to handle EOE and EOF. 29 // The Connector class itself is responsible to ensure deterministic order on every run. 30 class DbConnector : public Connector<TensorRow> { 31 public: 32 // Constructor of DbConnector 33 // @note DbConnector will create internal N number of blocking queues, where N = nProducers. 34 // See Connector.h for more details. 35 // @param n_producers The number of threads producing data into this DbConnector. 36 // @param n_consumers The number of thread consuming data from this DbConnector. 37 // @param queue_capacity The number of element (TensorRows) for each internal queue. DbConnector(int32_t n_producers,int32_t n_consumers,int32_t queue_capacity)38 DbConnector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) 39 : Connector<TensorRow>(n_producers, n_consumers, queue_capacity), end_of_file_(false) {} 40 41 // Destructor of DbConnector 42 ~DbConnector() = default; 43 44 // Add a TensorRow into the DbConnector. 45 // @note The caller of this add method should use std::move to pass the ownership to DbConnector. 46 // @param worker_id The id of a worker thread calling this method. 47 // @param el A rvalue reference to an element to be passed/added/pushed. 48 Status Add(TensorRow &&el, int32_t worker_id = 0) noexcept { 49 return (Connector<TensorRow>::Push(worker_id, std::move(el))); 50 } 51 52 Status SendEOE(int32_t worker_id = 0) noexcept { 53 TensorRow eoe = TensorRow(TensorRow::kFlagEOE); 54 return Add(std::move(eoe), worker_id); 55 } 56 57 Status SendEOF(int32_t worker_id = 0) noexcept { 58 TensorRow eof = TensorRow(TensorRow::kFlagEOF); 59 return Add(std::move(eof), worker_id); 60 } 61 // Get a TensorRow from the DbConnector. 62 // @note After the first EOF row is encountered, subsequent pop()s will return EOF row. 63 // This will provide/propagate the EOF to all consumer threads of this Connector. 64 // Thus, When the num_consumers < num_producers, there will be extra EOF messages in some of the internal queues 65 // and reset() must be called before reusing DbConnector. 66 // @param worker_id The id of a worker thread calling this method. 67 // @param result The address of a TensorRow where the popped element will be placed. 68 // @param retry_if_eoe A flag to allow the same thread invoke pop() again if the current pop returns eoe buffer. 69 Status PopWithRetry(int32_t worker_id, TensorRow *result, bool retry_if_eoe = false) noexcept { 70 if (result == nullptr) { 71 return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, 72 "[ERROR] nullptr detected when getting data from db connector"); 73 } else { 74 std::unique_lock<std::mutex> lk(m_); 75 RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return (expect_consumer_ == worker_id) || end_of_file_; })); 76 // Once an EOF message is encountered this flag will be set and we can return early. 77 if (end_of_file_) { 78 *result = TensorRow(TensorRow::kFlagEOF); 79 } else { 80 RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); 81 // Setting the internal flag once the first EOF is encountered. 82 if (result->eof()) { 83 end_of_file_ = true; 84 } 85 pop_from_ = (pop_from_ + 1) % num_producers_; 86 } 87 // Do not increment expect_consumer_ when result is eoe and retry_if_eoe is set. 88 if (!(result->eoe() && retry_if_eoe)) { 89 expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; 90 } 91 } 92 out_buffers_count_++; 93 cv_.NotifyAll(); 94 return Status::OK(); 95 } 96 97 private: 98 // A flag to indicate the end of stream has been encountered. 99 bool end_of_file_; 100 }; 101 } // namespace dataset 102 } // namespace mindspore 103 104 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DB_CONNECTOR_H_ 105