1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/cache_embedding/cache_embedding.h"
18
19 #include <random>
20 #include <vector>
21 #include <list>
22 #include <queue>
23 #include <utility>
24 #include <memory>
25 #include <string>
26 #include <algorithm>
27
28 #include "ops/sequence_ops.h"
29 #include "ops/other_ops.h"
30 #include "ops/nn_optimizer_ops.h"
31 #include "ops/nn_ops.h"
32 #include "ops/array_ops.h"
33 #include "ops/framework_ops.h"
34 #include "utils/hash_map.h"
35 #include "utils/hash_set.h"
36 #include "include/backend/optimizer/helper.h"
37 #include "frontend/optimizer/optimizer.h"
38 #include "ir/func_graph.h"
39 #include "utils/cache_embedding_hashmap_struct.h"
40 namespace mindspore {
41 namespace parallel {
42 using ParamMap = mindspore::HashMap<ParameterPtr, ParameterPtr>;
43 using ParamSet = mindspore::HashSet<ParameterPtr>;
44 using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>;
45 using AnfMap = mindspore::HashMap<AnfNodePtr, AnfNodePtr>;
46 using AnfSet = mindspore::HashSet<AnfNodePtr>;
47
AddCacheParameters(const FuncGraphPtr & graph,const ParamSet & parameter_cache_enable_set)48 ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter_cache_enable_set) {
49 ParamMap cache_host_params_map;
50 for (auto ¶m : parameter_cache_enable_set) {
51 auto param_info = param->param_info();
52 if (param_info && param_info->cache_enable()) {
53 auto data_type = param->Type();
54 auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
55 auto type_id = data_element_type->type_id();
56 auto cache_shape = param_info->cache_shape();
57 auto ori_param_name = param->name();
58 auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape);
59 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
60 auto cache_name = ori_param_name + "_cache";
61 new_param_info->set_name(cache_name);
62 new_tensor->set_param_info(new_param_info);
63 auto cache_param = graph->AddFvParameter(cache_name, new_tensor);
64 cache_host_params_map[cache_param] = param;
65 }
66 }
67 return cache_host_params_map;
68 }
69
CheckHostCacheParamSize(const ParamSet & parameter_cache_enable_set)70 bool CheckHostCacheParamSize(const ParamSet ¶meter_cache_enable_set) {
71 int64_t host_size = 0;
72 int64_t cache_size = 0;
73 for (auto &host_param : parameter_cache_enable_set) {
74 auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
75 auto host_param_info = host_param->param_info();
76 auto cache_shape = host_param_info->cache_shape();
77 if (cache_shape.empty()) {
78 MS_LOG(EXCEPTION) << "The value of cache_shape is empty.";
79 }
80 auto tmp_cache_size = cache_shape[0];
81 if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) {
82 MS_LOG(EXCEPTION)
83 << "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same.";
84 }
85 cache_size = tmp_cache_size;
86 host_size = tmp_host_size;
87 }
88 if (cache_size > host_size) {
89 MS_LOG(WARNING) << "vocab_cache_size > vocab_size, there is no need use cache.";
90 return false;
91 }
92 return true;
93 }
94
ReplaceCacheParams(const FuncGraphPtr & graph,const ParamMap & map)95 void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) {
96 auto manager = graph->manager();
97 MS_EXCEPTION_IF_NULL(manager);
98 for (auto &ele : map) {
99 if (!manager->Replace(ele.second, ele.first)) {
100 MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed.";
101 }
102 }
103 }
104
MapKeysToSet(const ParamMap & map)105 ParamSet MapKeysToSet(const ParamMap &map) {
106 ParamSet set;
107 for (auto &ele : map) {
108 set.insert(ele.first);
109 }
110 return set;
111 }
112
FindParamCacheEnable(const FuncGraphPtr & graph)113 ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) {
114 ParamSet parameter_cache_enable_set;
115 auto parameters = graph->parameters();
116 auto params_size = parameters.size();
117 for (size_t i = 0; i < params_size; ++i) {
118 auto param = parameters[i]->cast<ParameterPtr>();
119 auto param_info = param->param_info();
120 if (param_info && param_info->cache_enable()) {
121 parameter_cache_enable_set.insert(param);
122 }
123 }
124 return parameter_cache_enable_set;
125 }
126
FindUniqueCacheEnable(const CNodePtrList & cnodes)127 CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) {
128 size_t cnodes_size = cnodes.size();
129 CNodePtrList unique_cache_enable;
130 for (size_t i = 0; i < cnodes_size; ++i) {
131 if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) {
132 auto unique_node = cnodes[i];
133 auto unique_prim = GetCNodePrimitive(unique_node);
134 MS_EXCEPTION_IF_NULL(unique_prim);
135 auto attr_value = unique_prim->GetAttr(kAttrCacheEnable);
136 if (attr_value != nullptr && GetValue<bool>(attr_value)) {
137 unique_cache_enable.emplace_back(unique_node);
138 }
139 }
140 }
141 if (unique_cache_enable.size() > 1) {
142 MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size();
143 }
144 return unique_cache_enable;
145 }
146
147 template <typename T>
MemCopyFromHostToCache(void * hashmap_addr,void * host_addr,void * cache_addr,size_t host_max,size_t cache_max,size_t hashmap_size,size_t col_size)148 void MemCopyFromHostToCache(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
149 size_t hashmap_size, size_t col_size) {
150 auto host_data = static_cast<char *>(host_addr);
151 auto cache_data = static_cast<char *>(cache_addr);
152 auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
153 // default param type float
154 const size_t param_type_size = 4;
155 size_t single_col_bytes = param_type_size * col_size;
156 for (size_t i = 0; i < hashmap_size; ++i) {
157 if (!hashmap_data[i].IsEmpty()) {
158 size_t host_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].key_);
159 size_t cache_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].value_);
160 if (host_offset + single_col_bytes <= host_max) {
161 auto ret =
162 memcpy_s(cache_data + cache_offset, cache_max - cache_offset, host_data + host_offset, single_col_bytes);
163 if (ret != EOK) {
164 MS_LOG(EXCEPTION) << "Memcpy failed.";
165 }
166 }
167 }
168 }
169 MS_LOG(INFO) << "Memcpy from cache to host success!";
170 }
171
BindAndInitCacheTensor(const ParamMap & param_pair_list,const ParameterPtr & hashmap)172 void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr &hashmap) {
173 auto hashmap_tensor_value = hashmap->default_param();
174 auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
175 auto hashmap_size = hashmap_tensor->shape_c()[0];
176 auto hashmap_data_type = hashmap_tensor->data_type();
177 for (auto &ele : param_pair_list) {
178 auto host_tensor_value = ele.second->default_param();
179 auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
180 auto cache_tensor_value = ele.first->default_param();
181 auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
182
183 // bind host, cache, hashmap
184 host_tensor->set_cache_enable(true);
185 host_tensor->set_hashmap_tensor_ptr(hashmap_tensor);
186 host_tensor->set_cache_tensor_ptr(cache_tensor);
187
188 // init cache tensor data
189 auto host_shape = host_tensor->shape_c();
190 auto cache_shape = cache_tensor->shape_c();
191 if (host_shape.size() != 2 && cache_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
192 MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
193 << "host shape:" << host_shape << ", cache shape:" << cache_shape;
194 }
195 auto host_data_max_size = static_cast<size_t>(host_tensor->Size());
196 auto cache_data_max_size = static_cast<size_t>(cache_tensor->Size());
197 if (hashmap_data_type == TypeId::kNumberTypeInt32) {
198 MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
199 host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
200 LongToSize(host_shape[1]));
201 } else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
202 MemCopyFromHostToCache<int64_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
203 host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
204 LongToSize(host_shape[1]));
205 } else {
206 MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
207 }
208 }
209 }
210
211 template <typename T>
InitHashMapData(void * data,const int64_t host_size,const int64_t cache_size,const size_t hashmap_size,const size_t byte_size)212 void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size,
213 const size_t byte_size) {
214 MS_LOG(INFO) << "Start init hashmap data.";
215 MS_EXCEPTION_IF_NULL(data);
216 HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data);
217 MS_EXCEPTION_IF_NULL(hashmap_data);
218 int ret = memset_s(hashmap_data, byte_size, 0, byte_size);
219 if (ret != EOK) {
220 MS_LOG(EXCEPTION) << "Memset failed.";
221 }
222 std::vector<T> host_range;
223 host_range.reserve(static_cast<T>(host_size));
224 for (int64_t i = 0; i < host_size; ++i) {
225 host_range.emplace_back(static_cast<T>(i));
226 }
227 #if defined(__APPLE__) || defined(_MSC_VER)
228 std::random_device rd;
229 std::mt19937 rng(rd());
230 std::shuffle(host_range.begin(), host_range.end(), rng);
231 #else
232 std::random_shuffle(host_range.begin(), host_range.end());
233 #endif
234 size_t size = static_cast<size_t>(cache_size);
235 size_t hashmap_count = 0;
236 for (size_t i = 0; i < size; ++i) {
237 auto random_key = host_range[i];
238 auto entry = HashFunc(random_key, hashmap_size);
239 size_t count = 1;
240 while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) {
241 count += 1;
242 entry = (entry + 1) % static_cast<T>(hashmap_size);
243 }
244 if (hashmap_data[entry].IsEmpty()) {
245 hashmap_count++;
246 hashmap_data[entry].key_ = random_key;
247 hashmap_data[entry].value_ = SizeToInt(i);
248 hashmap_data[entry].step_ = kInitStep;
249 hashmap_data[entry].tag_ = SizeToInt(count);
250 }
251 }
252 MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size;
253 }
254
InitHashMap(const FuncGraphPtr & func_graph,const int64_t host_size,const int64_t cache_size,TypeId type_id)255 AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size,
256 TypeId type_id) {
257 // init new tensor
258 size_t hashmap_size = static_cast<size_t>(cache_size * kEmptyRate);
259 std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4};
260 auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
261 size_t byte_size = new_tensor->Size();
262 if (type_id == TypeId::kNumberTypeInt64) {
263 InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
264 } else {
265 InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
266 }
267 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
268 std::string hashmap_name = "cache_hashmap";
269 new_param_info->set_name(hashmap_name);
270 new_tensor->set_param_info(new_param_info);
271 return func_graph->AddFvParameter(hashmap_name, new_tensor);
272 }
273
InitStep(const FuncGraphPtr & func_graph,TypeId type_id)274 AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
275 std::vector<int64_t> host_shape{1};
276 auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
277 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
278 std::string step_name = "cache_step";
279 new_param_info->set_name(step_name);
280 new_tensor->set_param_info(new_param_info);
281 return func_graph->AddFvParameter(step_name, new_tensor);
282 }
283
CreateMapCacheIdx(const FuncGraphPtr & func_graph,const AnfNodePtr & indices,const ParamMap & cache_host_params_map)284 AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
285 const ParamMap &cache_host_params_map) {
286 auto iter = cache_host_params_map.begin();
287 int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
288 int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
289 auto indices_type = indices->Type();
290 auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element();
291 auto indices_type_id = indices_element_type->type_id();
292 auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id);
293 auto step = InitStep(func_graph, indices_type_id);
294 auto max_num = NewValueNode(MakeValue(host_size));
295 auto hashmap_param = hashmap->cast<ParameterPtr>();
296 BindAndInitCacheTensor(cache_host_params_map, hashmap_param);
297 // add rank_id
298 int64_t offset_value = 0;
299 std::string rank_id_str = common::GetEnv("RANK_ID");
300 if (!rank_id_str.empty()) {
301 int64_t rank_id = atoi(rank_id_str.c_str());
302 offset_value = rank_id * host_size;
303 }
304 auto offset = NewValueNode(MakeValue(offset_value));
305 auto max_num_imm = std::make_shared<Int64Imm>(host_size);
306 auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
307 max_num->set_abstract(max_num_abstract_scalar);
308 auto offset_imm = std::make_shared<Int64Imm>(offset_value);
309 auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm);
310 offset->set_abstract(offset_abstract_scalar);
311
312 PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx;
313 map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
314 std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset};
315 auto map_cache_idx = func_graph->NewCNode(map_cache_nodes);
316
317 auto indices_ori_shp = indices->Shape();
318 auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>();
319 ShapeVector shape(indices_shp->shape().size(), -1);
320
321 auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp);
322 auto old_emb_idx =
323 std::make_shared<abstract::AbstractTensor>(indices_element_type, std::make_shared<abstract::Shape>(shape));
324 auto miss_emb_idx =
325 std::make_shared<abstract::AbstractTensor>(indices_element_type, std::make_shared<abstract::Shape>(shape));
326 auto swap_emb_idx =
327 std::make_shared<abstract::AbstractTensor>(indices_element_type, std::make_shared<abstract::Shape>(shape));
328
329 std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
330 auto abstract = std::make_shared<abstract::AbstractTuple>(elements);
331 map_cache_idx->set_abstract(abstract);
332 return map_cache_idx;
333 }
334
CreateTupleGetItem(const FuncGraphPtr & func_graph,const AnfNodePtr & input,size_t index)335 AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
336 MS_EXCEPTION_IF_NULL(func_graph);
337 auto idx = NewValueNode(SizeToLong(index));
338 MS_EXCEPTION_IF_NULL(idx);
339 auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
340 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
341 idx->set_abstract(abstract_scalar);
342 auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
343 auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
344 auto tuple_getitem_abstract = input_abstract_tuple->elements()[index];
345 tuple_getitem->set_abstract(tuple_getitem_abstract);
346 return tuple_getitem;
347 }
348
CreateTupleGetItems(const FuncGraphPtr & func_graph,const AnfNodePtr & input,std::vector<AnfNodePtr> * outputs)349 void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) {
350 auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
351 auto size = input_abstract_tuple->elements().size();
352 MS_EXCEPTION_IF_NULL(outputs);
353 for (size_t i = 0; i < size; ++i) {
354 (*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i));
355 }
356 }
357
CreateEmbeddingLookup(const FuncGraphPtr & graph,AnfNodePtr params,AnfNodePtr indices)358 AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
359 MS_EXCEPTION_IF_NULL(graph);
360 PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
361 emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
362 ValueNodePtr offset_value_node = NewValueNode(static_cast<int64_t>(0));
363 std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices, offset_value_node};
364 auto emb_lookup = graph->NewCNode(emb_lookup_nodes);
365 return emb_lookup;
366 }
367
CreateCacheSwapTable(const FuncGraphPtr & graph,ParameterPtr cache_table,AnfNodePtr swap_cache_idx,AnfNodePtr miss_value)368 AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
369 AnfNodePtr miss_value) {
370 MS_EXCEPTION_IF_NULL(graph);
371 PrimitivePtr cache_swap_table_primitive = std::make_shared<Primitive>(kCacheSwapTableOpName);
372 std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
373 miss_value};
374 auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
375 return cache_swap_table;
376 }
377
CreateUpdateCache(const FuncGraphPtr & graph,ParameterPtr params,AnfNodePtr old_emb_idx,AnfNodePtr old_value)378 AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
379 AnfNodePtr old_value) {
380 MS_EXCEPTION_IF_NULL(graph);
381 PrimitivePtr update_cache_primitive = std::make_shared<Primitive>(kUpdateCacheOpName);
382 update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
383
384 auto params_ori_shp = params->Shape();
385 MS_EXCEPTION_IF_NULL(params_ori_shp);
386 auto params_shp = params_ori_shp->cast<abstract::ShapePtr>();
387 MS_EXCEPTION_IF_NULL(params_shp);
388 auto params_shape = params_shp->shape();
389 auto max_size = params_shape[0];
390 auto max_size_node = NewValueNode(MakeValue(max_size));
391 auto max_num_imm = std::make_shared<Int64Imm>(max_size);
392 auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
393 max_size_node->set_abstract(max_num_abstract_scalar);
394
395 std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value,
396 max_size_node};
397 auto update_cache = graph->NewCNode(update_cache_nodes);
398 return update_cache;
399 }
400
CreateEmbSwapUpdate(const FuncGraphPtr & graph,ParamMap param_pair_list,const AnfNodePtrList & map_cache_idx_node_outputs)401 NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list,
402 const AnfNodePtrList &map_cache_idx_node_outputs) {
403 MS_EXCEPTION_IF_NULL(graph);
404 NodePairList node_pair_list;
405 for (auto &ele : param_pair_list) {
406 auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]);
407 auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup);
408 auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table);
409 node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache));
410 }
411 return node_pair_list;
412 }
413
CreateControlDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & prior_node,const AnfNodePtr & behind_node)414 void CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node) {
415 // Create control depend
416 MS_EXCEPTION_IF_NULL(main_graph);
417 auto manager = main_graph->manager();
418 MS_EXCEPTION_IF_NULL(manager);
419 AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node};
420 auto depend_cnode = main_graph->NewCNode(cd_inputs);
421 if (!manager->Replace(behind_node, depend_cnode)) {
422 MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed.";
423 }
424 }
425
CreateDepend(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & invalid_nodes,const AnfNodePtr & patron_node)426 AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes,
427 const AnfNodePtr &patron_node) {
428 MS_EXCEPTION_IF_NULL(graph);
429 std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)};
430 std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list));
431 auto make_tuple = graph->NewCNode(make_tuple_list);
432 std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple};
433 auto depend_cnode = graph->NewCNode(depend_list);
434 depend_cnode->set_abstract(patron_node->abstract());
435 return depend_cnode;
436 }
437
FindSparseGatherV2WithCache(const CNodePtrList & cnodes,const ParamSet & param_set)438 CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet ¶m_set) {
439 size_t cnodes_size = cnodes.size();
440 CNodePtrList sparse_gather_v2_with_cache;
441 for (size_t i = 0; i < cnodes_size; ++i) {
442 if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2) ||
443 IsPrimitiveCNode(cnodes[i], prim::kPrimEmbeddingLookup)) {
444 auto load_node = cnodes[i]->input(1);
445 if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
446 load_node = load_node->cast<CNodePtr>()->input(1);
447 }
448 if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
449 auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
450 if (param_set.find(param_node) != param_set.end()) {
451 sparse_gather_v2_with_cache.push_back(cnodes[i]);
452 } else {
453 MS_LOG(EXCEPTION) << "EmbeddingLookup can't not support cache and no cache in the same graph.";
454 }
455 }
456 }
457 }
458 if (sparse_gather_v2_with_cache.empty()) {
459 MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param.";
460 }
461
462 auto indices = sparse_gather_v2_with_cache[0]->input(2);
463 for (auto &ele : sparse_gather_v2_with_cache) {
464 if (ele->input(2) != indices) {
465 MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param have different indices!.";
466 }
467 }
468 return sparse_gather_v2_with_cache;
469 }
470
FindGatherV2FromSparseGatherV2(const FuncGraphPtr & graph,const AnfNodePtr & node)471 AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) {
472 MS_EXCEPTION_IF_NULL(graph);
473 AnfNodePtrList gatherv2_nodes;
474 auto user_set = graph->manager()->node_users()[node];
475 for (auto &ele : user_set) {
476 if (IsPrimitiveCNode(ele.first, prim::kPrimGather)) {
477 gatherv2_nodes.emplace_back(ele.first);
478 }
479 }
480 if (gatherv2_nodes.size() != 1) {
481 MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got "
482 << gatherv2_nodes.size();
483 }
484 return gatherv2_nodes[0];
485 }
486
FindNoRefParams(const FuncGraphPtr & graph)487 AnfSet FindNoRefParams(const FuncGraphPtr &graph) {
488 AnfSet no_ref_params;
489 auto params = graph->parameters();
490 for (auto &anf_param : params) {
491 auto param = anf_param->cast<ParameterPtr>();
492 if (!param->has_default()) {
493 MS_LOG(INFO) << param->DebugString() << " has no default";
494 no_ref_params.insert(anf_param);
495 }
496 }
497 return no_ref_params;
498 }
499
RemoveOriginParamFromSet(const CNodePtr & unique_node,AnfSet * no_ref_params)500 void RemoveOriginParamFromSet(const CNodePtr &unique_node, AnfSet *no_ref_params) {
501 std::queue<CNodePtr> que;
502 que.push(unique_node);
503 while (!que.empty()) {
504 auto node = que.front();
505 que.pop();
506 auto node_inputs = node->inputs();
507 for (auto &input : node_inputs) {
508 if (input->isa<CNode>()) {
509 que.push(input->cast<CNodePtr>());
510 } else if (input->isa<Parameter>()) {
511 size_t num = no_ref_params->erase(input);
512 if (num > 0) {
513 MS_LOG(INFO) << "Erase unique_node input from set success.";
514 return;
515 }
516 }
517 }
518 }
519 MS_LOG(EXCEPTION) << "Can not find any parameter that use by Unique.";
520 }
521
CreateOutputNodeParam(const FuncGraphPtr & graph,const AnfNodePtr & ori_input,const std::string & name)522 AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &ori_input, const std::string &name) {
523 auto ori_input_type = ori_input->Type();
524 auto ori_input_element_type = ori_input_type->cast<mindspore::TensorTypePtr>()->element();
525 auto ori_input_type_id = ori_input_element_type->type_id();
526 auto ori_input_shp = ori_input->Shape();
527 auto input_shp = ori_input_shp->cast<abstract::ShapePtr>();
528 auto input_shape = input_shp->shape();
529 auto new_tensor = std::make_shared<tensor::Tensor>(ori_input_type_id, input_shape);
530 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
531 auto new_param_name = name + "_pipe";
532 new_param_info->set_name(new_param_name);
533 new_tensor->set_param_info(new_param_info);
534 return graph->AddFvParameter(new_param_name, new_tensor);
535 }
536
CreateOtherPipeParams(const FuncGraphPtr & graph,const AnfSet & no_ref_params)537 AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {
538 AnfMap no_ref_pipe_param_map;
539 for (auto ¶m : no_ref_params) {
540 auto ori_param = param->cast<ParameterPtr>();
541 auto ori_name = ori_param->name();
542 auto new_param = CreateOutputNodeParam(graph, param, ori_name);
543 no_ref_pipe_param_map[param] = new_param;
544 }
545 return no_ref_pipe_param_map;
546 }
547
CreateAssign(const FuncGraphPtr & graph,const AnfNodePtr & res_param,const AnfNodePtr & src_param,bool is_dynamic=false)548 AnfNodePtr CreateAssign(const FuncGraphPtr &graph, const AnfNodePtr &res_param, const AnfNodePtr &src_param,
549 bool is_dynamic = false) {
550 auto assign_prim = prim::kPrimAssign;
551 if (is_dynamic) {
552 assign_prim = prim::kPrimDynamicAssign;
553 assign_prim->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
554 }
555 std::vector<AnfNodePtr> assign_nodes{NewValueNode(assign_prim), res_param, src_param};
556 auto assign_status = graph->NewCNode(assign_nodes);
557 return assign_status;
558 }
559
FindCNodeOutput(const FuncGraphPtr & graph,const AnfNodePtr & node,int64_t index)560 AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t index) {
561 auto manager = graph->manager();
562 auto node_users = manager->node_users()[node];
563 for (auto &node_user : node_users) {
564 if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
565 auto cnode = node_user.first->cast<CNodePtr>();
566 auto node_index = cnode->input(2);
567 if (node_index->isa<ValueNode>()) {
568 auto value_node = node_index->cast<ValueNodePtr>();
569 MS_EXCEPTION_IF_NULL(value_node);
570 auto item_idx = GetValue<int64_t>(value_node->value());
571 if (item_idx == index) {
572 return node_user.first;
573 }
574 }
575 }
576 }
577 MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index;
578 }
579
ReplaceNoRefToParams(const FuncGraphPtr & graph,const AnfMap & no_ref_pipe_param_map,const AnfNodePtr & cache_idx_param,const AnfNodePtr & cache_idx,const AnfNodePtr & sparse_gatherv2_indices)580 void ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map,
581 const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx,
582 const AnfNodePtr &sparse_gatherv2_indices) {
583 auto manager = graph->manager();
584 MS_EXCEPTION_IF_NULL(manager);
585 const auto &node_users = manager->node_users();
586 // add other no ref pipe param and unique index dense
587 for (auto &ele : no_ref_pipe_param_map) {
588 const auto &user_set = node_users.at(ele.first);
589 auto assign_status = CreateAssign(graph, ele.second, ele.first);
590 for (const auto &user_node : user_set) {
591 CreateControlDepend(graph, user_node.first, assign_status);
592 }
593 if (!manager->Replace(ele.first, ele.second)) {
594 MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed.";
595 }
596 }
597
598 // add cache idx param
599 auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true);
600 const auto &indices_user_set = node_users.at(sparse_gatherv2_indices);
601 for (const auto &user_node : indices_user_set) {
602 CreateControlDepend(graph, user_node.first, dynamic_assgin_status);
603 }
604 if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) {
605 MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed.";
606 }
607 }
608
CacheEmbeddingForTrain(const FuncGraphPtr & graph,bool is_pipe,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)609 void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes,
610 const CNodePtr &unique_node, const ParamSet ¶m_cache_enable_set) {
611 MS_EXCEPTION_IF_NULL(graph);
612 auto manager = graph->manager();
613 MS_EXCEPTION_IF_NULL(manager);
614 size_t cnodes_size = cnodes.size();
615 auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
616 if (cache_host_params_map.empty()) {
617 MS_LOG(EXCEPTION) << "host's cache parameter map is empty!";
618 }
619 auto param_set = MapKeysToSet(cache_host_params_map);
620 ReplaceCacheParams(graph, cache_host_params_map);
621 graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
622 MS_LOG(INFO) << "Graph is set cache enable.";
623
624 CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
625 auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
626 auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);
627
628 AnfNodePtrList map_cache_idx_node_outputs;
629 CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);
630
631 auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);
632 AnfNodePtrList invalid_nodes;
633 auto cache_idx = map_cache_idx_node_outputs[0];
634 if (!is_pipe) {
635 if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), cache_idx)) {
636 MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
637 }
638 for (auto &ele : node_pair_list) {
639 for (auto &gather_op : sparse_gatherv2_with_cache) {
640 CreateControlDepend(graph, ele.first, gather_op);
641 }
642 invalid_nodes.emplace_back(ele.second);
643 }
644 } else {
645 auto cache_idx_param = CreateOutputNodeParam(graph, unique_node->input(1), std::string("cache_idx"));
646 auto unique_index_reverse = FindCNodeOutput(graph, unique_node, 1);
647 auto unique_index_param = CreateOutputNodeParam(graph, unique_index_reverse, std::string("index_dense"));
648 auto no_ref_params = FindNoRefParams(graph);
649 RemoveOriginParamFromSet(unique_node, &no_ref_params);
650 auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params);
651 no_ref_param_map[unique_index_reverse] = unique_index_param;
652 ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, sparse_gatherv2_with_cache[0]->input(2));
653 std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes),
654 [](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; });
655 }
656 AnfNodePtr last_node = cnodes[cnodes_size - 1];
657 CNodePtr return_node;
658 if (last_node->isa<CNode>()) {
659 return_node = last_node->cast<CNodePtr>();
660 }
661 MS_EXCEPTION_IF_NULL(return_node);
662 if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
663 MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
664 }
665 if (return_node->size() < 2) {
666 MS_LOG(EXCEPTION) << "Number of return node inputs should be greater than or equal to 2.";
667 }
668
669 auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
670 if (!manager->Replace(return_node->input(1), depend_node)) {
671 MS_LOG(EXCEPTION) << "Depend replace node failed";
672 }
673 }
674
CacheEmbeddingForEval(const FuncGraphPtr & graph,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)675 void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes, const CNodePtr &unique_node,
676 const ParamSet ¶m_cache_enable_set) {
677 MS_EXCEPTION_IF_NULL(graph);
678 auto manager = graph->manager();
679 MS_EXCEPTION_IF_NULL(manager);
680 graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
681 MS_LOG(INFO) << "Graph is set cache enable.";
682 // replace GatherV2 to EmbeddingLookupCPU
683 auto indices = unique_node->input(1);
684 auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
685 for (auto &ele : sparse_gatherv2_with_cache) {
686 auto anf_ele = ele->cast<AnfNodePtr>();
687 auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
688 auto embedding_lookup = CreateEmbeddingLookup(graph, ele->input(1), indices);
689 if (!manager->Replace(gatherv2, embedding_lookup)) {
690 MS_LOG(EXCEPTION) << "Depend replace node failed";
691 }
692 }
693 }
694
AddCacheEmbedding(const FuncGraphPtr & graph,bool is_pipe)695 void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
696 MS_EXCEPTION_IF_NULL(graph);
697 std::list<CNodePtr> orders = graph->GetOrderedCnodes();
698 CNodePtrList cnodes(orders.cbegin(), orders.cend());
699 bool training = graph->has_flag("training");
700 auto param_cache_enable_set = FindParamCacheEnable(graph);
701 if (param_cache_enable_set.empty()) {
702 MS_LOG(INFO) << "Parameters are all not cache enable.";
703 return;
704 } else {
705 MS_LOG(INFO) << "Parameters have cache enable.";
706 }
707 if (!CheckHostCacheParamSize(param_cache_enable_set)) {
708 return;
709 }
710 for (auto &node : cnodes) {
711 if (IsPrimitiveCNode(node, prim::kPrimNPUAllocFloatStatus)) {
712 MS_LOG(EXCEPTION) << "Cache embedding haven't support loss scale yet.";
713 }
714 }
715 auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
716 if (unique_cache_enable.empty()) {
717 MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
718 return;
719 }
720 auto unique_node = unique_cache_enable[0];
721 if (training) {
722 // If training, create cache parameters corresponding to the host params with is cache_enable.
723 // Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr.
724 // Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values
725 // from cache to host in each epoch end.
726 // Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to
727 // flush miss values to cache params and write back old values to host params.
728 // If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add
729 // ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph.
730 // If use pipe in training, create parameters for no ref param such as labels and MapCacheIdx output[0] and
731 // Unique output[1], in each step, it will train the data from last step, so that can hide the time of Unique
732 // and other cpu kernels. So in the first step, it's fake data.
733 CacheEmbeddingForTrain(graph, is_pipe, cnodes, unique_node, param_cache_enable_set);
734 } else {
735 // If eval, Use EmbeddingLookup(CPU) op to replace GatherV2.
736 // The network is the same as Host-Device mode.
737 CacheEmbeddingForEval(graph, cnodes, unique_node, param_cache_enable_set);
738 }
739 }
740 } // namespace parallel
741 } // namespace mindspore
742