1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_EMBEDDING_OPERATION_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_EMBEDDING_OPERATION_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <utility> 23 #include <set> 24 #include "runtime/hardware/device_context.h" 25 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h" 26 #include "runtime/graph_scheduler/actor/embedding_cache/embedding_cache_prefetch_actor.h" 27 28 namespace mindspore { 29 namespace runtime { 30 // One and two dimensional shape placeholder. 31 const ShapeVector kOneDimensionalShape = {1}; 32 const ShapeVector kTwoDimensionalShape = {1, 1}; 33 34 const size_t kInputIndexZero = 0; 35 const size_t kInputIndexOne = 1; 36 const size_t kInputIndexTwo = 2; 37 38 const size_t kCacheOpInputNum = 3; 39 const size_t kCacheOpOutputNum = 1; 40 41 using device::DeviceContext; 42 using distributed::EmbeddingCacheStatisticsInfo; 43 using distributed::EmbeddingDeviceCache; 44 using distributed::EmbeddingHostCache; 45 using distributed::HashTableInfo; 46 using mindspore::session::KernelGraph; 47 48 // Maximum number of threads for concurrent accelerated cache processing. 49 using distributed::kMaxThreadNum; 50 // Maximum number of feature ids processed per thread. 51 using distributed::kMaxIdsPerThread; 52 // Maximum number for retry find valid slot in device or host cache. 53 constexpr size_t kMaxRetryNum = 12000; 54 55 class DeviceEmbeddingOperation { 56 public: DeviceEmbeddingOperation(EmbeddingCachePrefetchActor * actor,device::DeviceContext * device_context,const std::pair<int,int> & local_embedding_slice_bounds,const std::pair<int,int> & local_device_cache_bounds,EmbeddingCacheStatisticsInfo * statistics_info,const size_t & stream_id)57 DeviceEmbeddingOperation(EmbeddingCachePrefetchActor *actor, device::DeviceContext *device_context, 58 const std::pair<int, int> &local_embedding_slice_bounds, 59 const std::pair<int, int> &local_device_cache_bounds, 60 EmbeddingCacheStatisticsInfo *statistics_info, const size_t &stream_id) 61 : actor_(actor), 62 device_context_(device_context), 63 local_embedding_slice_bounds_(local_embedding_slice_bounds), 64 local_device_cache_bounds_(local_device_cache_bounds), 65 statistics_info_(statistics_info), 66 stream_id_(stream_id) {} 67 68 virtual ~DeviceEmbeddingOperation() = default; 69 70 virtual bool Initialize(); 71 72 // Analyze the hit/miss info of the local host cache and device cache, and calculate the swapping and 73 // mapping information of the missing feature id that needs to be inserted into the cache. AnalyseCache(int * batch_ids,const size_t batch_ids_num,size_t data_step,const std::atomic_ulong * graph_running_step,bool * device_cache_need_wait_graph,bool * host_cache_need_wait_graph,int * indices,EmbeddingDeviceCache * embedding_device_cache,EmbeddingHostCache * embedding_host_cache,EmbeddingCacheStatisticsInfo * statistics_info)74 virtual bool AnalyseCache(int *batch_ids, const size_t batch_ids_num, size_t data_step, 75 const std::atomic_ulong *graph_running_step, bool *device_cache_need_wait_graph, 76 bool *host_cache_need_wait_graph, int *indices, 77 EmbeddingDeviceCache *embedding_device_cache, EmbeddingHostCache *embedding_host_cache, 78 EmbeddingCacheStatisticsInfo *statistics_info) { 79 return true; 80 } 81 82 // Pull missing embeddings on the device cache from the local host. 83 virtual bool PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) = 0; 84 85 // Push non-hotspot embeddings on the device cache to the local host cache. 86 virtual bool PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) = 0; 87 88 // Get the id range of each server's embedding table slice. 89 virtual void GetRemoteEmbeddingSliceBound(size_t vocab_size, size_t server_num, 90 std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) = 0; 91 92 // Async copy host memory to device. 93 static bool MemcpyHostToDeviceAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context, 94 size_t stream_id); 95 96 // Async copy device memory to host. 97 static bool MemcpyDeviceToHostAsync(void *dst, const void *src, size_t size, const DeviceContext *device_context, 98 size_t stream_id); 99 100 // Get all modified ids. modified_ids()101 const mindspore::HashSet<int> &modified_ids() const { return modified_ids_; } 102 103 protected: 104 // Parse the hit and swap out to device cache information of the currently preprocessed id of the local host cache. 105 bool ParseHostDataHostToDevice(int id, size_t data_step, size_t *cur_graph_running_step, 106 const std::atomic_ulong *latest_graph_running_step, bool *host_cache_need_wait_graph, 107 EmbeddingHostCache *embedding_host_cache, 108 EmbeddingCacheStatisticsInfo *statistics_info); 109 110 // Parse the swap in information from device cache of the currently preprocessed id of the local host cache. 111 bool ParseHostDataDeviceToHost(size_t data_step, size_t *cur_graph_running_step, 112 const std::atomic_ulong *latest_graph_running_step, bool *host_cache_need_wait_graph, 113 EmbeddingDeviceCache *embedding_device_cache, EmbeddingHostCache *embedding_host_cache, 114 EmbeddingCacheStatisticsInfo *statistics_info); 115 116 // Build a CNode of embedding cache look up kernel, which is used to look up local device 117 // embedding cache. 118 virtual void BuildEmbeddingCacheLookupKernel() = 0; 119 120 // Build a CNode of embedding cache update kernel, which is used to update local 121 // device embedding cache. 122 virtual void BuildEmbeddingCacheUpdateKernel() = 0; 123 124 static ParameterPtr NewParameter(const KernelGraphPtr &graph, TypePtr type, const ShapeVector &shape); 125 126 static ValueNodePtr NewValueNode(int64_t value, const DeviceContext *device_context, size_t stream_id); 127 128 static bool InferOpShape(const CNodePtr &kernel, const std::vector<kernel::KernelTensor *> &input_kernel_tensors, 129 const std::vector<kernel::KernelTensor *> &output_kernel_tensors, 130 const std::vector<abstract::AbstractBasePtr> &output_kernel_tensors_for_iner); 131 132 // The actor which owns this operation. 133 EmbeddingCachePrefetchActor *actor_; 134 135 // The device interface. 136 device::DeviceContext *device_context_; 137 138 // Model parallelism is used between multiple workers, and local_embedding_slice_bounds_ records the feature range 139 // corresponding to the embedding table slice of the process. 140 std::pair<int, int> local_embedding_slice_bounds_; 141 142 // Model parallelism is used between multiple workers, and local_device_cache_bounds_ records the local device cache 143 // range corresponding to the embedding table slice of the process. 144 std::pair<int, int> local_device_cache_bounds_; 145 146 // The embedding cache look up kernel node(operator name: 'Gather' for dense mode and 'MapTensorGet' for sparse mode). 147 CNodePtr embedding_cache_lookup_node_{nullptr}; 148 149 // The embedding cache update kernel node(operator name: 'ScatterUpdate' for dense mode and 'MapTensorPut' for sparse 150 // mode). 151 CNodePtr embedding_cache_update_node_{nullptr}; 152 153 // The feature ids that have been initialized already. 154 mindspore::HashSet<int> initialized_ids_; 155 156 // The feature ids whose embedding vectors are modified(trained). 157 mindspore::HashSet<int> modified_ids_; 158 159 // Statistics on the cache hit rate of the host and device and the information used to update cache. 160 EmbeddingCacheStatisticsInfo *statistics_info_; 161 162 // Cache embedding cache ops kernel graphs. 163 std::vector<KernelGraphPtr> embedding_cache_graphs_; 164 165 // The device stream used to async memcpy operators and launch device kernels, such as embedding cache look up and 166 // update kernel. 167 size_t stream_id_{0}; 168 169 private: 170 DISABLE_COPY_AND_ASSIGN(DeviceEmbeddingOperation); 171 }; 172 } // namespace runtime 173 } // namespace mindspore 174 #endif // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_EMBEDDING_OPERATION_H_ 175