• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_DENSE_EMBEDDING_OPERATION_H_
18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_DENSE_EMBEDDING_OPERATION_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <utility>
23 #include "runtime/graph_scheduler/actor/embedding_cache/device_embedding_operation.h"
24 
25 namespace mindspore {
26 namespace runtime {
27 class DeviceDenseEmbeddingOperation : public DeviceEmbeddingOperation {
28  public:
DeviceDenseEmbeddingOperation(EmbeddingCachePrefetchActor * actor,device::DeviceContext * device_context,const std::pair<int,int> & local_embedding_slice_bounds,const std::pair<int,int> & local_device_cache_bounds,EmbeddingCacheStatisticsInfo * statistics_info,const size_t & stream_id)29   DeviceDenseEmbeddingOperation(EmbeddingCachePrefetchActor *actor, device::DeviceContext *device_context,
30                                 const std::pair<int, int> &local_embedding_slice_bounds,
31                                 const std::pair<int, int> &local_device_cache_bounds,
32                                 EmbeddingCacheStatisticsInfo *statistics_info, const size_t &stream_id)
33       : DeviceEmbeddingOperation(actor, device_context, local_embedding_slice_bounds, local_device_cache_bounds,
34                                  statistics_info, stream_id) {}
35 
36   ~DeviceDenseEmbeddingOperation() override = default;
37 
38   bool AnalyseCache(int *batch_ids, const size_t batch_ids_num, size_t data_step,
39                     const std::atomic_ulong *graph_running_step, bool *device_cache_need_wait_graph,
40                     bool *host_cache_need_wait_graph, int *indices, EmbeddingDeviceCache *embedding_device_cache,
41                     EmbeddingHostCache *embedding_host_cache, EmbeddingCacheStatisticsInfo *statistics_info) override;
42 
43   // Push non-hotspot embeddings on the device cache to the local host cache.
44   bool PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override;
45 
46   // Pull missing embeddings on the device cache from the local host.
47   bool PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override;
48 
49   // Get the id range of each server's embedding table slice.
50   void GetRemoteEmbeddingSliceBound(size_t vocab_size, size_t server_num,
51                                     std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) override;
52 
53  protected:
54   // Build a CNode of embedding cache look up kernel(operator name: 'EmbeddingLookup'), which is used to look up local
55   // device embedding cache.
56   void BuildEmbeddingCacheLookupKernel() override;
57   // Build a CNode of embedding cache update kernel(operator name: 'ScatterUpdate'), which is used to update local
58   // device embedding cache.
59   void BuildEmbeddingCacheUpdateKernel() override;
60 
61  private:
62   // Look up feature weights on Device Embedding Cache:
63   // 1. Update the shape of parameter node.
64   // 2. Infer shape for embedding cache look up kernel(operator name: 'EmbeddingLookup').
65   // 3. Launch embedding cache look up kernel.
66   bool LookupDeviceCache(void *indices, void *embedding_cache, size_t indices_num, size_t cache_size,
67                          size_t embedding_size, void *outputs);
68 
69   // Update feature weights on Device Embedding Cache:
70   // 1. Update the shape of parameter node.
71   // 2. Infer shape for embedding cache update kernel(operator name: 'ScatterUpdate').
72   // 3. Launch embedding cache update kernel.
73   bool UpdateDeviceCache(void *indices, void *update_value, size_t indices_num, size_t cache_size,
74                          size_t embedding_size, void *embedding_cache);
75 
76   // Batch preprocess the current batch ids information of cache hitting or exceeding the range of the embedding table
77   // slice corresponding to the process.
78   bool CheckCacheHitOrOutRange(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *out_range,
79                                size_t data_step);
80 
81   // Thread execution function of method 'CheckCacheHitOrOutRange'.
82   bool CheckCacheHitOrOutRangeFunc(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *out_range,
83                                    size_t *hash_hit_count, size_t data_step);
84 
85   // Parse the hit and swap information of the currently preprocessed id in the device cache.
86   bool ParseDeviceData(int id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index,
87                        size_t data_step, size_t *cur_graph_running_step,
88                        const std::atomic_ulong *latest_graph_running_step, bool *device_cache_need_wait_graph,
89                        EmbeddingDeviceCache *embedding_device_cache, EmbeddingCacheStatisticsInfo *statistics_info);
90 
91   DISABLE_COPY_AND_ASSIGN(DeviceDenseEmbeddingOperation);
92 };
93 }  // namespace runtime
94 }  // namespace mindspore
95 #endif  // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_DENSE_EMBEDDING_OPERATION_H_
96