1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ 18 19 #include <atomic> 20 #include <iostream> 21 #include <limits> 22 #include <memory> 23 #include <map> 24 #include <mutex> 25 #include <set> 26 #include <string> 27 #include <unordered_map> 28 #include <utility> 29 #include <vector> 30 31 #include "minddata/dataset/core/config_manager.h" 32 #ifdef ENABLE_CACHE 33 #include "minddata/dataset/engine/cache/cache_grpc_client.h" 34 #else 35 #include "minddata/dataset/engine/cache/stub/cache_grpc_client.h" 36 #endif 37 38 #include "minddata/dataset/util/lock.h" 39 #include "minddata/dataset/util/cond_var.h" 40 #include "minddata/dataset/util/queue_map.h" 41 #include "minddata/dataset/util/task_manager.h" 42 #include "minddata/dataset/util/wait_post.h" 43 44 namespace mindspore { 45 namespace dataset { 46 /// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through 47 /// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously 48 /// rows, etc. 49 class CacheClient { 50 public: 51 friend class CacheMergeOp; 52 friend class CreateCacheRequest; 53 friend class CacheRowRequest; 54 friend class BatchFetchRequest; 55 friend class BatchCacheRowsRequest; 56 57 /// \brief A builder to help creating a CacheClient object 58 class Builder { 59 public: 60 Builder(); 61 62 ~Builder() = default; 63 64 /// Setter function to set the session id 65 /// \param session_id 66 /// \return Builder object itself. SetSessionId(session_id_type session_id)67 Builder &SetSessionId(session_id_type session_id) { 68 session_id_ = session_id; 69 return *this; 70 } 71 72 /// Setter function to set the cache memory size 73 /// \param cache_mem_sz 74 /// \return Builder object itself SetCacheMemSz(uint64_t cache_mem_sz)75 Builder &SetCacheMemSz(uint64_t cache_mem_sz) { 76 cache_mem_sz_ = cache_mem_sz; 77 return *this; 78 } 79 80 /// Setter function to spill attribute 81 /// \param spill 82 /// Builder object itself SetSpill(bool spill)83 Builder &SetSpill(bool spill) { 84 spill_ = spill; 85 return *this; 86 } 87 88 /// Setter function to set rpc hostname 89 /// \param host 90 /// \return Builder object itself SetHostname(std::string host)91 Builder &SetHostname(std::string host) { 92 hostname_ = std::move(host); 93 return *this; 94 } 95 96 /// Setter function to set tcpip port 97 /// \param port 98 /// \return Builder object itself. SetPort(int32_t port)99 Builder &SetPort(int32_t port) { 100 port_ = port; 101 return *this; 102 } 103 104 /// Setter function to set number of async rpc workers 105 /// \param num_connections 106 /// \return Builder object itself SetNumConnections(int32_t num_connections)107 Builder &SetNumConnections(int32_t num_connections) { 108 num_connections_ = num_connections; 109 return *this; 110 } 111 112 /// Setter function to set prefetch amount for fetching rows from cache server 113 /// \param prefetch_sz 114 /// \return Builder object itself SetPrefetchSize(int32_t prefetch_sz)115 Builder &SetPrefetchSize(int32_t prefetch_sz) { 116 prefetch_size_ = prefetch_sz; 117 return *this; 118 } 119 120 /// Getter functions GetSessionId()121 session_id_type GetSessionId() const { return session_id_; } GetCacheMemSz()122 uint64_t GetCacheMemSz() const { return cache_mem_sz_; } isSpill()123 bool isSpill() const { return spill_; } GetHostname()124 const std::string &GetHostname() const { return hostname_; } GetPort()125 int32_t GetPort() const { return port_; } GetNumConnections()126 int32_t GetNumConnections() const { return num_connections_; } GetPrefetchSize()127 int32_t GetPrefetchSize() const { return prefetch_size_; } 128 129 Status SanityCheck(); 130 131 Status Build(std::shared_ptr<CacheClient> *out); 132 133 private: 134 session_id_type session_id_; 135 uint64_t cache_mem_sz_; 136 bool spill_; 137 std::string hostname_; 138 int32_t port_; 139 int32_t num_connections_; 140 int32_t prefetch_size_; 141 }; 142 143 /// \brief Constructor 144 /// \param session_id A user assigned session id for the current pipeline 145 /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited 146 /// \param spill Spill to disk if out of memory 147 CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port, 148 int32_t num_connections, int32_t prefetch_size); 149 150 /// \brief Destructor 151 ~CacheClient(); 152 153 /// \brief Send a TensorRow to the cache server 154 /// \param[in] row 155 /// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset 156 /// \return return code 157 Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const; 158 159 /// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is 160 /// any cache miss 161 /// \param row_id A vector of row id's 162 /// \param out A TensorTable of TensorRows. 163 /// \return return code 164 Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const; 165 166 /// \brief Create a cache. 167 /// \param tree_crc A crc that was generated during tree prepare phase 168 /// \param generate_id Let the cache service generate row id 169 /// \return Status object 170 Status CreateCache(uint32_t tree_crc, bool generate_id); 171 172 /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. 173 /// \return Status object 174 Status DestroyCache(); 175 176 /// \brief Get the statistics from a cache. 177 /// \param[in/out] Pointer to a pre-allocated ServiceStat object 178 /// \return Status object 179 Status GetStat(CacheServiceStat *); 180 181 /// \brief Get the state of a cache server 182 /// \param[in/out] Pointer to a int8_t 183 /// \return Status object 184 Status GetState(int8_t *); 185 186 /// \brief Cache the schema at the cache server 187 /// \param map The unordered map of the schema 188 /// \return Status object 189 Status CacheSchema(const std::unordered_map<std::string, int32_t> &map); 190 191 /// \brief Fetch the schema from the cache server 192 /// \param map Pointer to pre-allocated map object 193 /// \return Status object. 194 Status FetchSchema(std::unordered_map<std::string, int32_t> *map); 195 196 /// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache 197 /// client that holds cookie can be allowed to make this request 198 /// \return Status object 199 Status BuildPhaseDone() const; 200 201 /// \brief A print method typically used for debugging 202 /// \param out The output stream to write output to 203 void Print(std::ostream &out) const; 204 205 /// \brief Stream output operator overload 206 /// \return the output stream must be returned 207 friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) { 208 cc.Print(out); 209 return out; 210 } 211 212 /// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it. 213 /// \return Cookie cookie()214 std::string cookie() const { return cookie_; } 215 216 /// \brief Send a request async to the server 217 /// \param rq BaseRequest 218 /// \return Status object 219 Status PushRequest(std::shared_ptr<BaseRequest> rq) const; 220 221 /// \brief If the remote server supports local bypass using shared memory 222 /// \return boolean value SupportLocalClient()223 bool SupportLocalClient() const { return local_bypass_; } 224 225 /// \brief Return the base memory address if we attach to any shared memory. SharedMemoryBaseAddr()226 auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); } 227 228 /// Getter functions session_id()229 session_id_type session_id() const { return cinfo_.session_id(); } GetCacheMemSz()230 uint64_t GetCacheMemSz() const { return cache_mem_sz_; } isSpill()231 bool isSpill() const { return spill_; } GetNumConnections()232 int32_t GetNumConnections() const { return num_connections_; } GetPrefetchSize()233 int32_t GetPrefetchSize() const { return prefetch_size_; } GetClientId()234 int32_t GetClientId() const { return client_id_; } 235 std::string GetHostname() const; 236 int32_t GetPort() const; 237 238 /// MergeOp will notify us when the server can't cache any more rows. 239 /// We will stop any attempt to fetch any rows that are most likely 240 /// not present at the server. 241 void ServerRunningOutOfResources(); 242 243 /// \brief Check if a row is 100% cache miss at the server by checking the local information 244 /// \param key row id to be test 245 /// \return true if not at the server KeyIsCacheMiss(row_id_type key)246 bool KeyIsCacheMiss(row_id_type key) { 247 if (cache_miss_keys_) { 248 // Make sure it is fully built even though the pointer is not null 249 Status rc = cache_miss_keys_wp_.Wait(); 250 if (rc.IsOk()) { 251 return cache_miss_keys_->KeyIsCacheMiss(key); 252 } 253 } 254 return false; 255 } 256 257 /// \brief Serialize a Tensor into the async buffer. 258 Status AsyncWriteRow(const TensorRow &row); 259 260 // Default size of the async write buffer 261 constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M 262 constexpr static int32_t kNumAsyncBuffer = 3; 263 264 /// Force a final flush to the cache server. Must be called when receiving eoe. FlushAsyncWriteBuffer()265 Status FlushAsyncWriteBuffer() { 266 if (async_buffer_stream_) { 267 return async_buffer_stream_->SyncFlush(AsyncBufferStream::AsyncFlushFlag::kFlushBlocking); 268 } 269 return Status::OK(); 270 } 271 272 private: 273 mutable RWLock mux_; 274 uint64_t cache_mem_sz_; 275 bool spill_; 276 // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow 277 // sharing of the cache. 278 CacheClientInfo cinfo_; 279 // The server_connection_id_ is the actual id we use for operations after the cache is built 280 connection_id_type server_connection_id_; 281 // Some magic cookie/id returned from the cache server. 282 std::string cookie_; 283 int32_t client_id_; 284 std::vector<int32_t> cpu_list_; 285 // Comm layer 286 bool local_bypass_; 287 int32_t num_connections_; 288 int32_t prefetch_size_; 289 mutable std::shared_ptr<CacheClientGreeter> comm_; 290 std::atomic<bool> fetch_all_keys_; 291 WaitPost cache_miss_keys_wp_; 292 /// A structure shared by all the prefetchers to know what keys are missing at the server. 293 class CacheMissKeys { 294 public: 295 explicit CacheMissKeys(const std::vector<row_id_type> &v); 296 ~CacheMissKeys() = default; 297 /// This checks if a key is missing. 298 /// \param key 299 /// \return true if definitely a key miss 300 bool KeyIsCacheMiss(row_id_type key); 301 302 private: 303 row_id_type min_; 304 row_id_type max_; 305 std::set<row_id_type> gap_; 306 }; 307 std::unique_ptr<CacheMissKeys> cache_miss_keys_; 308 309 /// A data stream of back-to-back serialized tensor rows. 310 class AsyncBufferStream { 311 public: 312 AsyncBufferStream(); 313 ~AsyncBufferStream(); 314 315 /// \brief Initialize an Ascyn write buffer 316 Status Init(CacheClient *cc); 317 318 /// A worker will call the API AsyncWrite to put a TensorRow into the data stream. 319 /// A background thread will stream the data to the cache server. 320 /// The result of calling AsyncWrite is not immediate known or it can be the last 321 /// result of some previous flush. 322 /// \note Need to call SyncFlush to do the final flush. 323 Status AsyncWrite(const TensorRow &row); 324 enum class AsyncFlushFlag : int8_t { kFlushNone = 0, kFlushBlocking = 1, kCallerHasXLock = 1u << 2 }; 325 Status SyncFlush(AsyncFlushFlag flag); 326 327 /// This maps a physical shared memory to the data stream. 328 class AsyncWriter { 329 public: 330 friend class AsyncBufferStream; 331 Status Write(int64_t sz, const std::vector<ReadableSlice> &v); 332 333 private: 334 std::shared_ptr<BatchCacheRowsRequest> rq; 335 void *buffer_; 336 int32_t num_ele_; // How many tensor rows in this buffer 337 int64_t bytes_avail_; // Number of bytes remain 338 }; 339 340 /// \brief Release the shared memory during shutdown 341 /// /note but needs comm layer to be alive. 342 Status ReleaseBuffer(); 343 /// \brief Reset the AsyncBufferStream into its initial state 344 /// \return Status object 345 Status Reset(); 346 347 private: 348 Status flush_rc_; 349 std::mutex mux_; 350 TaskGroup vg_; 351 CacheClient *cc_; 352 int64_t offset_addr_; 353 AsyncWriter buf_arr_[kNumAsyncBuffer]; 354 int32_t cur_; 355 }; 356 std::shared_ptr<AsyncBufferStream> async_buffer_stream_; 357 }; 358 } // namespace dataset 359 } // namespace mindspore 360 361 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_CLIENT_H_ 362