• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <deque>
22 #include <map>
23 #include <memory>
24 #include <mutex>
25 #include <string>
26 #include <utility>
27 #include "minddata/dataset/core/tensor_row.h"
28 #include "minddata/dataset/engine/cache/cache_client.h"
29 #include "minddata/dataset/engine/datasetops/parallel_op.h"
30 #include "minddata/dataset/engine/dataset_iterator.h"
31 #include "minddata/dataset/util/queue.h"
32 #include "minddata/dataset/util/queue_map.h"
33 #include "minddata/dataset/util/semaphore.h"
34 
35 namespace mindspore {
36 namespace dataset {
37 /// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
38 /// stream
39 class CacheMergeOp : public ParallelOp {
40  public:
41   // Some handshake structures between CacheMissWorkerEntry and Cleaner
42   class TensorRowCacheRequest {
43    public:
44     enum class State : uint8_t {
45       kEmpty = 0,  // Initial state. Row hasn't arrived from cache miss stream yet.
46       kDirty = 1,  // Cleaner hasn't flushed it to the cache server yet.
47       kClean = 2   // The row has been flushed already.
48     };
TensorRowCacheRequest()49     TensorRowCacheRequest() : st_(State::kEmpty) {}
50     ~TensorRowCacheRequest() = default;
51     /// Getter and Setter of the state
GetState()52     State GetState() const { return st_; }
SetState(State newState)53     void SetState(State newState) { st_ = newState; }
54     /// Take a tensor row and send rpc call to the server async
55     /// \param cc Cache client of the CacheMergeOp
56     /// \param row TensorRow to be sent to the server
57     /// \return Status object
58     /// \note Thread safe
59     Status AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc, const TensorRow &row);
60 
61     /// \brief We send the row to the server async so the CacheMissWorkerEntry can continue.
62     /// It is the cleaner that will check the result.
63     /// \return Status object
64     Status CheckCacheResult();
65 
66    private:
67     std::atomic<State> st_;
68     std::shared_ptr<CacheRowRequest> cleaner_copy_;
69   };
70 
71   constexpr static int kNumChildren = 2;        // CacheMergeOp has 2 children
72   constexpr static int kCacheHitChildIdx = 0;   // Cache hit stream
73   constexpr static int kCacheMissChildIdx = 1;  // Cache miss stream
74 
75   /// \brief Constructor
76   /// \param numWorkers Number of parallel workers as a derived class of ParallelOp
77   /// \param opConnector Size Connector size as a derived class of ParallelOp
78   /// \param numCleaners Number of cleaners to move cache miss rows into the cache server
79   /// \param cache_client CacheClient to communicate with the Cache server
80   CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
81                std::shared_ptr<CacheClient> cache_client);
82   ~CacheMergeOp();
83   void Print(std::ostream &out, bool show_all) const override;
Name()84   std::string Name() const override { return kCacheMergeOp; }
85 
86   friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) {
87     mo.Print(out, false);
88     return out;
89   }
90   /// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and
91   ///     the threads for the cleaners.
92   /// \return
93   Status operator()() override;
94 
95   /// \brief Entry function for worker thread that fetch rows from CacheLookupOp
96   /// \param workerId
97   /// \return Status object
98   Status WorkerEntry(int32_t workerId) override;
99 
100   /// \brief Perform specific post-operations on CacheOp
101   /// \return Status The status code returned
102   Status PrepareOperator() override;
103 
104   /// \brief Entry function for worker thread that fetch rows from the cache miss stream
105   /// \param workerId
106   /// \return Status object
107   Status CacheMissWorkerEntry(int32_t workerId);
108 
109   /// \brief Base-class override for eoe handling
110   /// \param worker_id
111   /// \return Status object
112   Status EoeReceived(int32_t worker_id) override;
113 
114   /// \brief Base-class override for handling cases when an eof is received.
115   /// \param worker_id - The worker id
116   /// \return Status The status code returned
117   Status EofReceived(int32_t worker_id) override;
118 
119  protected:
120   Status ComputeColMap() override;
121 
122  private:
123   std::mutex mux_;
124   QueueMap<row_id_type, TensorRow> cache_miss_;
125   std::map<row_id_type, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>> io_request_;
126   std::unique_ptr<Queue<row_id_type>> io_que_;
127   int32_t num_cleaners_;
128   std::shared_ptr<CacheClient> cache_client_;
129   std::atomic<bool> cache_missing_rows_;
130 
131   /// \brief Locate the cache request from the io_request_ map
132   /// \param row_id
133   /// \param out pointer to the cache request
134   /// \return Status object
135   Status GetRq(row_id_type row_id, TensorRowCacheRequest **out);
136 
137   /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
138   /// moving cache miss TensorRow into the CacheServer.
139   /// \return Status object
140   Status Cleaner();
141 };
142 }  // namespace dataset
143 }  // namespace mindspore
144 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
145