• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/cache_embedding/cache_embedding.h"
18 
19 #include <random>
20 #include <vector>
21 #include <list>
22 #include <queue>
23 #include <utility>
24 #include <memory>
25 #include <string>
26 #include <algorithm>
27 
28 #include "ops/sequence_ops.h"
29 #include "ops/other_ops.h"
30 #include "ops/nn_optimizer_ops.h"
31 #include "ops/nn_ops.h"
32 #include "ops/array_ops.h"
33 #include "ops/framework_ops.h"
34 #include "utils/hash_map.h"
35 #include "utils/hash_set.h"
36 #include "include/backend/optimizer/helper.h"
37 #include "frontend/optimizer/optimizer.h"
38 #include "ir/func_graph.h"
39 #include "utils/cache_embedding_hashmap_struct.h"
40 namespace mindspore {
41 namespace parallel {
42 using ParamMap = mindspore::HashMap<ParameterPtr, ParameterPtr>;
43 using ParamSet = mindspore::HashSet<ParameterPtr>;
44 using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>;
45 using AnfMap = mindspore::HashMap<AnfNodePtr, AnfNodePtr>;
46 using AnfSet = mindspore::HashSet<AnfNodePtr>;
47 
AddCacheParameters(const FuncGraphPtr & graph,const ParamSet & parameter_cache_enable_set)48 ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet &parameter_cache_enable_set) {
49   ParamMap cache_host_params_map;
50   for (auto &param : parameter_cache_enable_set) {
51     auto param_info = param->param_info();
52     if (param_info && param_info->cache_enable()) {
53       auto data_type = param->Type();
54       auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
55       auto type_id = data_element_type->type_id();
56       auto cache_shape = param_info->cache_shape();
57       auto ori_param_name = param->name();
58       auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape);
59       ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
60       auto cache_name = ori_param_name + "_cache";
61       new_param_info->set_name(cache_name);
62       new_tensor->set_param_info(new_param_info);
63       auto cache_param = graph->AddFvParameter(cache_name, new_tensor);
64       cache_host_params_map[cache_param] = param;
65     }
66   }
67   return cache_host_params_map;
68 }
69 
CheckHostCacheParamSize(const ParamSet & parameter_cache_enable_set)70 bool CheckHostCacheParamSize(const ParamSet &parameter_cache_enable_set) {
71   int64_t host_size = 0;
72   int64_t cache_size = 0;
73   for (auto &host_param : parameter_cache_enable_set) {
74     auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
75     auto host_param_info = host_param->param_info();
76     auto cache_shape = host_param_info->cache_shape();
77     if (cache_shape.empty()) {
78       MS_LOG(EXCEPTION) << "The value of cache_shape is empty.";
79     }
80     auto tmp_cache_size = cache_shape[0];
81     if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) {
82       MS_LOG(EXCEPTION)
83         << "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same.";
84     }
85     cache_size = tmp_cache_size;
86     host_size = tmp_host_size;
87   }
88   if (cache_size > host_size) {
89     MS_LOG(WARNING) << "vocab_cache_size > vocab_size, there is no need use cache.";
90     return false;
91   }
92   return true;
93 }
94 
ReplaceCacheParams(const FuncGraphPtr & graph,const ParamMap & map)95 void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) {
96   auto manager = graph->manager();
97   MS_EXCEPTION_IF_NULL(manager);
98   for (auto &ele : map) {
99     if (!manager->Replace(ele.second, ele.first)) {
100       MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed.";
101     }
102   }
103 }
104 
MapKeysToSet(const ParamMap & map)105 ParamSet MapKeysToSet(const ParamMap &map) {
106   ParamSet set;
107   for (auto &ele : map) {
108     set.insert(ele.first);
109   }
110   return set;
111 }
112 
FindParamCacheEnable(const FuncGraphPtr & graph)113 ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) {
114   ParamSet parameter_cache_enable_set;
115   auto parameters = graph->parameters();
116   auto params_size = parameters.size();
117   for (size_t i = 0; i < params_size; ++i) {
118     auto param = parameters[i]->cast<ParameterPtr>();
119     auto param_info = param->param_info();
120     if (param_info && param_info->cache_enable()) {
121       parameter_cache_enable_set.insert(param);
122     }
123   }
124   return parameter_cache_enable_set;
125 }
126 
FindUniqueCacheEnable(const CNodePtrList & cnodes)127 CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) {
128   size_t cnodes_size = cnodes.size();
129   CNodePtrList unique_cache_enable;
130   for (size_t i = 0; i < cnodes_size; ++i) {
131     if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) {
132       auto unique_node = cnodes[i];
133       auto unique_prim = GetCNodePrimitive(unique_node);
134       MS_EXCEPTION_IF_NULL(unique_prim);
135       auto attr_value = unique_prim->GetAttr(kAttrCacheEnable);
136       if (attr_value != nullptr && GetValue<bool>(attr_value)) {
137         unique_cache_enable.emplace_back(unique_node);
138       }
139     }
140   }
141   if (unique_cache_enable.size() > 1) {
142     MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size();
143   }
144   return unique_cache_enable;
145 }
146 
147 template <typename T>
MemCopyFromHostToCache(void * hashmap_addr,void * host_addr,void * cache_addr,size_t host_max,size_t cache_max,size_t hashmap_size,size_t col_size)148 void MemCopyFromHostToCache(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
149                             size_t hashmap_size, size_t col_size) {
150   auto host_data = static_cast<char *>(host_addr);
151   auto cache_data = static_cast<char *>(cache_addr);
152   auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
153   // default param type float
154   const size_t param_type_size = 4;
155   size_t single_col_bytes = param_type_size * col_size;
156   for (size_t i = 0; i < hashmap_size; ++i) {
157     if (!hashmap_data[i].IsEmpty()) {
158       size_t host_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].key_);
159       size_t cache_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].value_);
160       if (host_offset + single_col_bytes <= host_max) {
161         auto ret =
162           memcpy_s(cache_data + cache_offset, cache_max - cache_offset, host_data + host_offset, single_col_bytes);
163         if (ret != EOK) {
164           MS_LOG(EXCEPTION) << "Memcpy failed.";
165         }
166       }
167     }
168   }
169   MS_LOG(INFO) << "Memcpy from cache to host success!";
170 }
171 
BindAndInitCacheTensor(const ParamMap & param_pair_list,const ParameterPtr & hashmap)172 void BindAndInitCacheTensor(const ParamMap &param_pair_list, const ParameterPtr &hashmap) {
173   auto hashmap_tensor_value = hashmap->default_param();
174   auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
175   auto hashmap_size = hashmap_tensor->shape_c()[0];
176   auto hashmap_data_type = hashmap_tensor->data_type();
177   for (auto &ele : param_pair_list) {
178     auto host_tensor_value = ele.second->default_param();
179     auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
180     auto cache_tensor_value = ele.first->default_param();
181     auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
182 
183     // bind host, cache, hashmap
184     host_tensor->set_cache_enable(true);
185     host_tensor->set_hashmap_tensor_ptr(hashmap_tensor);
186     host_tensor->set_cache_tensor_ptr(cache_tensor);
187 
188     // init cache tensor data
189     auto host_shape = host_tensor->shape_c();
190     auto cache_shape = cache_tensor->shape_c();
191     if (host_shape.size() != 2 && cache_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
192       MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
193                         << "host shape:" << host_shape << ", cache shape:" << cache_shape;
194     }
195     auto host_data_max_size = static_cast<size_t>(host_tensor->Size());
196     auto cache_data_max_size = static_cast<size_t>(cache_tensor->Size());
197     if (hashmap_data_type == TypeId::kNumberTypeInt32) {
198       MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
199                                       host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
200                                       LongToSize(host_shape[1]));
201     } else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
202       MemCopyFromHostToCache<int64_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
203                                       host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
204                                       LongToSize(host_shape[1]));
205     } else {
206       MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
207     }
208   }
209 }
210 
211 template <typename T>
InitHashMapData(void * data,const int64_t host_size,const int64_t cache_size,const size_t hashmap_size,const size_t byte_size)212 void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size,
213                      const size_t byte_size) {
214   MS_LOG(INFO) << "Start init hashmap data.";
215   MS_EXCEPTION_IF_NULL(data);
216   HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data);
217   MS_EXCEPTION_IF_NULL(hashmap_data);
218   int ret = memset_s(hashmap_data, byte_size, 0, byte_size);
219   if (ret != EOK) {
220     MS_LOG(EXCEPTION) << "Memset failed.";
221   }
222   std::vector<T> host_range;
223   host_range.reserve(static_cast<T>(host_size));
224   for (int64_t i = 0; i < host_size; ++i) {
225     host_range.emplace_back(static_cast<T>(i));
226   }
227 #if defined(__APPLE__) || defined(_MSC_VER)
228   std::random_device rd;
229   std::mt19937 rng(rd());
230   std::shuffle(host_range.begin(), host_range.end(), rng);
231 #else
232   std::random_shuffle(host_range.begin(), host_range.end());
233 #endif
234   size_t size = static_cast<size_t>(cache_size);
235   size_t hashmap_count = 0;
236   for (size_t i = 0; i < size; ++i) {
237     auto random_key = host_range[i];
238     auto entry = HashFunc(random_key, hashmap_size);
239     size_t count = 1;
240     while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) {
241       count += 1;
242       entry = (entry + 1) % static_cast<T>(hashmap_size);
243     }
244     if (hashmap_data[entry].IsEmpty()) {
245       hashmap_count++;
246       hashmap_data[entry].key_ = random_key;
247       hashmap_data[entry].value_ = SizeToInt(i);
248       hashmap_data[entry].step_ = kInitStep;
249       hashmap_data[entry].tag_ = SizeToInt(count);
250     }
251   }
252   MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size;
253 }
254 
InitHashMap(const FuncGraphPtr & func_graph,const int64_t host_size,const int64_t cache_size,TypeId type_id)255 AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size,
256                        TypeId type_id) {
257   // init new tensor
258   size_t hashmap_size = static_cast<size_t>(cache_size * kEmptyRate);
259   std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4};
260   auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
261   size_t byte_size = new_tensor->Size();
262   if (type_id == TypeId::kNumberTypeInt64) {
263     InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
264   } else {
265     InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
266   }
267   ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
268   std::string hashmap_name = "cache_hashmap";
269   new_param_info->set_name(hashmap_name);
270   new_tensor->set_param_info(new_param_info);
271   return func_graph->AddFvParameter(hashmap_name, new_tensor);
272 }
273 
InitStep(const FuncGraphPtr & func_graph,TypeId type_id)274 AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
275   std::vector<int64_t> host_shape{1};
276   auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
277   ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
278   std::string step_name = "cache_step";
279   new_param_info->set_name(step_name);
280   new_tensor->set_param_info(new_param_info);
281   return func_graph->AddFvParameter(step_name, new_tensor);
282 }
283 
CreateMapCacheIdx(const FuncGraphPtr & func_graph,const AnfNodePtr & indices,const ParamMap & cache_host_params_map)284 AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
285                              const ParamMap &cache_host_params_map) {
286   auto iter = cache_host_params_map.begin();
287   int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
288   int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
289   auto indices_type = indices->Type();
290   auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element();
291   auto indices_type_id = indices_element_type->type_id();
292   auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id);
293   auto step = InitStep(func_graph, indices_type_id);
294   auto max_num = NewValueNode(MakeValue(host_size));
295   auto hashmap_param = hashmap->cast<ParameterPtr>();
296   BindAndInitCacheTensor(cache_host_params_map, hashmap_param);
297   // add rank_id
298   int64_t offset_value = 0;
299   std::string rank_id_str = common::GetEnv("RANK_ID");
300   if (!rank_id_str.empty()) {
301     int64_t rank_id = atoi(rank_id_str.c_str());
302     offset_value = rank_id * host_size;
303   }
304   auto offset = NewValueNode(MakeValue(offset_value));
305   auto max_num_imm = std::make_shared<Int64Imm>(host_size);
306   auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
307   max_num->set_abstract(max_num_abstract_scalar);
308   auto offset_imm = std::make_shared<Int64Imm>(offset_value);
309   auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm);
310   offset->set_abstract(offset_abstract_scalar);
311 
312   PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx;
313   map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
314   std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset};
315   auto map_cache_idx = func_graph->NewCNode(map_cache_nodes);
316 
317   auto indices_ori_shp = indices->Shape();
318   auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>();
319   ShapeVector shape(indices_shp->shape().size(), -1);
320 
321   auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp);
322   auto old_emb_idx =
323     std::make_shared<abstract::AbstractTensor>(indices_element_type, std::make_shared<abstract::Shape>(shape));
324   auto miss_emb_idx =
325     std::make_shared<abstract::AbstractTensor>(indices_element_type, std::make_shared<abstract::Shape>(shape));
326   auto swap_emb_idx =
327     std::make_shared<abstract::AbstractTensor>(indices_element_type, std::make_shared<abstract::Shape>(shape));
328 
329   std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
330   auto abstract = std::make_shared<abstract::AbstractTuple>(elements);
331   map_cache_idx->set_abstract(abstract);
332   return map_cache_idx;
333 }
334 
CreateTupleGetItem(const FuncGraphPtr & func_graph,const AnfNodePtr & input,size_t index)335 AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
336   MS_EXCEPTION_IF_NULL(func_graph);
337   auto idx = NewValueNode(SizeToLong(index));
338   MS_EXCEPTION_IF_NULL(idx);
339   auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
340   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
341   idx->set_abstract(abstract_scalar);
342   auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
343   auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
344   auto tuple_getitem_abstract = input_abstract_tuple->elements()[index];
345   tuple_getitem->set_abstract(tuple_getitem_abstract);
346   return tuple_getitem;
347 }
348 
CreateTupleGetItems(const FuncGraphPtr & func_graph,const AnfNodePtr & input,std::vector<AnfNodePtr> * outputs)349 void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) {
350   auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
351   auto size = input_abstract_tuple->elements().size();
352   MS_EXCEPTION_IF_NULL(outputs);
353   for (size_t i = 0; i < size; ++i) {
354     (*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i));
355   }
356 }
357 
CreateEmbeddingLookup(const FuncGraphPtr & graph,AnfNodePtr params,AnfNodePtr indices)358 AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
359   MS_EXCEPTION_IF_NULL(graph);
360   PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
361   emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
362   ValueNodePtr offset_value_node = NewValueNode(static_cast<int64_t>(0));
363   std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices, offset_value_node};
364   auto emb_lookup = graph->NewCNode(emb_lookup_nodes);
365   return emb_lookup;
366 }
367 
CreateCacheSwapTable(const FuncGraphPtr & graph,ParameterPtr cache_table,AnfNodePtr swap_cache_idx,AnfNodePtr miss_value)368 AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
369                                 AnfNodePtr miss_value) {
370   MS_EXCEPTION_IF_NULL(graph);
371   PrimitivePtr cache_swap_table_primitive = std::make_shared<Primitive>(kCacheSwapTableOpName);
372   std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
373                                                  miss_value};
374   auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
375   return cache_swap_table;
376 }
377 
CreateUpdateCache(const FuncGraphPtr & graph,ParameterPtr params,AnfNodePtr old_emb_idx,AnfNodePtr old_value)378 AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
379                              AnfNodePtr old_value) {
380   MS_EXCEPTION_IF_NULL(graph);
381   PrimitivePtr update_cache_primitive = std::make_shared<Primitive>(kUpdateCacheOpName);
382   update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
383 
384   auto params_ori_shp = params->Shape();
385   MS_EXCEPTION_IF_NULL(params_ori_shp);
386   auto params_shp = params_ori_shp->cast<abstract::ShapePtr>();
387   MS_EXCEPTION_IF_NULL(params_shp);
388   auto params_shape = params_shp->shape();
389   auto max_size = params_shape[0];
390   auto max_size_node = NewValueNode(MakeValue(max_size));
391   auto max_num_imm = std::make_shared<Int64Imm>(max_size);
392   auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
393   max_size_node->set_abstract(max_num_abstract_scalar);
394 
395   std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value,
396                                              max_size_node};
397   auto update_cache = graph->NewCNode(update_cache_nodes);
398   return update_cache;
399 }
400 
CreateEmbSwapUpdate(const FuncGraphPtr & graph,ParamMap param_pair_list,const AnfNodePtrList & map_cache_idx_node_outputs)401 NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list,
402                                  const AnfNodePtrList &map_cache_idx_node_outputs) {
403   MS_EXCEPTION_IF_NULL(graph);
404   NodePairList node_pair_list;
405   for (auto &ele : param_pair_list) {
406     auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]);
407     auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup);
408     auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table);
409     node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache));
410   }
411   return node_pair_list;
412 }
413 
CreateControlDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & prior_node,const AnfNodePtr & behind_node)414 void CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node) {
415   // Create control depend
416   MS_EXCEPTION_IF_NULL(main_graph);
417   auto manager = main_graph->manager();
418   MS_EXCEPTION_IF_NULL(manager);
419   AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node};
420   auto depend_cnode = main_graph->NewCNode(cd_inputs);
421   if (!manager->Replace(behind_node, depend_cnode)) {
422     MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed.";
423   }
424 }
425 
CreateDepend(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & invalid_nodes,const AnfNodePtr & patron_node)426 AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes,
427                         const AnfNodePtr &patron_node) {
428   MS_EXCEPTION_IF_NULL(graph);
429   std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)};
430   std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list));
431   auto make_tuple = graph->NewCNode(make_tuple_list);
432   std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple};
433   auto depend_cnode = graph->NewCNode(depend_list);
434   depend_cnode->set_abstract(patron_node->abstract());
435   return depend_cnode;
436 }
437 
FindSparseGatherV2WithCache(const CNodePtrList & cnodes,const ParamSet & param_set)438 CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet &param_set) {
439   size_t cnodes_size = cnodes.size();
440   CNodePtrList sparse_gather_v2_with_cache;
441   for (size_t i = 0; i < cnodes_size; ++i) {
442     if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2) ||
443         IsPrimitiveCNode(cnodes[i], prim::kPrimEmbeddingLookup)) {
444       auto load_node = cnodes[i]->input(1);
445       if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
446         load_node = load_node->cast<CNodePtr>()->input(1);
447       }
448       if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
449         auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
450         if (param_set.find(param_node) != param_set.end()) {
451           sparse_gather_v2_with_cache.push_back(cnodes[i]);
452         } else {
453           MS_LOG(EXCEPTION) << "EmbeddingLookup can't not support cache and no cache in the same graph.";
454         }
455       }
456     }
457   }
458   if (sparse_gather_v2_with_cache.empty()) {
459     MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param.";
460   }
461 
462   auto indices = sparse_gather_v2_with_cache[0]->input(2);
463   for (auto &ele : sparse_gather_v2_with_cache) {
464     if (ele->input(2) != indices) {
465       MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param  have different indices!.";
466     }
467   }
468   return sparse_gather_v2_with_cache;
469 }
470 
FindGatherV2FromSparseGatherV2(const FuncGraphPtr & graph,const AnfNodePtr & node)471 AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) {
472   MS_EXCEPTION_IF_NULL(graph);
473   AnfNodePtrList gatherv2_nodes;
474   auto user_set = graph->manager()->node_users()[node];
475   for (auto &ele : user_set) {
476     if (IsPrimitiveCNode(ele.first, prim::kPrimGather)) {
477       gatherv2_nodes.emplace_back(ele.first);
478     }
479   }
480   if (gatherv2_nodes.size() != 1) {
481     MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got "
482                       << gatherv2_nodes.size();
483   }
484   return gatherv2_nodes[0];
485 }
486 
FindNoRefParams(const FuncGraphPtr & graph)487 AnfSet FindNoRefParams(const FuncGraphPtr &graph) {
488   AnfSet no_ref_params;
489   auto params = graph->parameters();
490   for (auto &anf_param : params) {
491     auto param = anf_param->cast<ParameterPtr>();
492     if (!param->has_default()) {
493       MS_LOG(INFO) << param->DebugString() << " has no default";
494       no_ref_params.insert(anf_param);
495     }
496   }
497   return no_ref_params;
498 }
499 
RemoveOriginParamFromSet(const CNodePtr & unique_node,AnfSet * no_ref_params)500 void RemoveOriginParamFromSet(const CNodePtr &unique_node, AnfSet *no_ref_params) {
501   std::queue<CNodePtr> que;
502   que.push(unique_node);
503   while (!que.empty()) {
504     auto node = que.front();
505     que.pop();
506     auto node_inputs = node->inputs();
507     for (auto &input : node_inputs) {
508       if (input->isa<CNode>()) {
509         que.push(input->cast<CNodePtr>());
510       } else if (input->isa<Parameter>()) {
511         size_t num = no_ref_params->erase(input);
512         if (num > 0) {
513           MS_LOG(INFO) << "Erase unique_node input from set success.";
514           return;
515         }
516       }
517     }
518   }
519   MS_LOG(EXCEPTION) << "Can not find any parameter that use by Unique.";
520 }
521 
CreateOutputNodeParam(const FuncGraphPtr & graph,const AnfNodePtr & ori_input,const std::string & name)522 AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &ori_input, const std::string &name) {
523   auto ori_input_type = ori_input->Type();
524   auto ori_input_element_type = ori_input_type->cast<mindspore::TensorTypePtr>()->element();
525   auto ori_input_type_id = ori_input_element_type->type_id();
526   auto ori_input_shp = ori_input->Shape();
527   auto input_shp = ori_input_shp->cast<abstract::ShapePtr>();
528   auto input_shape = input_shp->shape();
529   auto new_tensor = std::make_shared<tensor::Tensor>(ori_input_type_id, input_shape);
530   ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
531   auto new_param_name = name + "_pipe";
532   new_param_info->set_name(new_param_name);
533   new_tensor->set_param_info(new_param_info);
534   return graph->AddFvParameter(new_param_name, new_tensor);
535 }
536 
CreateOtherPipeParams(const FuncGraphPtr & graph,const AnfSet & no_ref_params)537 AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {
538   AnfMap no_ref_pipe_param_map;
539   for (auto &param : no_ref_params) {
540     auto ori_param = param->cast<ParameterPtr>();
541     auto ori_name = ori_param->name();
542     auto new_param = CreateOutputNodeParam(graph, param, ori_name);
543     no_ref_pipe_param_map[param] = new_param;
544   }
545   return no_ref_pipe_param_map;
546 }
547 
CreateAssign(const FuncGraphPtr & graph,const AnfNodePtr & res_param,const AnfNodePtr & src_param,bool is_dynamic=false)548 AnfNodePtr CreateAssign(const FuncGraphPtr &graph, const AnfNodePtr &res_param, const AnfNodePtr &src_param,
549                         bool is_dynamic = false) {
550   auto assign_prim = prim::kPrimAssign;
551   if (is_dynamic) {
552     assign_prim = prim::kPrimDynamicAssign;
553     assign_prim->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
554   }
555   std::vector<AnfNodePtr> assign_nodes{NewValueNode(assign_prim), res_param, src_param};
556   auto assign_status = graph->NewCNode(assign_nodes);
557   return assign_status;
558 }
559 
FindCNodeOutput(const FuncGraphPtr & graph,const AnfNodePtr & node,int64_t index)560 AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t index) {
561   auto manager = graph->manager();
562   auto node_users = manager->node_users()[node];
563   for (auto &node_user : node_users) {
564     if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
565       auto cnode = node_user.first->cast<CNodePtr>();
566       auto node_index = cnode->input(2);
567       if (node_index->isa<ValueNode>()) {
568         auto value_node = node_index->cast<ValueNodePtr>();
569         MS_EXCEPTION_IF_NULL(value_node);
570         auto item_idx = GetValue<int64_t>(value_node->value());
571         if (item_idx == index) {
572           return node_user.first;
573         }
574       }
575     }
576   }
577   MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index;
578 }
579 
ReplaceNoRefToParams(const FuncGraphPtr & graph,const AnfMap & no_ref_pipe_param_map,const AnfNodePtr & cache_idx_param,const AnfNodePtr & cache_idx,const AnfNodePtr & sparse_gatherv2_indices)580 void ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map,
581                           const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx,
582                           const AnfNodePtr &sparse_gatherv2_indices) {
583   auto manager = graph->manager();
584   MS_EXCEPTION_IF_NULL(manager);
585   const auto &node_users = manager->node_users();
586   // add other no ref pipe param and unique index dense
587   for (auto &ele : no_ref_pipe_param_map) {
588     const auto &user_set = node_users.at(ele.first);
589     auto assign_status = CreateAssign(graph, ele.second, ele.first);
590     for (const auto &user_node : user_set) {
591       CreateControlDepend(graph, user_node.first, assign_status);
592     }
593     if (!manager->Replace(ele.first, ele.second)) {
594       MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed.";
595     }
596   }
597 
598   // add cache idx param
599   auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true);
600   const auto &indices_user_set = node_users.at(sparse_gatherv2_indices);
601   for (const auto &user_node : indices_user_set) {
602     CreateControlDepend(graph, user_node.first, dynamic_assgin_status);
603   }
604   if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) {
605     MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed.";
606   }
607 }
608 
CacheEmbeddingForTrain(const FuncGraphPtr & graph,bool is_pipe,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)609 void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes,
610                             const CNodePtr &unique_node, const ParamSet &param_cache_enable_set) {
611   MS_EXCEPTION_IF_NULL(graph);
612   auto manager = graph->manager();
613   MS_EXCEPTION_IF_NULL(manager);
614   size_t cnodes_size = cnodes.size();
615   auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
616   if (cache_host_params_map.empty()) {
617     MS_LOG(EXCEPTION) << "host's cache parameter map is empty!";
618   }
619   auto param_set = MapKeysToSet(cache_host_params_map);
620   ReplaceCacheParams(graph, cache_host_params_map);
621   graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
622   MS_LOG(INFO) << "Graph is set cache enable.";
623 
624   CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
625   auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
626   auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);
627 
628   AnfNodePtrList map_cache_idx_node_outputs;
629   CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);
630 
631   auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);
632   AnfNodePtrList invalid_nodes;
633   auto cache_idx = map_cache_idx_node_outputs[0];
634   if (!is_pipe) {
635     if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), cache_idx)) {
636       MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
637     }
638     for (auto &ele : node_pair_list) {
639       for (auto &gather_op : sparse_gatherv2_with_cache) {
640         CreateControlDepend(graph, ele.first, gather_op);
641       }
642       invalid_nodes.emplace_back(ele.second);
643     }
644   } else {
645     auto cache_idx_param = CreateOutputNodeParam(graph, unique_node->input(1), std::string("cache_idx"));
646     auto unique_index_reverse = FindCNodeOutput(graph, unique_node, 1);
647     auto unique_index_param = CreateOutputNodeParam(graph, unique_index_reverse, std::string("index_dense"));
648     auto no_ref_params = FindNoRefParams(graph);
649     RemoveOriginParamFromSet(unique_node, &no_ref_params);
650     auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params);
651     no_ref_param_map[unique_index_reverse] = unique_index_param;
652     ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, sparse_gatherv2_with_cache[0]->input(2));
653     std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes),
654                    [](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; });
655   }
656   AnfNodePtr last_node = cnodes[cnodes_size - 1];
657   CNodePtr return_node;
658   if (last_node->isa<CNode>()) {
659     return_node = last_node->cast<CNodePtr>();
660   }
661   MS_EXCEPTION_IF_NULL(return_node);
662   if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
663     MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
664   }
665   if (return_node->size() < 2) {
666     MS_LOG(EXCEPTION) << "Number of return node inputs should be greater than or equal to 2.";
667   }
668 
669   auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
670   if (!manager->Replace(return_node->input(1), depend_node)) {
671     MS_LOG(EXCEPTION) << "Depend replace node failed";
672   }
673 }
674 
CacheEmbeddingForEval(const FuncGraphPtr & graph,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)675 void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes, const CNodePtr &unique_node,
676                            const ParamSet &param_cache_enable_set) {
677   MS_EXCEPTION_IF_NULL(graph);
678   auto manager = graph->manager();
679   MS_EXCEPTION_IF_NULL(manager);
680   graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
681   MS_LOG(INFO) << "Graph is set cache enable.";
682   // replace GatherV2 to EmbeddingLookupCPU
683   auto indices = unique_node->input(1);
684   auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
685   for (auto &ele : sparse_gatherv2_with_cache) {
686     auto anf_ele = ele->cast<AnfNodePtr>();
687     auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
688     auto embedding_lookup = CreateEmbeddingLookup(graph, ele->input(1), indices);
689     if (!manager->Replace(gatherv2, embedding_lookup)) {
690       MS_LOG(EXCEPTION) << "Depend replace node failed";
691     }
692   }
693 }
694 
AddCacheEmbedding(const FuncGraphPtr & graph,bool is_pipe)695 void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
696   MS_EXCEPTION_IF_NULL(graph);
697   std::list<CNodePtr> orders = graph->GetOrderedCnodes();
698   CNodePtrList cnodes(orders.cbegin(), orders.cend());
699   bool training = graph->has_flag("training");
700   auto param_cache_enable_set = FindParamCacheEnable(graph);
701   if (param_cache_enable_set.empty()) {
702     MS_LOG(INFO) << "Parameters are all not cache enable.";
703     return;
704   } else {
705     MS_LOG(INFO) << "Parameters have cache enable.";
706   }
707   if (!CheckHostCacheParamSize(param_cache_enable_set)) {
708     return;
709   }
710   for (auto &node : cnodes) {
711     if (IsPrimitiveCNode(node, prim::kPrimNPUAllocFloatStatus)) {
712       MS_LOG(EXCEPTION) << "Cache embedding haven't support loss scale yet.";
713     }
714   }
715   auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
716   if (unique_cache_enable.empty()) {
717     MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
718     return;
719   }
720   auto unique_node = unique_cache_enable[0];
721   if (training) {
722     // If training, create cache parameters corresponding to the host params with is cache_enable.
723     // Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr.
724     // Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values
725     // from cache to host in each epoch end.
726     // Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to
727     // flush miss values to cache params and write back old values to host params.
728     // If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add
729     // ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph.
730     // If use pipe in training, create parameters for no ref param such as labels and MapCacheIdx output[0] and
731     // Unique output[1], in each step, it will train the data from last step, so that can hide the time of Unique
732     // and other cpu kernels. So in the first step, it's fake data.
733     CacheEmbeddingForTrain(graph, is_pipe, cnodes, unique_node, param_cache_enable_set);
734   } else {
735     // If eval, Use EmbeddingLookup(CPU) op to replace GatherV2.
736     // The network is the same as Host-Device mode.
737     CacheEmbeddingForEval(graph, cnodes, unique_node, param_cache_enable_set);
738   }
739 }
740 }  // namespace parallel
741 }  // namespace mindspore
742