1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_SPARSE_EMBEDDING_OPERATION_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_SPARSE_EMBEDDING_OPERATION_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <utility> 23 #include "runtime/graph_scheduler/actor/embedding_cache/device_embedding_operation.h" 24 #include "include/backend/device_address.h" 25 26 namespace mindspore { 27 namespace runtime { 28 using device::DeviceAddress; 29 30 class DeviceSparseEmbeddingOperation : public DeviceEmbeddingOperation { 31 public: DeviceSparseEmbeddingOperation(EmbeddingCachePrefetchActor * actor,device::DeviceContext * device_context,const std::pair<int,int> & local_embedding_slice_bounds,const std::pair<int,int> & local_device_cache_bounds,EmbeddingCacheStatisticsInfo * statistics_info,const size_t & stream_id)32 DeviceSparseEmbeddingOperation(EmbeddingCachePrefetchActor *actor, device::DeviceContext *device_context, 33 const std::pair<int, int> &local_embedding_slice_bounds, 34 const std::pair<int, int> &local_device_cache_bounds, 35 EmbeddingCacheStatisticsInfo *statistics_info, const size_t &stream_id) 36 : DeviceEmbeddingOperation(actor, device_context, local_embedding_slice_bounds, local_device_cache_bounds, 37 statistics_info, stream_id) {} 38 39 ~DeviceSparseEmbeddingOperation() override = default; 40 41 bool Initialize() override; 42 43 // Push non-hotspot embeddings on the device cache to the local host cache. 44 bool PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override; 45 46 // Pull missing embeddings on the device cache from the local host. 47 bool PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info, const CacheAnalysis *cache_analysis) override; 48 49 // Get the id range of each server's embedding table slice. 50 void GetRemoteEmbeddingSliceBound(size_t vocab_size, size_t server_num, 51 std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) override; 52 53 protected: 54 // Build a CNode of embedding cache look up kernel(operator name: 'MapTensorGet'), which is used to look up local 55 // device embedding cache. 56 void BuildEmbeddingCacheLookupKernel() override; 57 // Build a CNode of embedding cache update kernel(operator name: 'MapTensorPut'), which is used to update local 58 // device embedding cache. 59 void BuildEmbeddingCacheUpdateKernel() override; 60 61 private: 62 // Build a CNode of embedding cache erase kernel(operator name: 'MapTensorErase'), which is used to erase local 63 // device embedding cache. 64 void BuildEmbeddingCacheEraseKernel(); 65 66 static ParameterPtr NewMapParameter(const KernelGraphPtr &graph, TypeId key_type, TypeId value_type, 67 const ShapeVector &value_shape); 68 69 // Look up feature weights on Device Embedding Cache. LookupDeviceCache(const DeviceAddress * embed_device_address,void * ids,void * embedding_cache,size_t ids_num,size_t embedding_size,void * outputs)70 bool LookupDeviceCache(const DeviceAddress *embed_device_address, void *ids, void *embedding_cache, size_t ids_num, 71 size_t embedding_size, void *outputs) { 72 MS_LOG(EXCEPTION) << "Not implemented function."; 73 } 74 75 // Update feature weights on Device Embedding Cache. UpdateDeviceCache(void * ids,void * update_value,size_t indices_num,size_t embedding_size,void * embedding_cache,const DeviceAddress * embed_device_address)76 bool UpdateDeviceCache(void *ids, void *update_value, size_t indices_num, size_t embedding_size, 77 void *embedding_cache, const DeviceAddress *embed_device_address) { 78 MS_LOG(EXCEPTION) << "Not implemented function."; 79 } 80 81 // Erase feature embeddings on device embedding cache. EraseDeviceCache(void * ids,size_t ids_num,void * embedding_cache,const DeviceAddress * embed_device_address)82 bool EraseDeviceCache(void *ids, size_t ids_num, void *embedding_cache, const DeviceAddress *embed_device_address) { 83 MS_LOG(EXCEPTION) << "Not implemented function."; 84 } 85 86 // The embedding cache erase kernel node(operator name: 'MapTensorErase'). 87 CNodePtr embedding_cache_erase_node_{nullptr}; 88 89 DISABLE_COPY_AND_ASSIGN(DeviceSparseEmbeddingOperation); 90 }; 91 } // namespace runtime 92 } // namespace mindspore 93 #endif // MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_ACTOR_EMBEDDING_CACHE_DEVICE_SPARSE_EMBEDDING_OPERATION_H_ 94