1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_EMBEDDING_CACHE_SCHEDULER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_EMBEDDING_CACHE_SCHEDULER_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <string> 23 #include "utils/ms_utils.h" 24 #include "include/backend/kernel_graph.h" 25 #include "runtime/hardware/device_context.h" 26 #include "include/backend/visible.h" 27 28 namespace mindspore { 29 namespace runtime { 30 using device::DeviceContext; 31 class EmbeddingCachePrefetchActor; 32 33 // EmbeddingCacheScheduler could be used to build, schedule and finalize embedding cache prefetch actor 34 // to cache large embedding table of a large recommendation network model. The cache level is: 35 // Device Cache->Local Host Cache->Remote Cache. The embedding cache prefetch actor is used to perform Local 36 // and Device Cache hit analysis and cache prefetching. 37 class BACKEND_EXPORT EmbeddingCacheScheduler { 38 public: 39 static EmbeddingCacheScheduler &GetInstance(); 40 41 // Build and initialize embedding cache prefetch actor and save it by embedding_cache_prefetch_actor_. 42 void Initialize(); 43 44 // Set device address for embedding cache parameter. 45 void SetEmbedCachedParamAddress(const DeviceContext *device_context, const KernelGraphPtr &graph); 46 // Set data set channel name, used for multi dataset mode, such as predict after train. 47 void SetDataSetChannel(const std::string &actor_id, const std::vector<KernelGraphPtr> &graphs); 48 49 // Initialize all embedding storage instances. 50 void InitEmbeddingStorage(const std::vector<AnfNodePtr> ¶meters) const; 51 52 // 1. Build network connection between local and remote cache for embedding cache prefetch actor. 53 // 2. Schedule and Run embedding cache prefetch actor. 54 // Since the embedding cache prefetch actor is spinning, and the actor is not in the actor set, start the actor in the 55 // Schedule interface. 56 void Schedule(); 57 58 // Record the number of global steps executed by the compute graph. 59 void IncreaseGraphStep(const std::string &actor_id) const; 60 61 // Synchronize latest embedding table in local cache to remote. 62 void SyncEmbeddingTable() const; 63 64 // Finalize embedding cache prefetch actor. 65 void Finalize(bool sync_embedding_table = true); 66 67 private: 68 EmbeddingCacheScheduler() = default; 69 ~EmbeddingCacheScheduler() = default; 70 DISABLE_COPY_AND_ASSIGN(EmbeddingCacheScheduler); 71 72 // Get ids number in a batch, not batch size. 73 void ParseBatchIdsNum(const KernelGraphPtr &graph); 74 75 // Allocate device and local host memory for embedding cache table. 76 void AllocMemForEmbeddingCacheTable(const DeviceContext *device_context); 77 78 // Embedding cache prefetch actor. 79 std::shared_ptr<EmbeddingCachePrefetchActor> embedding_cache_prefetch_actor_; 80 81 // The flag indicates whether already parse batch ids number. 82 bool parsed_batch_ids_num_{false}; 83 84 // The flag indicates whether already allocate memory for embedding cache tables. 85 bool allocated_embed_cache_mem_{false}; 86 87 // The flag indicates whether the EmbeddingCacheScheduler is initialized. 88 bool initialized_{false}; 89 // The flag indicates whether the EmbeddingCacheScheduler is scheduled. 90 bool scheduled_{false}; 91 // The flag indicates whether the EmbeddingCacheScheduler is finalized. 92 bool finalized_{false}; 93 // Ensure that the Finalize function is multithreaded safe. 94 std::mutex finalize_mutex_; 95 96 // Record data set channel name, used for multi dataset mode, such as predict after train. 97 // Key: data prepare actor id for an actor set, Value: data set channel name. 98 mindspore::HashMap<std::string, std::string> data_prepare_aid_to_data_channel_; 99 }; 100 } // namespace runtime 101 } // namespace mindspore 102 103 #endif // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_EMBEDDING_CACHE_SCHEDULER_H_ 104