1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ 18 19 #include <atomic> 20 #include <deque> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 #include "minddata/dataset/engine/connector.h" 26 #include "minddata/dataset/engine/cache/cache_client.h" 27 #include "minddata/dataset/engine/datasetops/parallel_op.h" 28 #include "minddata/dataset/engine/datasetops/repeat_op.h" 29 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 30 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" 31 #include "minddata/dataset/util/queue.h" 32 #include "minddata/dataset/util/queue_map.h" 33 #include "minddata/dataset/util/semaphore.h" 34 #include "minddata/dataset/util/wait_post.h" 35 namespace mindspore { 36 namespace dataset { 37 /// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. 38 /// \see CacheOp 39 /// \see CacheLookupOp 40 class CacheBase : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow> { 41 public: 42 /// \brief Base class constructor 43 /// \param num_workers Number of parallel workers 44 /// \param op_connector_size Connector size 45 /// \param cache_client CacheClient for communication to the CacheServer 46 /// \param sampler Sampler which is mandatory 47 CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client, 48 std::shared_ptr<SamplerRT> sampler); 49 /// \brief Destructor 50 ~CacheBase(); 51 52 /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state 53 /// info from it's previous execution and then initializes itself so that it can be executed 54 /// again. 55 /// \return Status The status code returned 56 Status Reset() override; 57 58 /// \brief A print method typically used for debugging 59 /// \param out The output stream to write output to 60 /// \param show_all A bool to control if you want to show all info or just a summary 61 void Print(std::ostream &out, bool show_all) const override; 62 63 /// \brief Gives a name to the class, typically used for debugging Name()64 std::string Name() const override { return kCacheBase; } 65 66 /// \brief << Stream output operator overload 67 /// \notes This allows you to write the debug print info using stream operators 68 /// \param out reference to the output stream being overloaded 69 /// \param mo reference to the CacheOp to display 70 /// \return the output stream must be returned 71 friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) { 72 mo.Print(out, false); 73 return out; 74 } 75 76 /// \brief Getter for the cache client 77 /// \return shared ptr to the cache client GetCacheClient()78 std::shared_ptr<CacheClient> GetCacheClient() { return cache_client_; } 79 /// \brief Setter for the cache client SetCacheClient(std::shared_ptr<CacheClient> cache_client)80 void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); } 81 /// \brief Derived class must implement this method if a cache miss is treated as error 82 virtual bool AllowCacheMiss() = 0; 83 84 protected: 85 constexpr static int32_t eoe_row_id = -1; 86 int64_t row_cnt_; 87 std::atomic<int64_t> num_cache_miss_; 88 std::shared_ptr<CacheClient> cache_client_; 89 std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_; 90 91 /// \brief Common function to register resources for interrupt 92 /// \note Derived should override this function for extra resources to be registered 93 virtual Status RegisterResources(); 94 /// \brief This function is called by main thread to send samples to the worker thread. 95 /// \note It is a non-virtual function 96 /// \return Status object 97 Status FetchSamplesToWorkers(); 98 /// \brief This function is called by each worker to fetch rows from the cache server for a given set of 99 /// sample row id's 100 /// \return Status object 101 Status FetchFromCache(int32_t worker_id); 102 /// \brief Get the column map from cache server 103 Status UpdateColumnMapFromCache(); 104 105 private: 106 constexpr static int32_t connector_capacity_ = 1024; 107 int32_t prefetch_size_; 108 int32_t num_prefetchers_; 109 QueueList<std::unique_ptr<IOBlock>> prefetch_queues_; 110 QueueMap<row_id_type, TensorRow> prefetch_; 111 112 /// \brief Prefetcher. It prefetch the rows from cache server 113 /// \return Status object. 114 Status Prefetcher(int32_t worker_id); 115 /// \brief Functions used by prefetcher and WorkerEntry 116 Status PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss); 117 Status GetPrefetchRow(row_id_type row_id, TensorRow *out); 118 }; 119 } // namespace dataset 120 } // namespace mindspore 121 122 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ 123