• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_SPARSE_EMBEDDING_OPERATION_H_
18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_SPARSE_EMBEDDING_OPERATION_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <utility>
23 #include "runtime/graph_scheduler/actor/embedding_cache/device_embedding_operation.h"
24 #include "include/backend/device_address.h"
25 
26 namespace mindspore {
27 namespace runtime {
28 using device::DeviceAddress;
29 
30 class DeviceSparseEmbeddingOperation : public DeviceEmbeddingOperation {
31  public:
DeviceSparseEmbeddingOperation(EmbeddingCachePrefetchActor * actor,device::DeviceContext * device_context,const std::pair<int,int> & local_embedding_slice_bounds,const std::pair<int,int> & local_device_cache_bounds,EmbeddingCacheStatisticsInfo * statistics_info,const size_t & stream_id)32   DeviceSparseEmbeddingOperation(EmbeddingCachePrefetchActor *actor, device::DeviceContext *device_context,
33                                  const std::pair<int, int> &local_embedding_slice_bounds,
34                                  const std::pair<int, int> &local_device_cache_bounds,
35                                  EmbeddingCacheStatisticsInfo *statistics_info, const size_t &stream_id)
36       : DeviceEmbeddingOperation(actor, device_context, local_embedding_slice_bounds, local_device_cache_bounds,
37                                  statistics_info, stream_id) {}
38 
39   ~DeviceSparseEmbeddingOperation() override = default;
40 
41   bool Initialize() override;
42 
43   // Push non-hotspot embeddings on the device cache to the local host cache.
44   bool PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override;
45 
46   // Pull missing embeddings on the device cache from the local host.
47   bool PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override;
48 
49   // Get the id range of each server's embedding table slice.
50   void GetRemoteEmbeddingSliceBound(size_t vocab_size, size_t server_num,
51                                     std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) override;
52 
53  protected:
54   // Build a CNode of embedding cache look up kernel(operator name: 'MapTensorGet'), which is used to look up local
55   // device embedding cache.
56   void BuildEmbeddingCacheLookupKernel() override;
57   // Build a CNode of embedding cache update kernel(operator name: 'MapTensorPut'), which is used to update local
58   // device embedding cache.
59   void BuildEmbeddingCacheUpdateKernel() override;
60 
61  private:
62   // Build a CNode of embedding cache erase kernel(operator name: 'MapTensorErase'), which is used to erase local
63   // device embedding cache.
64   void BuildEmbeddingCacheEraseKernel();
65 
66   static ParameterPtr NewMapParameter(const KernelGraphPtr &graph, TypeId key_type, TypeId value_type,
67                                       const ShapeVector &value_shape);
68 
69   // Look up feature weights on Device Embedding Cache.
LookupDeviceCache(const DeviceAddress * embed_device_address,void * ids,void * embedding_cache,size_t ids_num,size_t embedding_size,void * outputs)70   bool LookupDeviceCache(const DeviceAddress *embed_device_address, void *ids, void *embedding_cache, size_t ids_num,
71                          size_t embedding_size, void *outputs) {
72     MS_LOG(EXCEPTION) << "Not implemented function.";
73   }
74 
75   // Update feature weights on Device Embedding Cache.
UpdateDeviceCache(void * ids,void * update_value,size_t indices_num,size_t embedding_size,void * embedding_cache,const DeviceAddress * embed_device_address)76   bool UpdateDeviceCache(void *ids, void *update_value, size_t indices_num, size_t embedding_size,
77                          void *embedding_cache, const DeviceAddress *embed_device_address) {
78     MS_LOG(EXCEPTION) << "Not implemented function.";
79   }
80 
81   // Erase feature embeddings on device embedding cache.
EraseDeviceCache(void * ids,size_t ids_num,void * embedding_cache,const DeviceAddress * embed_device_address)82   bool EraseDeviceCache(void *ids, size_t ids_num, void *embedding_cache, const DeviceAddress *embed_device_address) {
83     MS_LOG(EXCEPTION) << "Not implemented function.";
84   }
85 
86   // The embedding cache erase kernel node(operator name: 'MapTensorErase').
87   CNodePtr embedding_cache_erase_node_{nullptr};
88 
89   DISABLE_COPY_AND_ASSIGN(DeviceSparseEmbeddingOperation);
90 };
91 }  // namespace runtime
92 }  // namespace mindspore
93 #endif  // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_SPARSE_EMBEDDING_OPERATION_H_
94