1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_DENSE_EMBEDDING_OPERATION_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_DENSE_EMBEDDING_OPERATION_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <utility> 23 #include "runtime/graph_scheduler/actor/embedding_cache/device_embedding_operation.h" 24 25 namespace mindspore { 26 namespace runtime { 27 class DeviceDenseEmbeddingOperation : public DeviceEmbeddingOperation { 28 public: DeviceDenseEmbeddingOperation(EmbeddingCachePrefetchActor * actor,device::DeviceContext * device_context,const std::pair<int,int> & local_embedding_slice_bounds,const std::pair<int,int> & local_device_cache_bounds,EmbeddingCacheStatisticsInfo * statistics_info,const size_t & stream_id)29 DeviceDenseEmbeddingOperation(EmbeddingCachePrefetchActor *actor, device::DeviceContext *device_context, 30 const std::pair<int, int> &local_embedding_slice_bounds, 31 const std::pair<int, int> &local_device_cache_bounds, 32 EmbeddingCacheStatisticsInfo *statistics_info, const size_t &stream_id) 33 : DeviceEmbeddingOperation(actor, device_context, local_embedding_slice_bounds, local_device_cache_bounds, 34 statistics_info, stream_id) {} 35 36 ~DeviceDenseEmbeddingOperation() override = default; 37 38 bool AnalyseCache(int *batch_ids, const size_t batch_ids_num, size_t data_step, 39 const std::atomic_ulong *graph_running_step, bool *device_cache_need_wait_graph, 40 bool *host_cache_need_wait_graph, int *indices, EmbeddingDeviceCache *embedding_device_cache, 41 EmbeddingHostCache *embedding_host_cache, EmbeddingCacheStatisticsInfo *statistics_info) override; 42 43 // Push non-hotspot embeddings on the device cache to the local host cache. 44 bool PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override; 45 46 // Pull missing embeddings on the device cache from the local host. 47 bool PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override; 48 49 // Get the id range of each server's embedding table slice. 50 void GetRemoteEmbeddingSliceBound(size_t vocab_size, size_t server_num, 51 std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) override; 52 53 protected: 54 // Build a CNode of embedding cache look up kernel(operator name: 'EmbeddingLookup'), which is used to look up local 55 // device embedding cache. 56 void BuildEmbeddingCacheLookupKernel() override; 57 // Build a CNode of embedding cache update kernel(operator name: 'ScatterUpdate'), which is used to update local 58 // device embedding cache. 59 void BuildEmbeddingCacheUpdateKernel() override; 60 61 private: 62 // Look up feature weights on Device Embedding Cache: 63 // 1. Update the shape of parameter node. 64 // 2. Infer shape for embedding cache look up kernel(operator name: 'EmbeddingLookup'). 65 // 3. Launch embedding cache look up kernel. 66 bool LookupDeviceCache(void *indices, void *embedding_cache, size_t indices_num, size_t cache_size, 67 size_t embedding_size, void *outputs); 68 69 // Update feature weights on Device Embedding Cache: 70 // 1. Update the shape of parameter node. 71 // 2. Infer shape for embedding cache update kernel(operator name: 'ScatterUpdate'). 72 // 3. Launch embedding cache update kernel. 73 bool UpdateDeviceCache(void *indices, void *update_value, size_t indices_num, size_t cache_size, 74 size_t embedding_size, void *embedding_cache); 75 76 // Batch preprocess the current batch ids information of cache hitting or exceeding the range of the embedding table 77 // slice corresponding to the process. 78 bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *out_range, 79 size_t data_step); 80 81 // Thread execution function of method 'CheckCacheHitOrOutRange'. 82 bool CheckCacheHitOrOutRangeFunc(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *out_range, 83 size_t *hash_hit_count, size_t data_step); 84 85 // Parse the hit and swap information of the currently preprocessed id in the device cache. 86 bool ParseDeviceData(int id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index, 87 size_t data_step, size_t *cur_graph_running_step, 88 const std::atomic_ulong *latest_graph_running_step, bool *device_cache_need_wait_graph, 89 EmbeddingDeviceCache *embedding_device_cache, EmbeddingCacheStatisticsInfo *statistics_info); 90 91 DISABLE_COPY_AND_ASSIGN(DeviceDenseEmbeddingOperation); 92 }; 93 } // namespace runtime 94 } // namespace mindspore 95 #endif // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_DENSE_EMBEDDING_OPERATION_H_ 96