• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_EMBEDDING_CACHE_SCHEDULER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_EMBEDDING_CACHE_SCHEDULER_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include "utils/ms_utils.h"
24 #include "include/backend/kernel_graph.h"
25 #include "runtime/hardware/device_context.h"
26 #include "include/backend/visible.h"
27 
28 namespace mindspore {
29 namespace runtime {
30 using device::DeviceContext;
31 class EmbeddingCachePrefetchActor;
32 
33 // EmbeddingCacheScheduler could be used to build, schedule and finalize embedding cache prefetch actor
34 // to cache large embedding table of a large recommendation network model. The cache level is:
35 // Device Cache->Local Host Cache->Remote Cache. The embedding cache prefetch actor is used to perform Local
36 // and Device Cache hit analysis and cache prefetching.
37 class BACKEND_EXPORT EmbeddingCacheScheduler {
38  public:
39   static EmbeddingCacheScheduler &GetInstance();
40 
41   // Build and initialize embedding cache prefetch actor and save it by embedding_cache_prefetch_actor_.
42   void Initialize();
43 
44   // Set device address for embedding cache parameter.
45   void SetEmbedCachedParamAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
46   // Set data set channel name, used for multi dataset mode, such as predict after train.
47   void SetDataSetChannel(const std::string &actor_id, const std::vector<KernelGraphPtr> &graphs);
48 
49   // Initialize all embedding storage instances.
50   void InitEmbeddingStorage(const std::vector<AnfNodePtr> &parameters) const;
51 
52   // 1. Build network connection between local and remote cache for embedding cache prefetch actor.
53   // 2. Schedule and Run embedding cache prefetch actor.
54   // Since the embedding cache prefetch actor is spinning, and the actor is not in the actor set, start the actor in the
55   // Schedule interface.
56   void Schedule();
57 
58   // Record the number of global steps executed by the compute graph.
59   void IncreaseGraphStep(const std::string &actor_id) const;
60 
61   // Synchronize latest embedding table in local cache to remote.
62   void SyncEmbeddingTable() const;
63 
64   // Finalize embedding cache prefetch actor.
65   void Finalize(bool sync_embedding_table = true);
66 
67  private:
68   EmbeddingCacheScheduler() = default;
69   ~EmbeddingCacheScheduler() = default;
70   DISABLE_COPY_AND_ASSIGN(EmbeddingCacheScheduler);
71 
72   // Get ids number in a batch, not batch size.
73   void ParseBatchIdsNum(const KernelGraphPtr &graph);
74 
75   // Allocate device and local host memory for embedding cache table.
76   void AllocMemForEmbeddingCacheTable(const DeviceContext *device_context);
77 
78   // Embedding cache prefetch actor.
79   std::shared_ptr<EmbeddingCachePrefetchActor> embedding_cache_prefetch_actor_;
80 
81   // The flag indicates whether already parse batch ids number.
82   bool parsed_batch_ids_num_{false};
83 
84   // The flag indicates whether already allocate memory for embedding cache tables.
85   bool allocated_embed_cache_mem_{false};
86 
87   // The flag indicates whether the EmbeddingCacheScheduler is initialized.
88   bool initialized_{false};
89   // The flag indicates whether the EmbeddingCacheScheduler is scheduled.
90   bool scheduled_{false};
91   // The flag indicates whether the EmbeddingCacheScheduler is finalized.
92   bool finalized_{false};
93   // Ensure that the Finalize function is multithreaded safe.
94   std::mutex finalize_mutex_;
95 
96   // Record data set channel name, used for multi dataset mode, such as predict after train.
97   // Key: data prepare actor id for an actor set, Value: data set channel name.
98   mindspore::HashMap<std::string, std::string> data_prepare_aid_to_data_channel_;
99 };
100 }  // namespace runtime
101 }  // namespace mindspore
102 
103 #endif  // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_EMBEDDING_CACHE_SCHEDULER_H_
104