• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
18 
19 #include <atomic>
20 #include <deque>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "minddata/dataset/engine/connector.h"
26 #include "minddata/dataset/engine/cache/cache_client.h"
27 #include "minddata/dataset/engine/datasetops/parallel_op.h"
28 #include "minddata/dataset/engine/datasetops/repeat_op.h"
29 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
30 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
31 #include "minddata/dataset/util/queue.h"
32 #include "minddata/dataset/util/queue_map.h"
33 #include "minddata/dataset/util/semaphore.h"
34 #include "minddata/dataset/util/wait_post.h"
35 namespace mindspore {
36 namespace dataset {
37 /// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
38 /// \see CacheOp
39 /// \see CacheLookupOp
40 class CacheBase : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow> {
41  public:
42   /// \brief Base class constructor
43   /// \param num_workers Number of parallel workers
44   /// \param op_connector_size Connector size
45   /// \param cache_client CacheClient for communication to the CacheServer
46   /// \param sampler Sampler which is mandatory
47   CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
48             std::shared_ptr<SamplerRT> sampler);
49   /// \brief Destructor
50   ~CacheBase();
51 
52   /// \brief Overrides base class reset method.  When an operator does a reset, it cleans up any state
53   /// info from it's previous execution and then initializes itself so that it can be executed
54   /// again.
55   /// \return Status The status code returned
56   Status Reset() override;
57 
58   /// \brief A print method typically used for debugging
59   /// \param out The output stream to write output to
60   /// \param show_all A bool to control if you want to show all info or just a summary
61   void Print(std::ostream &out, bool show_all) const override;
62 
63   /// \brief Gives a name to the class, typically used for debugging
Name()64   std::string Name() const override { return kCacheBase; }
65 
66   /// \brief << Stream output operator overload
67   /// \notes This allows you to write the debug print info using stream operators
68   /// \param out reference to the output stream being overloaded
69   /// \param mo reference to the CacheOp to display
70   /// \return the output stream must be returned
71   friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) {
72     mo.Print(out, false);
73     return out;
74   }
75 
76   /// \brief Getter for the cache client
77   /// \return shared ptr to the cache client
GetCacheClient()78   std::shared_ptr<CacheClient> GetCacheClient() { return cache_client_; }
79   /// \brief Setter for the cache client
SetCacheClient(std::shared_ptr<CacheClient> cache_client)80   void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
81   /// \brief Derived class must implement this method if a cache miss is treated as error
82   virtual bool AllowCacheMiss() = 0;
83 
84  protected:
85   constexpr static int32_t eoe_row_id = -1;
86   int64_t row_cnt_;
87   std::atomic<int64_t> num_cache_miss_;
88   std::shared_ptr<CacheClient> cache_client_;
89   std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_;
90 
91   /// \brief Common function to register resources for interrupt
92   /// \note Derived should override this function for extra resources to be registered
93   virtual Status RegisterResources();
94   /// \brief This function is called by main thread to send samples to the worker thread.
95   /// \note It is a non-virtual function
96   /// \return Status object
97   Status FetchSamplesToWorkers();
98   /// \brief This function is called by each worker to fetch rows from the cache server for a given set of
99   /// sample row id's
100   /// \return Status object
101   Status FetchFromCache(int32_t worker_id);
102   /// \brief Get the column map from cache server
103   Status UpdateColumnMapFromCache();
104 
105  private:
106   constexpr static int32_t connector_capacity_ = 1024;
107   int32_t prefetch_size_;
108   int32_t num_prefetchers_;
109   QueueList<std::unique_ptr<IOBlock>> prefetch_queues_;
110   QueueMap<row_id_type, TensorRow> prefetch_;
111 
112   /// \brief Prefetcher. It prefetch the rows from cache server
113   /// \return Status object.
114   Status Prefetcher(int32_t worker_id);
115   /// \brief Functions used by prefetcher and WorkerEntry
116   Status PrefetchRows(const std::vector<row_id_type> &keys, std::vector<row_id_type> *cache_miss);
117   Status GetPrefetchRow(row_id_type row_id, TensorRow *out);
118 };
119 }  // namespace dataset
120 }  // namespace mindspore
121 
122 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
123