1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ 18 19 #include <algorithm> 20 #include <atomic> 21 #include <deque> 22 #include <map> 23 #include <memory> 24 #include <mutex> 25 #include <string> 26 #include <utility> 27 #include "minddata/dataset/core/tensor_row.h" 28 #include "minddata/dataset/engine/cache/cache_client.h" 29 #include "minddata/dataset/engine/datasetops/parallel_op.h" 30 #include "minddata/dataset/engine/dataset_iterator.h" 31 #include "minddata/dataset/util/queue.h" 32 #include "minddata/dataset/util/queue_map.h" 33 #include "minddata/dataset/util/semaphore.h" 34 35 namespace mindspore { 36 namespace dataset { 37 /// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single 38 /// stream 39 class CacheMergeOp : public ParallelOp { 40 public: 41 // Some handshake structures between CacheMissWorkerEntry and Cleaner 42 class TensorRowCacheRequest { 43 public: 44 enum class State : uint8_t { 45 kEmpty = 0, // Initial state. Row hasn't arrived from cache miss stream yet. 46 kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. 47 kClean = 2 // The row has been flushed already. 48 }; TensorRowCacheRequest()49 TensorRowCacheRequest() : st_(State::kEmpty) {} 50 ~TensorRowCacheRequest() = default; 51 /// Getter and Setter of the state GetState()52 State GetState() const { return st_; } SetState(State newState)53 void SetState(State newState) { st_ = newState; } 54 /// Take a tensor row and send rpc call to the server async 55 /// \param cc Cache client of the CacheMergeOp 56 /// \param row TensorRow to be sent to the server 57 /// \return Status object 58 /// \note Thread safe 59 Status AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc, const TensorRow &row); 60 61 /// \brief We send the row to the server async so the CacheMissWorkerEntry can continue. 62 /// It is the cleaner that will check the result. 63 /// \return Status object 64 Status CheckCacheResult(); 65 66 private: 67 std::atomic<State> st_; 68 std::shared_ptr<CacheRowRequest> cleaner_copy_; 69 }; 70 71 constexpr static int kNumChildren = 2; // CacheMergeOp has 2 children 72 constexpr static int kCacheHitChildIdx = 0; // Cache hit stream 73 constexpr static int kCacheMissChildIdx = 1; // Cache miss stream 74 75 /// \brief Constructor 76 /// \param numWorkers Number of parallel workers as a derived class of ParallelOp 77 /// \param opConnector Size Connector size as a derived class of ParallelOp 78 /// \param numCleaners Number of cleaners to move cache miss rows into the cache server 79 /// \param cache_client CacheClient to communicate with the Cache server 80 CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, 81 std::shared_ptr<CacheClient> cache_client); 82 ~CacheMergeOp(); 83 void Print(std::ostream &out, bool show_all) const override; Name()84 std::string Name() const override { return kCacheMergeOp; } 85 86 friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) { 87 mo.Print(out, false); 88 return out; 89 } 90 /// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and 91 /// the threads for the cleaners. 92 /// \return 93 Status operator()() override; 94 95 /// \brief Entry function for worker thread that fetch rows from CacheLookupOp 96 /// \param workerId 97 /// \return Status object 98 Status WorkerEntry(int32_t workerId) override; 99 100 /// \brief Perform specific post-operations on CacheOp 101 /// \return Status The status code returned 102 Status PrepareOperator() override; 103 104 /// \brief Entry function for worker thread that fetch rows from the cache miss stream 105 /// \param workerId 106 /// \return Status object 107 Status CacheMissWorkerEntry(int32_t workerId); 108 109 /// \brief Base-class override for eoe handling 110 /// \param worker_id 111 /// \return Status object 112 Status EoeReceived(int32_t worker_id) override; 113 114 /// \brief Base-class override for handling cases when an eof is received. 115 /// \param worker_id - The worker id 116 /// \return Status The status code returned 117 Status EofReceived(int32_t worker_id) override; 118 119 protected: 120 Status ComputeColMap() override; 121 122 private: 123 std::mutex mux_; 124 QueueMap<row_id_type, TensorRow> cache_miss_; 125 std::map<row_id_type, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>> io_request_; 126 std::unique_ptr<Queue<row_id_type>> io_que_; 127 int32_t num_cleaners_; 128 std::shared_ptr<CacheClient> cache_client_; 129 std::atomic<bool> cache_missing_rows_; 130 131 /// \brief Locate the cache request from the io_request_ map 132 /// \param row_id 133 /// \param out pointer to the cache request 134 /// \return Status object 135 Status GetRq(row_id_type row_id, TensorRowCacheRequest **out); 136 137 /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for 138 /// moving cache miss TensorRow into the CacheServer. 139 /// \return Status object 140 Status Cleaner(); 141 }; 142 } // namespace dataset 143 } // namespace mindspore 144 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ 145