• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <string>
19 #include <limits>
20 #include "runtime/graph_scheduler/actor/embedding_cache/device_sparse_embedding_operation.h"
21 #include "ops/sparse_op_name.h"
22 #include "ir/map_tensor.h"
23 
24 namespace mindspore {
25 namespace runtime {
Initialize()26 bool DeviceSparseEmbeddingOperation::Initialize() {
27   RETURN_IF_FALSE_WITH_LOG(DeviceEmbeddingOperation::Initialize(), "Initialize device embedding operation failed.");
28   BuildEmbeddingCacheEraseKernel();
29   return true;
30 }
31 
PushCacheFromDeviceToLocalHost(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)32 bool DeviceSparseEmbeddingOperation::PushCacheFromDeviceToLocalHost(const HashTableInfo &hash_info,
33                                                                     const CacheAnalysis *cache_analysis) {
34   MS_ERROR_IF_NULL(cache_analysis);
35   auto statistics_info = cache_analysis->statistics_info_;
36   auto embedding_device_cache = cache_analysis->embedding_device_cache_;
37   auto embedding_host_cache = cache_analysis->embedding_host_cache_;
38   MS_ERROR_IF_NULL(statistics_info);
39   MS_ERROR_IF_NULL(embedding_device_cache);
40   MS_ERROR_IF_NULL(embedding_host_cache);
41 
42   auto swap_indices_size = statistics_info->device_to_host_size_;
43   if (swap_indices_size == 0) {
44     return true;
45   }
46 
47   auto device_cache_device_to_host_ids = embedding_device_cache->device_to_host_ids.get();
48   auto host_cache_device_to_host_index = embedding_host_cache->device_to_host_index.get();
49   MS_ERROR_IF_NULL(device_cache_device_to_host_ids);
50   MS_ERROR_IF_NULL(host_cache_device_to_host_index);
51   auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
52   auto host_hash_table_addr = hash_info.host_address;
53   auto embedding_size = hash_info.embedding_size;
54   auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
55 
56   // Copy origin id to temp buffer of indices.
57   int *tmp_swap_ids = embedding_cache_table_manager.hash_swap_index_addr_;
58   RETURN_IF_FALSE_WITH_LOG(MemcpyHostToDeviceAsync(tmp_swap_ids, device_cache_device_to_host_ids,
59                                                    swap_indices_size * sizeof(int), device_context_, stream_id_),
60                            "Memcpy host to device asynchronously failed.");
61 
62   RETURN_IF_FALSE_WITH_LOG(LookupDeviceCache(hash_info.device_address, tmp_swap_ids, hash_table_addr, swap_indices_size,
63                                              embedding_size, embedding_cache_table_manager.hash_swap_value_addr_),
64                            "Lookup device cache failed.");
65 
66   // Erase swap out id from device hash table.
67   RETURN_IF_FALSE_WITH_LOG(EraseDeviceCache(tmp_swap_ids, swap_indices_size, hash_table_addr, hash_info.device_address),
68                            "Erase device cache failed");
69 
70   RETURN_IF_FALSE_WITH_LOG(
71     MemcpyDeviceToHostAsync(swap_out_data.get(), embedding_cache_table_manager.hash_swap_value_addr_,
72                             swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
73     "Memcpy device to host asynchronously failed.");
74 
75   MS_ERROR_IF_NULL(device_context_);
76   MS_ERROR_IF_NULL(device_context_->device_res_manager_);
77   RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
78   RETURN_IF_FALSE_WITH_LOG(
79     actor_->InsertLocalHostCache(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
80                                  swap_out_data.get(), host_hash_table_addr),
81     "Insert local host cache failed.");
82   return true;
83 }
84 
PullCacheFromLocalHostToDevice(const HashTableInfo & hash_info,const CacheAnalysis * cache_analysis)85 bool DeviceSparseEmbeddingOperation::PullCacheFromLocalHostToDevice(const HashTableInfo &hash_info,
86                                                                     const CacheAnalysis *cache_analysis) {
87   MS_ERROR_IF_NULL(cache_analysis);
88   auto statistics_info = cache_analysis->statistics_info_;
89   auto embedding_device_cache = cache_analysis->embedding_device_cache_;
90   auto embedding_host_cache = cache_analysis->embedding_host_cache_;
91   MS_ERROR_IF_NULL(statistics_info);
92   MS_ERROR_IF_NULL(embedding_device_cache);
93   MS_ERROR_IF_NULL(embedding_host_cache);
94 
95   auto swap_indices_size = statistics_info->host_to_device_size_;
96   if (swap_indices_size == 0) {
97     return true;
98   }
99 
100   auto host_cache_host_to_device_index = embedding_host_cache->host_to_device_index.get();
101   auto device_cache_host_to_device_ids = embedding_device_cache->host_to_device_ids.get();
102   MS_ERROR_IF_NULL(host_cache_host_to_device_index);
103   MS_ERROR_IF_NULL(device_cache_host_to_device_ids);
104 
105   auto embedding_size = hash_info.embedding_size;
106   MS_ERROR_IF_NULL(hash_info.address.addr);
107   auto hash_table_addr = reinterpret_cast<float *>(hash_info.address.addr);
108   MS_ERROR_IF_NULL(hash_info.host_address);
109   auto host_hash_table_addr = hash_info.host_address;
110   auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
111   RETURN_IF_FALSE_WITH_LOG(actor_->LookupLocalHostCache(embedding_size, swap_indices_size, host_hash_table_addr,
112                                                         host_cache_host_to_device_index, swap_out_data.get()),
113                            "Lookup local host cache failed.");
114 
115   RETURN_IF_FALSE_WITH_LOG(
116     MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_value_addr_, swap_out_data.get(),
117                             swap_indices_size * embedding_size * sizeof(float), device_context_, stream_id_),
118     "Memcpy host to device asynchronously failed.");
119   // Copy origin id to temp buffer of indices.
120   RETURN_IF_FALSE_WITH_LOG(
121     MemcpyHostToDeviceAsync(embedding_cache_table_manager.hash_swap_index_addr_, device_cache_host_to_device_ids,
122                             swap_indices_size * sizeof(int), device_context_, stream_id_),
123     "Memcpy host to device asynchronously failed.");
124 
125   RETURN_IF_FALSE_WITH_LOG(UpdateDeviceCache(embedding_cache_table_manager.hash_swap_index_addr_,
126                                              embedding_cache_table_manager.hash_swap_value_addr_, swap_indices_size,
127                                              embedding_size, hash_table_addr, hash_info.device_address),
128                            "Update device embedding cache failed.");
129   MS_ERROR_IF_NULL(device_context_);
130   MS_ERROR_IF_NULL(device_context_->device_res_manager_);
131   RETURN_IF_FALSE_WITH_LOG(device_context_->device_res_manager_->SyncStream(stream_id_), "Synchronize stream failed.");
132   return true;
133 }
134 
GetRemoteEmbeddingSliceBound(size_t vocab_size,size_t server_num,std::vector<std::pair<size_t,size_t>> * remote_embedding_slice_bounds)135 void DeviceSparseEmbeddingOperation::GetRemoteEmbeddingSliceBound(
136   size_t vocab_size, size_t server_num, std::vector<std::pair<size_t, size_t>> *remote_embedding_slice_bounds) {
137   if (server_num != 1) {
138     MS_LOG(EXCEPTION)
139       << "Sparse mode does not support multiple servers currently, so server number should be 1, but got: "
140       << server_num;
141   }
142 
143   MS_EXCEPTION_IF_NULL(remote_embedding_slice_bounds);
144   // Sparse mode does not support multiple servers currently, so the id does not need to be split, and the id range is
145   // specified from 0 to INTMAX .
146   (void)remote_embedding_slice_bounds->emplace_back(0, INT32_MAX);
147 }
148 
BuildEmbeddingCacheLookupKernel()149 void DeviceSparseEmbeddingOperation::BuildEmbeddingCacheLookupKernel() {
150   auto graph = std::make_shared<KernelGraph>();
151   MS_EXCEPTION_IF_NULL(graph);
152   graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
153   embedding_cache_graphs_.push_back(graph);
154 
155   // 1. Create parameter nodes which are inputs of embedding cache look up kernel(operator name: 'MapTensorGet').
156   ParameterPtr input_param = NewMapParameter(graph, kNumberTypeInt32, kNumberTypeFloat32, kOneDimensionalShape);
157   ParameterPtr input_ids = NewParameter(graph, kInt32, kOneDimensionalShape);
158 
159   // 2. Create a CNode for operator MapTensorGet.
160   PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kMapTensorGetOpName);
161   emb_lookup_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
162   emb_lookup_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
163   emb_lookup_primitive->set_attr(kAttrInsertDefaultValue, MakeValue(false));
164 
165   std::vector<AnfNodePtr> emb_lookup_input_nodes{mindspore::NewValueNode(emb_lookup_primitive), input_param, input_ids};
166   embedding_cache_lookup_node_ = graph->NewCNode(emb_lookup_input_nodes);
167   MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node_);
168   auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, kTwoDimensionalShape);
169   embedding_cache_lookup_node_->set_abstract(abstract);
170 
171   // 3. Kernel build process.
172   MS_EXCEPTION_IF_NULL(device_context_);
173   MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
174   device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_lookup_node_});
175   AnfAlgo::SetStreamId(stream_id_, embedding_cache_lookup_node_.get());
176 }
177 
BuildEmbeddingCacheUpdateKernel()178 void DeviceSparseEmbeddingOperation::BuildEmbeddingCacheUpdateKernel() {
179   auto graph = std::make_shared<KernelGraph>();
180   MS_EXCEPTION_IF_NULL(graph);
181   graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
182   embedding_cache_graphs_.push_back(graph);
183 
184   // 1. Create parameter nodes which are inputs of embedding cache update kernel(operator name: 'MapTensorPut').
185   ParameterPtr input_param = NewMapParameter(graph, kNumberTypeInt32, kNumberTypeFloat32, kOneDimensionalShape);
186   ParameterPtr input_ids = NewParameter(graph, kInt32, kOneDimensionalShape);
187   ParameterPtr update_values = NewParameter(graph, kFloat32, kTwoDimensionalShape);
188 
189   // 2. Create a CNode for operator MapTensorPut.
190   PrimitivePtr embedding_cache_update_primitive = std::make_shared<Primitive>(kMapTensorPutOpName);
191   embedding_cache_update_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
192 
193   std::vector<AnfNodePtr> embedding_cache_update_input_nodes{mindspore::NewValueNode(embedding_cache_update_primitive),
194                                                              input_param, input_ids, update_values};
195   embedding_cache_update_node_ = graph->NewCNode(embedding_cache_update_input_nodes);
196   MS_EXCEPTION_IF_NULL(embedding_cache_update_node_);
197   embedding_cache_update_node_->set_abstract(input_param->abstract());
198 
199   // 3. Kernel build process.
200   MS_EXCEPTION_IF_NULL(device_context_);
201   MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
202   device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_update_node_});
203   AnfAlgo::SetStreamId(stream_id_, embedding_cache_update_node_.get());
204 }
205 
BuildEmbeddingCacheEraseKernel()206 void DeviceSparseEmbeddingOperation::BuildEmbeddingCacheEraseKernel() {
207   auto graph = std::make_shared<KernelGraph>();
208   MS_EXCEPTION_IF_NULL(graph);
209   graph->set_graph_id((std::numeric_limits<uint32_t>::max)());
210   embedding_cache_graphs_.push_back(graph);
211 
212   // 1. Create parameter nodes which are inputs of embedding cache erase kernel(operator name: 'MapTensorErase').
213   ParameterPtr input_param = NewMapParameter(graph, kNumberTypeInt32, kNumberTypeFloat32, kOneDimensionalShape);
214   ParameterPtr input_ids = NewParameter(graph, kInt32, kOneDimensionalShape);
215 
216   // 2. Create a CNode for operator MapTensorErase.
217   PrimitivePtr embedding_cache_erase_primitive = std::make_shared<Primitive>(kMapTensorEraseOpName);
218   embedding_cache_erase_primitive->set_attr(kAttrInputIsDynamicShape, MakeValue(true));
219   embedding_cache_erase_primitive->set_attr(kAttrOutputIsDynamicShape, MakeValue(true));
220 
221   std::vector<AnfNodePtr> embedding_cache_erase_input_nodes{mindspore::NewValueNode(embedding_cache_erase_primitive),
222                                                             input_param, input_ids};
223   embedding_cache_erase_node_ = graph->NewCNode(embedding_cache_erase_input_nodes);
224   MS_EXCEPTION_IF_NULL(embedding_cache_erase_node_);
225   embedding_cache_erase_node_->set_abstract(input_param->abstract());
226 
227   // 3. Kernel build process.
228   MS_EXCEPTION_IF_NULL(device_context_);
229   MS_EXCEPTION_IF_NULL(device_context_->GetKernelExecutor(false));
230   device_context_->GetKernelExecutor(false)->CreateKernel({embedding_cache_erase_node_});
231   AnfAlgo::SetStreamId(stream_id_, embedding_cache_erase_node_.get());
232 }
233 
NewMapParameter(const KernelGraphPtr & graph,TypeId key_type,TypeId value_type,const ShapeVector & value_shape)234 ParameterPtr DeviceSparseEmbeddingOperation::NewMapParameter(const KernelGraphPtr &graph, TypeId key_type,
235                                                              TypeId value_type, const ShapeVector &value_shape) {
236   MS_EXCEPTION_IF_NULL(graph);
237 
238   auto param = graph->NewParameter();
239   MS_EXCEPTION_IF_NULL(param);
240   auto map_tensor = std::make_shared<tensor::MapTensor>(key_type, value_type, value_shape, nullptr);
241   auto abstract = std::make_shared<abstract::AbstractMapTensor>(map_tensor);
242   param->set_abstract(abstract);
243 
244   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
245   std::vector<std::string> formats = {kOpFormat_DEFAULT};
246   std::vector<TypeId> types = {kObjectTypeMapTensorType};
247   kernel_build_info_builder->SetOutputsFormat(formats);
248   kernel_build_info_builder->SetOutputsDeviceType(types);
249   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
250 
251   auto mutable_inputs = graph->MutableInputs();
252   MS_EXCEPTION_IF_NULL(mutable_inputs);
253   mutable_inputs->push_back(param);
254 
255   return param;
256 }
257 }  // namespace runtime
258 }  // namespace mindspore
259