1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_EMBEDDING_CACHE_PREFETCH_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_EMBEDDING_CACHE_PREFETCH_ACTOR_H_ 19 20 #include <map> 21 #include <set> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 #include <utility> 26 #include <tuple> 27 #include <random> 28 29 #include "runtime/graph_scheduler/actor/actor_common.h" 30 #include "ir/anf.h" 31 #include "include/backend/kernel_graph.h" 32 #include "include/backend/distributed/cluster/cluster_context.h" 33 #include "distributed/cluster/actor_route_table_proxy.h" 34 #include "include/backend/distributed/rpc/tcp/tcp_client.h" 35 #include "include/backend/distributed/rpc/tcp/tcp_server.h" 36 #include "utils/hash_map.h" 37 #include "include/common/random.h" 38 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h" 39 #include "include/backend/distributed/embedding_cache/blocking_queue.h" 40 41 // Note: After the code in ps/ps_cache are removed into runtime/addons/embedding_cache/, 42 // the follow include file and using declaration of ps will be removed. 43 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h" 44 #include "include/backend/distributed/ps/ps_context.h" 45 using mindspore::ps::PSContext; 46 using mindspore::ps::PsDataChannel; 47 using mindspore::ps::PsDataPrefetch; 48 49 namespace mindspore { 50 namespace runtime { 51 using kernel::Address; 52 using kernel::AddressPtr; 53 using kernel::AddressPtrList; 54 55 class DeviceEmbeddingOperation; 56 class Sender; 57 class Receiver; 58 using SenderPtr = std::shared_ptr<Sender>; 59 using ReceiverPtr = std::shared_ptr<Receiver>; 60 using SendRecvPair = std::pair<SenderPtr, ReceiverPtr>; 61 using SendRecvPairList = std::vector<SendRecvPair>; 62 63 using distributed::EmbeddingCacheStatisticsInfo; 64 using distributed::EmbeddingDeviceCache; 65 using distributed::EmbeddingHostCache; 66 using distributed::HashTableInfo; 67 using distributed::kInvalidIndexValue; 68 69 using distributed::BlockingQueue; 70 using distributed::CacheAnalysis; 71 using distributed::IdsAndIndices; 72 using distributed::UniqueIds; 73 using BlockingQueueTuple = 74 std::tuple<std::shared_ptr<BlockingQueue<UniqueIds>>, std::shared_ptr<BlockingQueue<CacheAnalysis>>, 75 std::shared_ptr<BlockingQueue<IdsAndIndices>>>; 76 77 using distributed::cluster::ActorRouteTableProxy; 78 using distributed::cluster::ActorRouteTableProxyPtr; 79 using distributed::rpc::TCPClient; 80 using distributed::rpc::TCPServer; 81 82 using DataType = float; 83 using Generator = random::Philox; 84 using NormalDistribution = random::NormalDistribution<double>; 85 using ConstantDistribution = random::ConstantDistribution<DataType>; 86 87 constexpr size_t kPipelineStageNum = 4; 88 constexpr size_t kIndex0 = 0; 89 constexpr size_t kIndex1 = 1; 90 constexpr size_t kIndex2 = 2; 91 constexpr size_t kIndex3 = 3; 92 93 // The EmbeddingCachePrefetchActor is used to cache large embedding table scenarios. The cache level is: Device 94 // Cache->Local Host Cache->Remote Cache. This Actor is used to perform Local and Device Cache hit analysis and cache 95 // prefetching (the feature weights corresponding to the ids of subsequent batches are assigned in advance Prefetching 96 // into the Device Cache, so that it is pipelined with the calculation on the Device side), cache prefetching may 97 // involve RPC communication with the Server side. 98 class EmbeddingCachePrefetchActor : public ActorBase { 99 public: EmbeddingCachePrefetchActor(device::DeviceContext * device_context)100 explicit EmbeddingCachePrefetchActor(device::DeviceContext *device_context) 101 : ActorBase("EmbeddingCachePrefetchActor"), device_context_(device_context), cpu_device_context_(nullptr) {} 102 103 ~EmbeddingCachePrefetchActor() override = default; 104 105 // Initialize embedding cache prefetch actor. 106 // 1. Build and Link rpc operators between local cache and remote cache. 107 // 2. Build network connection of rpc operators. 108 void Initialize(); 109 110 // Perform local cache hit analysis, prefetch the feature vector corresponding to the next batch into the cache. 111 void Run(); 112 113 // Increase the global step of compute graph. 114 void IncreaseGraphStep(const std::string &channel_name); 115 116 // Sync latest embedding table to remote. 117 void SyncEmbeddingTable(); 118 119 // Finalize embedding cache prefetch actor and push latest embedding from local cache to remote cache. 120 void Finalize(bool finalize_remote); 121 122 // Wait the computed graph finish current step when there is not enough free memory space in the cache, in order to 123 // delete the feature vector used by the current step from the cache. 124 bool WaitGraphRun(); 125 126 // Reset EmbeddingHashMap for device and local host cache. 127 bool ResetEmbeddingHashMap(); 128 129 // Insert weights into the local host embedding cache. 130 bool InsertLocalHostCache(size_t embedding_size, size_t insert_indices_size, const int *insert_indices, 131 const float *insert_data, float *hash_table_addr); 132 133 // Lookup embeddings from local host embedding cache. 134 bool LookupLocalHostCache(size_t embedding_size, size_t indices_num, const float *hash_table_addr, 135 const int *indices_addr, float *output_addr); 136 137 private: 138 // Increase the current global step of cache prefetching operation. 139 bool IncreaseStep(); 140 141 // Update the current computed graph's step to real global step at the time when this actor starts to prefetch cache 142 // for a batch ids. set_current_graph_step()143 void set_current_graph_step() { graph_running_step_ = graph_step_.load(); } 144 145 // Push non-hotspot embeddings on local host cache to remote. 146 bool PushCacheFromLocalHostToRemote(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis); 147 148 // Pull missing embeddings on local cache from remote. 149 bool PullCacheFromRemoteToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis); 150 151 // Initialize local cache values using the random number generator. 152 bool InitLocalCacheForNewIds(const HashTableInfo &hash_info); 153 bool InitLocalCacheForNewIds(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis); 154 155 // Lookup embedding from Remote and get embeddings via RPC. 156 bool PullEembeddingsFromRemote(int32_t param_key, const int *ids, size_t ids_num, std::vector<float> *outputs); 157 // Push the local embedding cache that requires evict to the remote. 158 bool PushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num, const float *embeddings, 159 size_t embeddings_len); 160 bool DoPushEmbeddingsToRemote(int32_t param_key, const int *ids, size_t ids_num, const float *embeddings, 161 size_t embeddings_len); 162 163 // In a multi-server scenario, the embeddings need to be segmented, and each server saves the embeddings of 164 // different feature id ranges. Therefore, when the local side performs the push or pull embeddings operation, the 165 // embeddings and ids need to be divided, and then communicate with the corresponding remote: Partition ids by 166 // remote embedding slice bound and get unique ids. 167 bool PartitionIds(const int *ids, size_t ids_num, std::vector<std::vector<int>> *slice_ids_list); 168 // Partition ids end embeddings by remote embedding slice bound. 169 bool PartitionIdsAndEmbeddings(const int *ids, size_t ids_num, const float *embeddings, size_t embeddings_len, 170 std::vector<std::vector<int>> *slice_ids_list, 171 std::vector<std::vector<float>> *slice_embeddings_list); 172 173 // Send content to remote, such as ids or embeddings. 174 // The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache. 175 bool SendToRemote(const std::string &cache_operation, int32_t param_key, size_t server_rank_id, size_t embedding_dim, 176 const void *keys, size_t keys_len, const void *values = nullptr, size_t values_len = 0, 177 bool finalize_remote = false, bool sync = true); 178 // Wait response of remote and get return result. 179 // The parameter 'cache_operation' is cache operation name such as LookupEmbeddingCache and UpdateEmbeddingCache. 180 std::unique_ptr<std::vector<char>> ReceiveFromRemote(const std::string &cache_operation, int32_t param_key, 181 size_t server_rank_id) const; 182 // Retrieve embeddings by input ids order. 183 bool RetrieveEmbeddings(const int *ids, size_t ids_num, const std::vector<std::vector<int>> &slice_ids_list, 184 const std::vector<std::unique_ptr<std::vector<char>>> &slice_embeddings_list, 185 std::vector<float> *outputs) const; 186 187 // Send finalize request to remote and finalize it. 188 bool FinalizeRemote(); 189 190 // Sync latest local host embedding cache to remote. 191 bool SyncHostEmbeddingTable(); 192 // Sync latest device embedding cache to remote. 193 bool SyncDeviceEmbeddingTable(); 194 195 // The cache prefetch phase may involve RPC communication with the server, implemented through Sender and 196 // Receiver. 197 // Build rpc operators. 198 void BuildRpcOperators(); 199 // Link rpc operators and build network connection. 200 void LinkRpcOperators(); 201 202 // Get dataset channel name. 203 const std::string &channel_name(); 204 // Set dataset channel name. 205 void set_channel_name(const std::string &channel_name); 206 207 // When the device cache does not reach 100% hit, the cache needs to be updated, which involves cache insertion and 208 // deletion. That is, push the non-hotspot embeddings on the local side to the remote, and pull the missing embeddings 209 // on the local side from the remote. 210 bool UpdateCache(); 211 212 // Do lookup embedding table operation. 213 void LookupEmbeddingTable(size_t indices_num, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, 214 const int *indices_addr, float *output_addr); 215 216 // Wait data channel ready. 217 void WaitDataChannelInit(); 218 219 // Wait initialize parameters on remote. 220 // Prevents the subsequent prefetch cache from failing due to the long initialization time of the large parameter on 221 // the remote side. 222 void WaitInitParametersOnRemote(); 223 224 void CreateChannelLock(const std::string &channel_name); 225 void CreateBlockQueue(const std::string &channel_name); 226 227 // Perform Local and Device Cache hit/miss analysis and prefetch cache for missing embeddings by multi-stage pipeline. 228 // Data flow: unique id queue -> cache analysis queue->id and indices queue 229 void StartPrefetchCachePipeline(const std::string &channel_name); 230 void StopPrefetchCachePipeline(); 231 void WaitPrefetchCacheFinish(); 232 233 // The four stage pipeline task. 234 void UniqueIdsTask(const std::string &channel_name); 235 void AnalyseCacheTask(const std::string &channel_name); 236 void UpdateCacheTask(const std::string &channel_name); 237 void TransformIdsToIndicesTask(const std::string &channel_name); 238 239 // Set current error information before finalizing actor. 240 void SetErrorInfo(const std::string &error_info); 241 242 mindspore::HashMap<std::string, std::shared_ptr<PsDataChannel>> channel_locks_; 243 mindspore::HashMap<std::string, std::shared_ptr<std::vector<std::thread>>> pipeline_stages_; 244 mindspore::HashMap<std::string, BlockingQueueTuple> channel_to_queues_; 245 246 // The operations for the embedding on the device. 247 DeviceEmbeddingOperation *emb_ops_{nullptr}; 248 249 // Record sender and receiver pairs for different cache operation, server and parameter key. 250 // key: cache operation(such as LookupEmbeddingCache and UpdateEmbeddingCache) 251 // value: sender and receiver pairs for this kind of cache operation. 252 mindspore::HashMap<std::string, std::vector<SendRecvPairList>> rpc_operators_; 253 254 // The device interface. 255 device::DeviceContext *device_context_; 256 // The CPU device context used for allocating rpc message data. 257 device::DeviceContext *cpu_device_context_; 258 // The device stream used to async memcpy operators and launch device kernels, such as embedding cache look up and 259 // update kernel. 260 size_t stream_id_{0}; 261 262 // Full Embedding table row num, not less than the total number of feature ids. 263 size_t vocab_size_{0}; 264 265 // Embedding cache size(row number of embedding cache) of local host cache. 266 size_t local_host_cache_size_{0}; 267 268 // Statistics on the cache hit rate of the host and device and the information used to update cache. 269 EmbeddingCacheStatisticsInfo statistics_info_; 270 271 // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range 272 // corresponding to the embedding table slice of the process. 273 std::pair<int, int> local_embedding_slice_bounds_; 274 275 // Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache 276 // range corresponding to the embedding table slice of the process. 277 std::pair<int, int> local_device_cache_bounds_; 278 279 // In a multi-server scenario, the embeddings need to be segmented, and each server saves the embeddings of 280 // different feature id ranges, remote_embedding_slice_bounds_ records the feature range of the embedding table 281 // slice on each server. 282 std::vector<std::pair<size_t, size_t>> remote_embedding_slice_bounds_; 283 284 // Total server number of cluster. 285 size_t server_num_{0}; 286 287 // The flag which indicates whether this actor is running to prefetch cache. 288 std::atomic_bool running_{false}; 289 290 // The flag which indicates whether this actor is initialized. 291 bool initialized_{false}; 292 // The flag which indicates whether this actor is finalized. 293 bool finalized_{false}; 294 295 // Ensure that the Finalize function is multithreaded safe. 296 std::mutex finalize_mutex_; 297 298 // The flag which indicates whether finish sync embedding table. 299 bool finish_sync_embedding_table_{false}; 300 std::mutex sync_embedding_table_mutex_; 301 302 // The current global step of the computed graph. 303 std::atomic_ulong graph_step_{0}; 304 // The computed graph's global step at the time when this actor starts to prefetch cache for a batch ids. 305 std::atomic_ulong graph_running_step_{0}; 306 // The current global step of cache prefetching operation. 307 size_t data_step_{0}; 308 309 // Dataset channel name, used in dataset switching scenarios. 310 std::string channel_name_{""}; 311 // The mutex to access channel_name_. 312 std::mutex channel_mutex_; 313 314 // The flag indicates whether finish initializing parameters on remote.. 315 std::atomic_bool finish_init_parameters_on_remote_{false}; 316 317 // Data parser condition variable for prefetching cache, used to start and synchronize intermediate state for cache 318 // prefetching. 319 std::condition_variable data_parser_; 320 // Data parser mutex for prefetching cache. 321 std::mutex data_mutex_; 322 323 // Whether device cache prefetching process needs to wait the computed graph finish current step when there is not 324 // enough free memory space in the cache. 325 bool device_cache_need_wait_graph_{false}; 326 // Whether local host cache prefetching process needs to wait the computed graph finish current step when there is not 327 // enough free memory space in the cache. 328 bool host_cache_need_wait_graph_{false}; 329 330 std::mutex pipeline_mutex_; 331 // Record latest error information user related. 332 std::string error_info_{""}; 333 }; 334 335 // RpcOperator is used to do rpc with other processes in distributed execution. 336 // RpcOperator use inter process edge to identify paired rpc operators uniquely. 337 class RpcOperator { 338 public: RpcOperator()339 RpcOperator() : inter_process_edge_(""), route_table_proxy_(nullptr) {} 340 virtual ~RpcOperator() = default; 341 342 // Set the inter-process edge name for rpc operators. set_inter_process_edge_name(const std::string & edge_name)343 void set_inter_process_edge_name(const std::string &edge_name) { inter_process_edge_ = edge_name; } 344 345 // Set the route table proxy for rpc operators. set_actor_route_table_proxy(const ActorRouteTableProxyPtr & route_table_proxy)346 void set_actor_route_table_proxy(const ActorRouteTableProxyPtr &route_table_proxy) { 347 route_table_proxy_ = route_table_proxy; 348 } 349 350 protected: 351 // Unique edge name between rpc operator, format: 352 // src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter key. 353 std::string inter_process_edge_; 354 355 // Route table proxy for buildding network connection between nodes like workers and server. 356 ActorRouteTableProxyPtr route_table_proxy_; 357 }; 358 359 // Sender is used to send data to other process. 360 class Sender : public RpcOperator { 361 public: Sender(device::DeviceContext * cpu_device_context)362 explicit Sender(device::DeviceContext *cpu_device_context) 363 : server_url_(""), client_(nullptr), cpu_device_context_(cpu_device_context) {} 364 ~Sender() override; 365 366 // Send buffer to peer. 367 bool Send(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types, 368 const AddressPtrList &data_list, bool finalize_remote = false, bool sync = true) const; 369 370 // Set the receiver paired with the sender to get the 'from url' from the receiver. set_receiver(const ReceiverPtr & receiver)371 void set_receiver(const ReceiverPtr &receiver) { receiver_ = receiver; } 372 373 // Lookup peer receiver's route and build network connection. 374 bool ConnectServer(); 375 376 private: 377 // Build the MessageBase include dynamic shape protobuf, which will be sent to peer receiver. 378 // The message format is as below: 379 // |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size | 380 // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----| 381 // The message.from (from url) must be set. 382 std::unique_ptr<MessageBase> BuildRpcMessage(const std::vector<ShapeVector> &shapes, 383 const std::vector<TypeId> data_types, const AddressPtrList &data_list, 384 const std::string &from_url, const std::string &to_url, 385 bool finalize_remote) const; 386 387 // Free message after it's sent to remote. 388 bool FreeMessage(void *data); 389 390 // Calculate the dynamic shape message size. 391 size_t CalDataSize(const std::vector<ShapeVector> &shapes, const std::vector<TypeId> data_types, 392 const AddressPtrList &data_list, bool finalize_remote) const; 393 394 // The url of the peer receiver's tcp server. 395 std::string server_url_; 396 397 std::unique_ptr<TCPClient> client_; 398 399 // The sender and the receiver are used in pairs. The information sent by the sender contains the url of the 400 // corresponding receiver, so a reference to the receiver is maintained in the sender. 401 ReceiverPtr receiver_; 402 403 // The CPU device context used for allocating rpc message data. 404 device::DeviceContext *cpu_device_context_; 405 }; 406 407 // Receiver is used to receive data from other process. 408 class Receiver : public RpcOperator { 409 public: Receiver(device::DeviceContext * cpu_device_context)410 explicit Receiver(device::DeviceContext *cpu_device_context) 411 : ip_(""), 412 port_(0), 413 server_(nullptr), 414 received_buffer_(nullptr), 415 received_msg_(false), 416 cpu_device_context_(cpu_device_context) {} 417 ~Receiver() override; 418 419 // Receive message from the peer sender, this interface is a synchronous interface and will wait for the message 420 // until the timeout period is reached. 421 std::unique_ptr<std::vector<char>> Receive(); 422 423 // Start receiver server and register this server address to route table in scheduler by proxy. 424 bool StartServer(); 425 426 // Get the url of this receiver, format: ip:port. get_url()427 std::string get_url() const { return ip_ + ":" + std::to_string(port_); } 428 429 private: 430 // The message callback of the tcp server. 431 MessageBase *HandleMessage(MessageBase *const msg); 432 433 // Parse the dynamic shape protobuf message. The format is as below: 434 // |--------22 bytes-------|-------sizeof(size_t)-------|-dynamic shape PB data size-| real data size | 435 // |RPC_DYNAMIC_SHAPE_DATA | dynamic shape PB data size |---dynamic shape PB data----|---real data----| 436 // The output parameter 'data' contains real data addr and size. 437 bool ParseDynamicShapeData(const char *msg_body, size_t msg_len, std::pair<const void *, size_t> *data) const; 438 439 // The callback set to rpc module to allocate message(Raw pointer). 440 void *AllocateMessage(size_t size); 441 442 // The network address of this receiver. It's generated automatically by rpc module. 443 std::string ip_; 444 uint32_t port_; 445 446 std::unique_ptr<TCPServer> server_; 447 448 // The buffer used save received content of message. 449 std::unique_ptr<std::vector<char>> received_buffer_; 450 451 // The flag indicates whether receive message successfully. 452 std::atomic_bool received_msg_; 453 454 // The interface 'Receive' is a synchronous, use condition variable to block thread and wait for the message. 455 std::condition_variable received_msg_cv_; 456 std::mutex received_msg_mtx_; 457 458 // The CPU device context used for allocating rpc message data. 459 device::DeviceContext *cpu_device_context_; 460 }; 461 462 using EmbeddingCachePrefetchActorPtr = std::shared_ptr<EmbeddingCachePrefetchActor>; 463 } // namespace runtime 464 } // namespace mindspore 465 #endif // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_EMBEDDING_CACHE_PREFETCH_ACTOR_H_ 466