• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_EMBEDDING_OPERATION_H_
18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_EMBEDDING_OPERATION_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <utility>
23 #include <set>
24 #include "runtime/hardware/device_context.h"
25 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
26 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h"
27 
28 namespace mindspore {
29 namespace runtime {
30 // One and two dimensional shape placeholder.
31 const ShapeVector kOneDimensionalShape = {1};
32 const ShapeVector kTwoDimensionalShape = {1, 1};
33 
34 const size_t kInputIndexZero = 0;
35 const size_t kInputIndexOne = 1;
36 const size_t kInputIndexTwo = 2;
37 
38 const size_t kCacheOpInputNum = 3;
39 const size_t kCacheOpOutputNum = 1;
40 
41 using device::DeviceContext;
42 using distributed::EmbeddingCacheStatisticsInfo;
43 using distributed::EmbeddingDeviceCache;
44 using distributed::EmbeddingHostCache;
45 using distributed::HashTableInfo;
46 using mindspore::session::KernelGraph;
47 
48 // Maximum number of threads for concurrent accelerated cache processing.
49 using distributed::kMaxThreadNum;
50 // Maximum number of feature ids processed per thread.
51 using distributed::kMaxIdsPerThread;
52 // Maximum number for retry find valid slot in device or host cache.
53 constexpr size_t kMaxRetryNum = 12000;
54 
55 class DeviceEmbeddingOperation {
56  public:
DeviceEmbeddingOperation(EmbeddingCachePrefetchActor * actor,device::DeviceContext * device_context,const std::pair<int,int> & local_embedding_slice_bounds,const std::pair<int,int> & local_device_cache_bounds,EmbeddingCacheStatisticsInfo * statistics_info,const size_t & stream_id)57   DeviceEmbeddingOperation(EmbeddingCachePrefetchActor *actor, device::DeviceContext *device_context,
58                            const std::pair<int, int> &local_embedding_slice_bounds,
59                            const std::pair<int, int> &local_device_cache_bounds,
60                            EmbeddingCacheStatisticsInfo *statistics_info, const size_t &stream_id)
61       : actor_(actor),
62         device_context_(device_context),
63         local_embedding_slice_bounds_(local_embedding_slice_bounds),
64         local_device_cache_bounds_(local_device_cache_bounds),
65         statistics_info_(statistics_info),
66         stream_id_(stream_id) {}
67 
68   virtual ~DeviceEmbeddingOperation() = default;
69 
70   virtual bool Initialize();
71 
72   // Analyze the hit/miss info of the local host cache and device cache, and calculate the swapping and
73   // mapping information of the missing feature id that needs to be inserted into the cache.
AnalyseCache(int * batch_ids,const size_t batch_ids_num,size_t data_step,const std::atomic_ulong * graph_running_step,bool * device_cache_need_wait_graph,bool * host_cache_need_wait_graph,int * indices,EmbeddingDeviceCache * embedding_device_cache,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)74   virtual bool AnalyseCache(int *batch_ids, const size_t batch_ids_num, size_t data_step,
75                             const std::atomic_ulong *graph_running_step, bool *device_cache_need_wait_graph,
76                             bool *host_cache_need_wait_graph, int *indices,
77                             EmbeddingDeviceCache *embedding_device_cache, EmbeddingHostCache *embedding_host_cache,
78                             EmbeddingCacheStatisticsInfo *statistics_info) {
79     return true;
80   }
81 
82   // Pull missing embeddings on the device cache from the local host.
83   virtual bool PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) = 0;
84 
85   // Push non-hotspot embeddings on the device cache to the local host cache.
86   virtual bool PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) = 0;
87 
88   // Get the id range of each server's embedding table slice.
89   virtual void GetRemoteEmbeddingSliceBound(size_t vocab_size, size_t server_num,
90                                             std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) = 0;
91 
92   // Async copy host memory to device.
93   static bool MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context,
94                                       size_t stream_id);
95 
96   // Async copy device memory to host.
97   static bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context,
98                                       size_t stream_id);
99 
100   // Get all modified ids.
modified_ids()101   const mindspore::HashSet<int> &modified_ids() const { return modified_ids_; }
102 
103  protected:
104   // Parse the hit and swap out to device cache information of the currently preprocessed id of the local host cache.
105   bool ParseHostDataHostToDevice(int id, size_t data_step, size_t *cur_graph_running_step,
106                                  const std::atomic_ulong *latest_graph_running_step, bool *host_cache_need_wait_graph,
107                                  EmbeddingHostCache *embedding_host_cache,
108                                  EmbeddingCacheStatisticsInfo *statistics_info);
109 
110   // Parse the swap in information from device cache of the currently preprocessed id of the local host cache.
111   bool ParseHostDataDeviceToHost(size_t data_step, size_t *cur_graph_running_step,
112                                  const std::atomic_ulong *latest_graph_running_step, bool *host_cache_need_wait_graph,
113                                  EmbeddingDeviceCache *embedding_device_cache, EmbeddingHostCache *embedding_host_cache,
114                                  EmbeddingCacheStatisticsInfo *statistics_info);
115 
116   // Build a CNode of embedding cache look up kernel, which is used to look up local device
117   // embedding cache.
118   virtual void BuildEmbeddingCacheLookupKernel() = 0;
119 
120   // Build a CNode of embedding cache update kernel, which is used to update local
121   // device embedding cache.
122   virtual void BuildEmbeddingCacheUpdateKernel() = 0;
123 
124   static ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const ShapeVector &shape);
125 
126   static ValueNodePtr NewValueNode(int64_t value, const DeviceContext *device_context, size_t stream_id);
127 
128   static bool InferOpShape(const CNodePtr &kernel, const std::vector<kernel::KernelTensor *> &input_kernel_tensors,
129                            const std::vector<kernel::KernelTensor *> &output_kernel_tensors,
130                            const std::vector<abstract::AbstractBasePtr> &output_kernel_tensors_for_iner);
131 
132   // The actor which owns this operation.
133   EmbeddingCachePrefetchActor *actor_;
134 
135   // The device interface.
136   device::DeviceContext *device_context_;
137 
138   // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range
139   // corresponding to the embedding table slice of the process.
140   std::pair<int, int> local_embedding_slice_bounds_;
141 
142   // Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache
143   // range corresponding to the embedding table slice of the process.
144   std::pair<int, int> local_device_cache_bounds_;
145 
146   // The embedding cache look up kernel node(operator name: 'Gather' for dense mode and 'MapTensorGet' for sparse mode).
147   CNodePtr embedding_cache_lookup_node_{nullptr};
148 
149   // The embedding cache update kernel node(operator name: 'ScatterUpdate' for dense mode and 'MapTensorPut' for sparse
150   // mode).
151   CNodePtr embedding_cache_update_node_{nullptr};
152 
153   // The feature ids that have been initialized already.
154   mindspore::HashSet<int> initialized_ids_;
155 
156   // The feature ids whose embedding vectors are modified(trained).
157   mindspore::HashSet<int> modified_ids_;
158 
159   // Statistics on the cache hit rate of the host and device and the information used to update cache.
160   EmbeddingCacheStatisticsInfo *statistics_info_;
161 
162   // Cache embedding cache ops kernel graphs.
163   std::vector<KernelGraphPtr> embedding_cache_graphs_;
164 
165   // The device stream used to async memcpy operators and launch device kernels, such as embedding cache look up and
166   // update kernel.
167   size_t stream_id_{0};
168 
169  private:
170   DISABLE_COPY_AND_ASSIGN(DeviceEmbeddingOperation);
171 };
172 }  // namespace runtime
173 }  // namespace mindspore
174 #endif  // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_EMBEDDING_OPERATION_H_
175