• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
18 
19 #include <atomic>
20 #include <iostream>
21 #include <limits>
22 #include <memory>
23 #include <map>
24 #include <mutex>
25 #include <set>
26 #include <string>
27 #include <unordered_map>
28 #include <utility>
29 #include <vector>
30 
31 #include "minddata/dataset/core/config_manager.h"
32 #ifdef ENABLE_CACHE
33 #include "minddata/dataset/engine/cache/cache_grpc_client.h"
34 #else
35 #include "minddata/dataset/engine/cache/stub/cache_grpc_client.h"
36 #endif
37 
38 #include "minddata/dataset/util/lock.h"
39 #include "minddata/dataset/util/cond_var.h"
40 #include "minddata/dataset/util/queue_map.h"
41 #include "minddata/dataset/util/task_manager.h"
42 #include "minddata/dataset/util/wait_post.h"
43 
44 namespace mindspore {
45 namespace dataset {
46 /// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through
47 /// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously
48 /// rows, etc.
49 class CacheClient {
50  public:
51   friend class CacheMergeOp;
52   friend class CreateCacheRequest;
53   friend class CacheRowRequest;
54   friend class BatchFetchRequest;
55   friend class BatchCacheRowsRequest;
56 
57   /// \brief A builder to help creating a CacheClient object
58   class Builder {
59    public:
60     Builder();
61 
62     ~Builder() = default;
63 
64     /// Setter function to set the session id
65     /// \param session_id
66     /// \return Builder object itself.
SetSessionId(session_id_type session_id)67     Builder &SetSessionId(session_id_type session_id) {
68       session_id_ = session_id;
69       return *this;
70     }
71 
72     /// Setter function to set the cache memory size
73     /// \param cache_mem_sz
74     /// \return Builder object itself
SetCacheMemSz(uint64_t cache_mem_sz)75     Builder &SetCacheMemSz(uint64_t cache_mem_sz) {
76       cache_mem_sz_ = cache_mem_sz;
77       return *this;
78     }
79 
80     /// Setter function to spill attribute
81     /// \param spill
82     /// Builder object itself
SetSpill(bool spill)83     Builder &SetSpill(bool spill) {
84       spill_ = spill;
85       return *this;
86     }
87 
88     /// Setter function to set rpc hostname
89     /// \param host
90     /// \return Builder object itself
SetHostname(std::string host)91     Builder &SetHostname(std::string host) {
92       hostname_ = std::move(host);
93       return *this;
94     }
95 
96     /// Setter function to set tcpip port
97     /// \param port
98     /// \return Builder object itself.
SetPort(int32_t port)99     Builder &SetPort(int32_t port) {
100       port_ = port;
101       return *this;
102     }
103 
104     /// Setter function to set number of async rpc workers
105     /// \param num_connections
106     /// \return Builder object itself
SetNumConnections(int32_t num_connections)107     Builder &SetNumConnections(int32_t num_connections) {
108       num_connections_ = num_connections;
109       return *this;
110     }
111 
112     /// Setter function to set prefetch amount for fetching rows from cache server
113     /// \param prefetch_sz
114     /// \return Builder object itself
SetPrefetchSize(int32_t prefetch_sz)115     Builder &SetPrefetchSize(int32_t prefetch_sz) {
116       prefetch_size_ = prefetch_sz;
117       return *this;
118     }
119 
120     /// Getter functions
GetSessionId()121     session_id_type GetSessionId() const { return session_id_; }
GetCacheMemSz()122     uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
isSpill()123     bool isSpill() const { return spill_; }
GetHostname()124     const std::string &GetHostname() const { return hostname_; }
GetPort()125     int32_t GetPort() const { return port_; }
GetNumConnections()126     int32_t GetNumConnections() const { return num_connections_; }
GetPrefetchSize()127     int32_t GetPrefetchSize() const { return prefetch_size_; }
128 
129     Status SanityCheck();
130 
131     Status Build(std::shared_ptr<CacheClient> *out);
132 
133    private:
134     session_id_type session_id_;
135     uint64_t cache_mem_sz_;
136     bool spill_;
137     std::string hostname_;
138     int32_t port_;
139     int32_t num_connections_;
140     int32_t prefetch_size_;
141   };
142 
143   /// \brief Constructor
144   /// \param session_id A user assigned session id for the current pipeline
145   /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
146   /// \param spill Spill to disk if out of memory
147   CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port,
148               int32_t num_connections, int32_t prefetch_size);
149 
150   /// \brief Destructor
151   ~CacheClient();
152 
153   /// \brief Send a TensorRow to the cache server
154   /// \param[in] row
155   /// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset
156   /// \return return code
157   Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const;
158 
159   /// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is
160   /// any cache miss
161   /// \param row_id A vector of row id's
162   /// \param out A TensorTable of TensorRows.
163   /// \return return code
164   Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const;
165 
166   /// \brief Create a cache.
167   /// \param tree_crc  A crc that was generated during tree prepare phase
168   /// \param generate_id Let the cache service generate row id
169   /// \return Status object
170   Status CreateCache(uint32_t tree_crc, bool generate_id);
171 
172   /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
173   /// \return Status object
174   Status DestroyCache();
175 
176   /// \brief Get the statistics from a cache.
177   /// \param[in/out] Pointer to a pre-allocated ServiceStat object
178   /// \return Status object
179   Status GetStat(CacheServiceStat *);
180 
181   /// \brief Get the state of a cache server
182   /// \param[in/out] Pointer to a int8_t
183   /// \return Status object
184   Status GetState(int8_t *);
185 
186   /// \brief Cache the schema at the cache server
187   /// \param map The unordered map of the schema
188   /// \return Status object
189   Status CacheSchema(const std::unordered_map<std::string, int32_t> &map);
190 
191   /// \brief Fetch the schema from the cache server
192   /// \param map Pointer to pre-allocated map object
193   /// \return Status object.
194   Status FetchSchema(std::unordered_map<std::string, int32_t> *map);
195 
196   /// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache
197   /// client that holds cookie can be allowed to make this request
198   /// \return Status object
199   Status BuildPhaseDone() const;
200 
201   /// \brief A print method typically used for debugging
202   /// \param out The output stream to write output to
203   void Print(std::ostream &out) const;
204 
205   /// \brief Stream output operator overload
206   /// \return the output stream must be returned
207   friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) {
208     cc.Print(out);
209     return out;
210   }
211 
212   /// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it.
213   /// \return Cookie
cookie()214   std::string cookie() const { return cookie_; }
215 
216   /// \brief Send a request async to the server
217   /// \param rq BaseRequest
218   /// \return Status object
219   Status PushRequest(std::shared_ptr<BaseRequest> rq) const;
220 
221   /// \brief If the remote server supports local bypass using shared memory
222   /// \return boolean value
SupportLocalClient()223   bool SupportLocalClient() const { return local_bypass_; }
224 
225   /// \brief Return the base memory address if we attach to any shared memory.
SharedMemoryBaseAddr()226   auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); }
227 
228   /// Getter functions
session_id()229   session_id_type session_id() const { return cinfo_.session_id(); }
GetCacheMemSz()230   uint64_t GetCacheMemSz() const { return cache_mem_sz_; }
isSpill()231   bool isSpill() const { return spill_; }
GetNumConnections()232   int32_t GetNumConnections() const { return num_connections_; }
GetPrefetchSize()233   int32_t GetPrefetchSize() const { return prefetch_size_; }
GetClientId()234   int32_t GetClientId() const { return client_id_; }
235   std::string GetHostname() const;
236   int32_t GetPort() const;
237 
238   /// MergeOp will notify us when the server can't cache any more rows.
239   /// We will stop any attempt to fetch any rows that are most likely
240   /// not present at the server.
241   void ServerRunningOutOfResources();
242 
243   /// \brief Check if a row is 100% cache miss at the server by checking the local information
244   /// \param key row id to be test
245   /// \return true if not at the server
KeyIsCacheMiss(row_id_type key)246   bool KeyIsCacheMiss(row_id_type key) {
247     if (cache_miss_keys_) {
248       // Make sure it is fully built even though the pointer is not null
249       Status rc = cache_miss_keys_wp_.Wait();
250       if (rc.IsOk()) {
251         return cache_miss_keys_->KeyIsCacheMiss(key);
252       }
253     }
254     return false;
255   }
256 
257   /// \brief Serialize a Tensor into the async buffer.
258   Status AsyncWriteRow(const TensorRow &row);
259 
260   // Default size of the async write buffer
261   constexpr static int64_t kAsyncBufferSize = 16 * 1048576L;  // 16M
262   constexpr static int32_t kNumAsyncBuffer = 3;
263 
264   /// Force a final flush to the cache server. Must be called when receiving eoe.
FlushAsyncWriteBuffer()265   Status FlushAsyncWriteBuffer() {
266     if (async_buffer_stream_) {
267       return async_buffer_stream_->SyncFlush(AsyncBufferStream::AsyncFlushFlag::kFlushBlocking);
268     }
269     return Status::OK();
270   }
271 
272  private:
273   mutable RWLock mux_;
274   uint64_t cache_mem_sz_;
275   bool spill_;
276   // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
277   // sharing of the cache.
278   CacheClientInfo cinfo_;
279   // The server_connection_id_ is the actual id we use for operations after the cache is built
280   connection_id_type server_connection_id_;
281   // Some magic cookie/id returned from the cache server.
282   std::string cookie_;
283   int32_t client_id_;
284   std::vector<int32_t> cpu_list_;
285   // Comm layer
286   bool local_bypass_;
287   int32_t num_connections_;
288   int32_t prefetch_size_;
289   mutable std::shared_ptr<CacheClientGreeter> comm_;
290   std::atomic<bool> fetch_all_keys_;
291   WaitPost cache_miss_keys_wp_;
292   /// A structure shared by all the prefetchers to know what keys are missing at the server.
293   class CacheMissKeys {
294    public:
295     explicit CacheMissKeys(const std::vector<row_id_type> &v);
296     ~CacheMissKeys() = default;
297     /// This checks if a key is missing.
298     /// \param key
299     /// \return true if definitely a key miss
300     bool KeyIsCacheMiss(row_id_type key);
301 
302    private:
303     row_id_type min_;
304     row_id_type max_;
305     std::set<row_id_type> gap_;
306   };
307   std::unique_ptr<CacheMissKeys> cache_miss_keys_;
308 
309   /// A data stream of back-to-back serialized tensor rows.
310   class AsyncBufferStream {
311    public:
312     AsyncBufferStream();
313     ~AsyncBufferStream();
314 
315     /// \brief Initialize an Ascyn write buffer
316     Status Init(CacheClient *cc);
317 
318     /// A worker will call the API AsyncWrite to put a TensorRow into the data stream.
319     /// A background thread will stream the data to the cache server.
320     /// The result of calling AsyncWrite is not immediate known or it can be the last
321     /// result of some previous flush.
322     /// \note Need to call SyncFlush to do the final flush.
323     Status AsyncWrite(const TensorRow &row);
324     enum class AsyncFlushFlag : int8_t { kFlushNone = 0, kFlushBlocking = 1, kCallerHasXLock = 1u << 2 };
325     Status SyncFlush(AsyncFlushFlag flag);
326 
327     /// This maps a physical shared memory to the data stream.
328     class AsyncWriter {
329      public:
330       friend class AsyncBufferStream;
331       Status Write(int64_t sz, const std::vector<ReadableSlice> &v);
332 
333      private:
334       std::shared_ptr<BatchCacheRowsRequest> rq;
335       void *buffer_;
336       int32_t num_ele_;      // How many tensor rows in this buffer
337       int64_t bytes_avail_;  // Number of bytes remain
338     };
339 
340     /// \brief Release the shared memory during shutdown
341     /// /note but needs comm layer to be alive.
342     Status ReleaseBuffer();
343     /// \brief Reset the AsyncBufferStream into its initial state
344     /// \return Status object
345     Status Reset();
346 
347    private:
348     Status flush_rc_;
349     std::mutex mux_;
350     TaskGroup vg_;
351     CacheClient *cc_;
352     int64_t offset_addr_;
353     AsyncWriter buf_arr_[kNumAsyncBuffer];
354     int32_t cur_;
355   };
356   std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
357 };
358 }  // namespace dataset
359 }  // namespace mindspore
360 
361 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_
362