• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/cache_embedding/cache_embedding.h"
18 #include <random>
19 #include <vector>
20 #include <list>
21 #include <queue>
22 #include <utility>
23 #include <memory>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <string>
27 #include <algorithm>
28 #include "backend/optimizer/common/helper.h"
29 #include "frontend/optimizer/optimizer.h"
30 #include "ir/func_graph.h"
31 #include "utils/cache_embedding_hashmap_struct.h"
32 namespace mindspore {
33 namespace parallel {
34 using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>;
35 using ParamSet = std::unordered_set<ParameterPtr>;
36 using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>;
37 using AnfMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
38 using AnfSet = std::unordered_set<AnfNodePtr>;
39 
AddCacheParameters(const FuncGraphPtr & graph,const ParamSet & parameter_cache_enable_set)40 ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet &parameter_cache_enable_set) {
41   ParamMap cache_host_params_map;
42   for (auto &param : parameter_cache_enable_set) {
43     auto param_info = param->param_info();
44     if (param_info && param_info->cache_enable()) {
45       auto data_type = param->Type();
46       auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
47       auto type_id = data_element_type->type_id();
48       auto cache_shape = param_info->cache_shape();
49       auto ori_param_name = param->name();
50       auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape);
51       ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
52       auto cache_name = ori_param_name + "_cache";
53       new_param_info->set_name(cache_name);
54       new_tensor->set_param_info(new_param_info);
55       auto cache_param = graph->AddWeightParameter(cache_name);
56       cache_param->set_default_param(MakeValue(new_tensor));
57       cache_param->set_abstract(new_tensor->ToAbstract());
58       cache_host_params_map[cache_param] = param;
59     }
60   }
61   return cache_host_params_map;
62 }
63 
CheckHostCacheParamSize(const ParamSet & parameter_cache_enable_set)64 bool CheckHostCacheParamSize(const ParamSet &parameter_cache_enable_set) {
65   int64_t host_size = 0;
66   int64_t cache_size = 0;
67   for (auto &host_param : parameter_cache_enable_set) {
68     auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
69     auto host_param_info = host_param->param_info();
70     auto cache_shape = host_param_info->cache_shape();
71     if (cache_shape.empty()) {
72       MS_LOG(EXCEPTION) << "The value of cache_shape is empty.";
73     }
74     auto tmp_cache_size = cache_shape[0];
75     if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) {
76       MS_LOG(EXCEPTION)
77         << "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same.";
78     }
79     cache_size = tmp_cache_size;
80     host_size = tmp_host_size;
81   }
82   if (cache_size > host_size) {
83     MS_LOG(WARNING) << "vocab_cache_size > vocab_size, there is no need use cache.";
84     return false;
85   }
86   return true;
87 }
88 
ReplaceCacheParams(const FuncGraphPtr & graph,const ParamMap & map)89 void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) {
90   auto manager = graph->manager();
91   MS_EXCEPTION_IF_NULL(manager);
92   for (auto &ele : map) {
93     if (!manager->Replace(ele.second, ele.first)) {
94       MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed.";
95     }
96   }
97 }
98 
MapKeysToSet(const ParamMap & map)99 ParamSet MapKeysToSet(const ParamMap &map) {
100   ParamSet set;
101   for (auto &ele : map) {
102     set.insert(ele.first);
103   }
104   return set;
105 }
106 
FindParamCacheEnable(const FuncGraphPtr & graph)107 ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) {
108   ParamSet parameter_cache_enable_set;
109   auto parameters = graph->parameters();
110   auto params_size = parameters.size();
111   for (size_t i = 0; i < params_size; ++i) {
112     auto param = parameters[i]->cast<ParameterPtr>();
113     auto param_info = param->param_info();
114     if (param_info && param_info->cache_enable()) {
115       parameter_cache_enable_set.insert(param);
116     }
117   }
118   return parameter_cache_enable_set;
119 }
120 
FindUniqueCacheEnable(const CNodePtrList & cnodes)121 CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) {
122   size_t cnodes_size = cnodes.size();
123   CNodePtrList unique_cache_enable;
124   for (size_t i = 0; i < cnodes_size; ++i) {
125     if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) {
126       auto unique_node = cnodes[i];
127       auto unique_prim = GetCNodePrimitive(unique_node);
128       MS_EXCEPTION_IF_NULL(unique_prim);
129       auto attr_value = unique_prim->GetAttr(kAttrCacheEnable);
130       if (attr_value != nullptr && GetValue<bool>(attr_value)) {
131         unique_cache_enable.emplace_back(unique_node);
132       }
133     }
134   }
135   if (unique_cache_enable.size() > 1) {
136     MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size();
137   }
138   return unique_cache_enable;
139 }
140 
141 template <typename T>
MemCopyFromHostToCache(void * hashmap_addr,void * host_addr,void * cache_addr,size_t host_max,size_t cache_max,size_t hashmap_size,size_t col_size)142 void MemCopyFromHostToCache(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
143                             size_t hashmap_size, size_t col_size) {
144   auto host_data = static_cast<char *>(host_addr);
145   auto cache_data = static_cast<char *>(cache_addr);
146   auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
147   // default param type float
148   const size_t param_type_size = 4;
149   size_t single_col_bytes = param_type_size * col_size;
150   for (size_t i = 0; i < hashmap_size; ++i) {
151     if (!hashmap_data[i].IsEmpty()) {
152       size_t host_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].key_);
153       size_t cache_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].value_);
154       if (host_offset + single_col_bytes <= host_max) {
155         auto ret =
156           memcpy_s(cache_data + cache_offset, cache_max - cache_offset, host_data + host_offset, single_col_bytes);
157         if (ret != 0) {
158           MS_LOG(EXCEPTION) << "Memcpy failed.";
159         }
160       }
161     }
162   }
163   MS_LOG(INFO) << "Memcpy from cache to host success!";
164 }
165 
BindAndInitCacheTensor(const ParamMap & param_pair_list,const ParameterPtr & hashmap)166 void BindAndInitCacheTensor(const ParamMap &param_pair_list, const ParameterPtr &hashmap) {
167   auto hashmap_tensor_value = hashmap->default_param();
168   auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
169   auto hashmap_size = hashmap_tensor->shape_c()[0];
170   auto hashmap_data_type = hashmap_tensor->data_type();
171   for (auto &ele : param_pair_list) {
172     auto host_tensor_value = ele.second->default_param();
173     auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
174     auto cache_tensor_value = ele.first->default_param();
175     auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
176 
177     // bind host, cache, hashmap
178     host_tensor->set_cache_enable(true);
179     host_tensor->set_hashmap_tensor_ptr(hashmap_tensor);
180     host_tensor->set_cache_tensor_ptr(cache_tensor);
181 
182     // init cache tensor data
183     auto host_shape = host_tensor->shape_c();
184     auto cache_shape = cache_tensor->shape_c();
185     if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
186       MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
187                         << "host shape:" << host_shape << ", cache shape:" << cache_shape;
188     }
189     auto host_data_max_size = static_cast<size_t>(host_tensor->Size());
190     auto cache_data_max_size = static_cast<size_t>(cache_tensor->Size());
191     if (hashmap_data_type == TypeId::kNumberTypeInt32) {
192       MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
193                                       host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
194                                       LongToSize(host_shape[1]));
195     } else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
196       MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
197                                       host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
198                                       LongToSize(host_shape[1]));
199     } else {
200       MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
201     }
202   }
203 }
204 
205 template <typename T>
InitHashMapData(void * data,const int64_t host_size,const int64_t cache_size,const size_t hashmap_size,const size_t byte_size)206 void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size,
207                      const size_t byte_size) {
208   MS_LOG(INFO) << "Start init hashmap data.";
209   MS_EXCEPTION_IF_NULL(data);
210   HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data);
211   MS_EXCEPTION_IF_NULL(hashmap_data);
212   int ret = memset_s(hashmap_data, byte_size, 0, byte_size);
213   if (ret != 0) {
214     MS_LOG(EXCEPTION) << "Memset failed.";
215   }
216   std::vector<T> host_range;
217   host_range.reserve(static_cast<T>(host_size));
218   for (int64_t i = 0; i < host_size; ++i) {
219     host_range.emplace_back(static_cast<T>(i));
220   }
221   std::random_shuffle(host_range.begin(), host_range.end());
222   size_t size = static_cast<size_t>(cache_size);
223   size_t hashmap_count = 0;
224   for (size_t i = 0; i < size; ++i) {
225     auto random_key = host_range[i];
226     auto entry = HashFunc(random_key, hashmap_size);
227     size_t count = 1;
228     while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) {
229       count += 1;
230       entry = (entry + 1) % static_cast<T>(hashmap_size);
231     }
232     if (hashmap_data[entry].IsEmpty()) {
233       hashmap_count++;
234       hashmap_data[entry].key_ = random_key;
235       hashmap_data[entry].value_ = static_cast<T>(i);
236       hashmap_data[entry].step_ = kInitStep;
237       hashmap_data[entry].tag_ = static_cast<T>(count);
238     }
239   }
240   MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size;
241 }
242 
InitHashMap(const FuncGraphPtr & func_graph,const int64_t host_size,const int64_t cache_size,TypeId type_id)243 AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size,
244                        TypeId type_id) {
245   // init new tensor
246   size_t hashmap_size = static_cast<size_t>(cache_size * kEmptyRate);
247   std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4};
248   auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
249   size_t byte_size = new_tensor->Size();
250   if (type_id == TypeId::kNumberTypeInt64) {
251     InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
252   } else {
253     InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
254   }
255   ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
256   std::string hashmap_name = "cache_hashmap";
257   new_param_info->set_name(hashmap_name);
258   new_tensor->set_param_info(new_param_info);
259   auto hashmap = func_graph->AddWeightParameter(hashmap_name);
260   hashmap->set_default_param(MakeValue(new_tensor));
261   hashmap->set_abstract(new_tensor->ToAbstract());
262   return hashmap;
263 }
264 
InitStep(const FuncGraphPtr & func_graph,TypeId type_id)265 AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
266   std::vector<int64_t> host_shape{1};
267   auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
268   ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
269   std::string step_name = "cache_step";
270   new_param_info->set_name(step_name);
271   new_tensor->set_param_info(new_param_info);
272   auto step = func_graph->AddWeightParameter(step_name);
273   step->set_default_param(MakeValue(new_tensor));
274   step->set_abstract(new_tensor->ToAbstract());
275   return step;
276 }
277 
CreateMapCacheIdx(const FuncGraphPtr & func_graph,const AnfNodePtr & indices,const ParamMap & cache_host_params_map)278 AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
279                              const ParamMap &cache_host_params_map) {
280   auto iter = cache_host_params_map.begin();
281   int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
282   int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
283   auto indices_type = indices->Type();
284   auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element();
285   auto indices_type_id = indices_element_type->type_id();
286   auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id);
287   auto step = InitStep(func_graph, indices_type_id);
288   auto max_num = NewValueNode(MakeValue(host_size));
289   auto hashmap_param = hashmap->cast<ParameterPtr>();
290   BindAndInitCacheTensor(cache_host_params_map, hashmap_param);
291   // add rank_id
292   int64_t offset_value = 0;
293   std::string rank_id_str = common::GetEnv("RANK_ID");
294   if (!rank_id_str.empty()) {
295     int64_t rank_id = atoi(rank_id_str.c_str());
296     offset_value = rank_id * host_size;
297   }
298   auto offset = NewValueNode(MakeValue(offset_value));
299   auto max_num_imm = std::make_shared<Int64Imm>(host_size);
300   auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
301   max_num->set_abstract(max_num_abstract_scalar);
302   auto offset_imm = std::make_shared<Int64Imm>(offset_value);
303   auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm);
304   offset->set_abstract(offset_abstract_scalar);
305 
306   PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx;
307   map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
308   std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset};
309   auto map_cache_idx = func_graph->NewCNode(map_cache_nodes);
310 
311   auto indices_ori_shp = indices->Shape();
312   auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>();
313   ShapeVector shape;
314   ShapeVector min_shape;
315   ShapeVector max_shape;
316   if (!indices_shp->max_shape().empty()) {
317     max_shape = indices_shp->max_shape();
318   } else {
319     max_shape = indices_shp->shape();
320   }
321   for (size_t i = 0; i < max_shape.size(); i++) {
322     shape.emplace_back(-1);
323     min_shape.emplace_back(1);
324   }
325 
326   auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp);
327   auto old_emb_idx = std::make_shared<abstract::AbstractTensor>(
328     indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
329   auto miss_emb_idx = std::make_shared<abstract::AbstractTensor>(
330     indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
331   auto swap_emb_idx = std::make_shared<abstract::AbstractTensor>(
332     indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
333 
334   std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
335   auto abstract = std::make_shared<abstract::AbstractTuple>(elements);
336   map_cache_idx->set_abstract(abstract);
337   return map_cache_idx;
338 }
339 
CreateTupleGetItem(const FuncGraphPtr & func_graph,const AnfNodePtr & input,size_t index)340 AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
341   MS_EXCEPTION_IF_NULL(func_graph);
342   auto idx = NewValueNode(SizeToLong(index));
343   MS_EXCEPTION_IF_NULL(idx);
344   auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
345   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
346   idx->set_abstract(abstract_scalar);
347   auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
348   auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
349   auto tuple_getitem_abstract = input_abstract_tuple->elements()[index];
350   tuple_getitem->set_abstract(tuple_getitem_abstract);
351   return tuple_getitem;
352 }
353 
CreateTupleGetItems(const FuncGraphPtr & func_graph,const AnfNodePtr & input,std::vector<AnfNodePtr> * outputs)354 void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) {
355   auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
356   auto size = input_abstract_tuple->elements().size();
357   MS_EXCEPTION_IF_NULL(outputs);
358   for (size_t i = 0; i < size; ++i) {
359     (*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i));
360   }
361 }
362 
CreateEmbeddingLookup(const FuncGraphPtr & graph,AnfNodePtr params,AnfNodePtr indices)363 AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
364   MS_EXCEPTION_IF_NULL(graph);
365   PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
366   emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
367   emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0));
368   std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices};
369   auto emb_lookup = graph->NewCNode(emb_lookup_nodes);
370   return emb_lookup;
371 }
372 
CreateCacheSwapTable(const FuncGraphPtr & graph,ParameterPtr cache_table,AnfNodePtr swap_cache_idx,AnfNodePtr miss_value)373 AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
374                                 AnfNodePtr miss_value) {
375   MS_EXCEPTION_IF_NULL(graph);
376   PrimitivePtr cache_swap_table_primitive = std::make_shared<Primitive>(kCacheSwapTableOpName);
377   std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
378                                                  miss_value};
379   auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
380   return cache_swap_table;
381 }
382 
CreateUpdateCache(const FuncGraphPtr & graph,ParameterPtr params,AnfNodePtr old_emb_idx,AnfNodePtr old_value)383 AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
384                              AnfNodePtr old_value) {
385   MS_EXCEPTION_IF_NULL(graph);
386   PrimitivePtr update_cache_primitive = std::make_shared<Primitive>(kUpdateCacheOpName);
387   update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
388 
389   auto params_ori_shp = params->Shape();
390   MS_EXCEPTION_IF_NULL(params_ori_shp);
391   auto params_shp = params_ori_shp->cast<abstract::ShapePtr>();
392   MS_EXCEPTION_IF_NULL(params_shp);
393   auto params_shape = params_shp->shape();
394   auto max_size = params_shape[0];
395   auto max_size_node = NewValueNode(MakeValue(max_size));
396   auto max_num_imm = std::make_shared<Int64Imm>(max_size);
397   auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
398   max_size_node->set_abstract(max_num_abstract_scalar);
399 
400   std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value,
401                                              max_size_node};
402   auto update_cache = graph->NewCNode(update_cache_nodes);
403   return update_cache;
404 }
405 
CreateEmbSwapUpdate(const FuncGraphPtr & graph,ParamMap param_pair_list,const AnfNodePtrList & map_cache_idx_node_outputs)406 NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list,
407                                  const AnfNodePtrList &map_cache_idx_node_outputs) {
408   MS_EXCEPTION_IF_NULL(graph);
409   NodePairList node_pair_list;
410   for (auto &ele : param_pair_list) {
411     auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]);
412     auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup);
413     auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table);
414     node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache));
415   }
416   return node_pair_list;
417 }
418 
CreateControlDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & prior_node,const AnfNodePtr & behind_node)419 void CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node) {
420   // Create control depend
421   MS_EXCEPTION_IF_NULL(main_graph);
422   auto manager = main_graph->manager();
423   MS_EXCEPTION_IF_NULL(manager);
424   AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node};
425   auto depend_cnode = main_graph->NewCNode(cd_inputs);
426   if (!manager->Replace(behind_node, depend_cnode)) {
427     MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed.";
428   }
429 }
430 
CreateDepend(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & invalid_nodes,const AnfNodePtr & patron_node)431 AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes,
432                         const AnfNodePtr &patron_node) {
433   MS_EXCEPTION_IF_NULL(graph);
434   std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)};
435   std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list));
436   auto make_tuple = graph->NewCNode(make_tuple_list);
437   std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple};
438   auto depend_cnode = graph->NewCNode(depend_list);
439   depend_cnode->set_abstract(patron_node->abstract());
440   return depend_cnode;
441 }
442 
FindSparseGatherV2WithCache(const CNodePtrList & cnodes,const ParamSet & param_set)443 CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet &param_set) {
444   size_t cnodes_size = cnodes.size();
445   CNodePtrList sparse_gather_v2_with_cache;
446   for (size_t i = 0; i < cnodes_size; ++i) {
447     if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2) ||
448         IsPrimitiveCNode(cnodes[i], prim::kPrimEmbeddingLookup)) {
449       auto load_node = cnodes[i]->input(1);
450       if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
451         load_node = load_node->cast<CNodePtr>()->input(1);
452       }
453       if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
454         auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
455         if (param_set.find(param_node) != param_set.end()) {
456           sparse_gather_v2_with_cache.push_back(cnodes[i]);
457         } else {
458           MS_LOG(EXCEPTION) << "EmbeddingLookup can't not support cache and no cache in the same graph.";
459         }
460       }
461     }
462   }
463   if (sparse_gather_v2_with_cache.empty()) {
464     MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param.";
465   }
466 
467   auto indices = sparse_gather_v2_with_cache[0]->input(2);
468   for (auto &ele : sparse_gather_v2_with_cache) {
469     if (ele->input(2) != indices) {
470       MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param  have different indices!.";
471     }
472   }
473   return sparse_gather_v2_with_cache;
474 }
475 
FindGatherV2FromSparseGatherV2(const FuncGraphPtr & graph,const AnfNodePtr & node)476 AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) {
477   MS_EXCEPTION_IF_NULL(graph);
478   AnfNodePtrList gatherv2_nodes;
479   auto user_set = graph->manager()->node_users()[node];
480   for (auto &ele : user_set) {
481     if (IsPrimitiveCNode(ele.first, prim::kPrimGather)) {
482       gatherv2_nodes.emplace_back(ele.first);
483     }
484   }
485   if (gatherv2_nodes.size() != 1) {
486     MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got "
487                       << gatherv2_nodes.size();
488   }
489   return gatherv2_nodes[0];
490 }
491 
FindNoRefParams(const FuncGraphPtr & graph)492 AnfSet FindNoRefParams(const FuncGraphPtr &graph) {
493   AnfSet no_ref_params;
494   auto params = graph->parameters();
495   for (auto &anf_param : params) {
496     auto param = anf_param->cast<ParameterPtr>();
497     if (!param->has_default()) {
498       MS_LOG(INFO) << param->DebugString() << " has no default";
499       no_ref_params.insert(anf_param);
500     }
501   }
502   return no_ref_params;
503 }
504 
RemoveOriginParamFromSet(const CNodePtr & unique_node,AnfSet * no_ref_params)505 void RemoveOriginParamFromSet(const CNodePtr &unique_node, AnfSet *no_ref_params) {
506   std::queue<CNodePtr> que;
507   que.push(unique_node);
508   while (!que.empty()) {
509     auto node = que.front();
510     que.pop();
511     auto node_inputs = node->inputs();
512     for (auto &input : node_inputs) {
513       if (input->isa<CNode>()) {
514         que.push(input->cast<CNodePtr>());
515       } else if (input->isa<Parameter>()) {
516         size_t num = no_ref_params->erase(input);
517         if (num > 0) {
518           MS_LOG(INFO) << "Erase unique_node input from set success.";
519           return;
520         }
521       }
522     }
523   }
524   MS_LOG(EXCEPTION) << "Can not find any parameter that use by Unique.";
525 }
526 
CreateOutputNodeParam(const FuncGraphPtr & graph,const AnfNodePtr & ori_input,const std::string & name)527 AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &ori_input, const std::string &name) {
528   auto ori_input_type = ori_input->Type();
529   auto ori_input_element_type = ori_input_type->cast<mindspore::TensorTypePtr>()->element();
530   auto ori_input_type_id = ori_input_element_type->type_id();
531   auto ori_input_shp = ori_input->Shape();
532   auto input_shp = ori_input_shp->cast<abstract::ShapePtr>();
533   auto input_shape = input_shp->shape();
534   auto new_tensor = std::make_shared<tensor::Tensor>(ori_input_type_id, input_shape);
535   ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
536   auto new_param_name = name + "_pipe";
537   new_param_info->set_name(new_param_name);
538   new_tensor->set_param_info(new_param_info);
539   auto new_param = graph->AddWeightParameter(new_param_name);
540   new_param->set_default_param(MakeValue(new_tensor));
541   auto abs_tensor = new_tensor->ToAbstract();
542   new_param->set_abstract(abs_tensor);
543   return new_param->cast<AnfNodePtr>();
544 }
545 
CreateOtherPipeParams(const FuncGraphPtr & graph,const AnfSet & no_ref_params)546 AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {
547   AnfMap no_ref_pipe_param_map;
548   for (auto &param : no_ref_params) {
549     auto ori_param = param->cast<ParameterPtr>();
550     auto ori_name = ori_param->name();
551     auto new_param = CreateOutputNodeParam(graph, param, ori_name);
552     no_ref_pipe_param_map[param] = new_param;
553   }
554   return no_ref_pipe_param_map;
555 }
556 
CreateAssign(const FuncGraphPtr & graph,const AnfNodePtr & res_param,const AnfNodePtr & src_param,bool is_dynamic=false)557 AnfNodePtr CreateAssign(const FuncGraphPtr &graph, const AnfNodePtr &res_param, const AnfNodePtr &src_param,
558                         bool is_dynamic = false) {
559   auto assign_prim = prim::kPrimAssign;
560   if (is_dynamic) {
561     assign_prim = prim::kPrimDynamicAssign;
562     assign_prim->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
563   }
564   std::vector<AnfNodePtr> assign_nodes{NewValueNode(assign_prim), res_param, src_param};
565   auto assign_status = graph->NewCNode(assign_nodes);
566   return assign_status;
567 }
568 
FindCNodeOutput(const FuncGraphPtr & graph,const AnfNodePtr & node,int64_t index)569 AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t index) {
570   auto manager = graph->manager();
571   auto node_users = manager->node_users()[node];
572   for (auto &node_user : node_users) {
573     if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
574       auto cnode = node_user.first->cast<CNodePtr>();
575       auto node_index = cnode->input(2);
576       if (node_index->isa<ValueNode>()) {
577         auto value_node = node_index->cast<ValueNodePtr>();
578         MS_EXCEPTION_IF_NULL(value_node);
579         auto item_idx = GetValue<int64_t>(value_node->value());
580         if (item_idx == index) {
581           return node_user.first;
582         }
583       }
584     }
585   }
586   MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index;
587 }
588 
ReplaceNoRefToParams(const FuncGraphPtr & graph,const AnfMap & no_ref_pipe_param_map,const AnfNodePtr & cache_idx_param,const AnfNodePtr & cache_idx,const AnfNodePtr & sparse_gatherv2_indices)589 void ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map,
590                           const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx,
591                           const AnfNodePtr &sparse_gatherv2_indices) {
592   auto manager = graph->manager();
593   MS_EXCEPTION_IF_NULL(manager);
594   auto node_users = manager->node_users();
595   // add other no ref pipe param and unique index dense
596   for (auto &ele : no_ref_pipe_param_map) {
597     auto user_set = node_users[ele.first];
598     auto assign_status = CreateAssign(graph, ele.second, ele.first);
599     for (auto user_node : user_set) {
600       CreateControlDepend(graph, user_node.first, assign_status);
601     }
602     if (!manager->Replace(ele.first, ele.second)) {
603       MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed.";
604     }
605   }
606 
607   // add cache idx param
608   auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true);
609   auto indices_user_set = node_users[sparse_gatherv2_indices];
610   for (auto &user_node : indices_user_set) {
611     CreateControlDepend(graph, user_node.first, dynamic_assgin_status);
612   }
613   if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) {
614     MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed.";
615   }
616 }
617 
CacheEmbeddingForTrain(const FuncGraphPtr & graph,bool is_pipe,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)618 void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes,
619                             const CNodePtr &unique_node, const ParamSet &param_cache_enable_set) {
620   MS_EXCEPTION_IF_NULL(graph);
621   auto manager = graph->manager();
622   MS_EXCEPTION_IF_NULL(manager);
623   size_t cnodes_size = cnodes.size();
624   auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
625   auto param_set = MapKeysToSet(cache_host_params_map);
626   ReplaceCacheParams(graph, cache_host_params_map);
627   graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
628   MS_LOG(INFO) << "Graph is set cache enable.";
629 
630   CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
631   auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
632   auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);
633 
634   AnfNodePtrList map_cache_idx_node_outputs;
635   CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);
636 
637   auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);
638   AnfNodePtrList invalid_nodes;
639   auto cache_idx = map_cache_idx_node_outputs[0];
640   if (!is_pipe) {
641     if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), cache_idx)) {
642       MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
643     }
644     for (auto &ele : node_pair_list) {
645       for (auto &gather_op : sparse_gatherv2_with_cache) {
646         CreateControlDepend(graph, ele.first, gather_op);
647       }
648       invalid_nodes.emplace_back(ele.second);
649     }
650   } else {
651     auto cache_idx_param = CreateOutputNodeParam(graph, unique_node->input(1), std::string("cache_idx"));
652     auto unique_index_reverse = FindCNodeOutput(graph, unique_node, 1);
653     auto unique_index_param = CreateOutputNodeParam(graph, unique_index_reverse, std::string("index_dense"));
654     auto no_ref_params = FindNoRefParams(graph);
655     RemoveOriginParamFromSet(unique_node, &no_ref_params);
656     auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params);
657     no_ref_param_map[unique_index_reverse] = unique_index_param;
658     ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, sparse_gatherv2_with_cache[0]->input(2));
659     std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes),
660                    [](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; });
661   }
662   AnfNodePtr last_node = cnodes[cnodes_size - 1];
663   CNodePtr return_node;
664   if (last_node->isa<CNode>()) {
665     return_node = last_node->cast<CNodePtr>();
666   }
667   MS_EXCEPTION_IF_NULL(return_node);
668   if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
669     MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
670   }
671   if (return_node->inputs().size() < 2) {
672     MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2.";
673   }
674 
675   auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
676   if (!manager->Replace(return_node->input(1), depend_node)) {
677     MS_LOG(EXCEPTION) << "Depend replace node failed";
678   }
679 }
680 
CacheEmbeddingForEval(const FuncGraphPtr & graph,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)681 void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes, const CNodePtr &unique_node,
682                            const ParamSet &param_cache_enable_set) {
683   MS_EXCEPTION_IF_NULL(graph);
684   auto manager = graph->manager();
685   MS_EXCEPTION_IF_NULL(manager);
686   graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
687   MS_LOG(INFO) << "Graph is set cache enable.";
688   // replace GatherV2 to EmbeddingLookupCPU
689   auto indices = unique_node->input(1);
690   auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
691   for (auto &ele : sparse_gatherv2_with_cache) {
692     auto anf_ele = ele->cast<AnfNodePtr>();
693     auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
694     auto embedding_lookup = CreateEmbeddingLookup(graph, ele->input(1), indices);
695     if (!manager->Replace(gatherv2, embedding_lookup)) {
696       MS_LOG(EXCEPTION) << "Depend replace node failed";
697     }
698   }
699 }
700 
AddCacheEmbedding(const FuncGraphPtr & graph,bool is_pipe)701 void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
702   MS_EXCEPTION_IF_NULL(graph);
703   std::list<CNodePtr> orders = graph->GetOrderedCnodes();
704   CNodePtrList cnodes(orders.begin(), orders.end());
705   bool training = graph->has_flag("training");
706   auto param_cache_enable_set = FindParamCacheEnable(graph);
707   if (param_cache_enable_set.empty()) {
708     MS_LOG(INFO) << "Parameters are all not cache enable.";
709     return;
710   } else {
711     MS_LOG(INFO) << "Parameters have cache enable.";
712   }
713   if (!CheckHostCacheParamSize(param_cache_enable_set)) {
714     return;
715   }
716   for (auto &node : cnodes) {
717     if (IsPrimitiveCNode(node, prim::kPrimNPUAllocFloatStatus)) {
718       MS_LOG(EXCEPTION) << "Cache embedding haven't support loss scale yet.";
719     }
720   }
721   auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
722   if (unique_cache_enable.empty()) {
723     MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
724     return;
725   }
726   auto unique_node = unique_cache_enable[0];
727   if (training) {
728     // If training, create cache parameters corresponding to the host params with is cache_enable.
729     // Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr.
730     // Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values
731     // from cache to host in each epoch end.
732     // Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to
733     // flush miss values to cache params and write back old values to host params.
734     // If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add
735     // ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph.
736     // If use pipe in training, create parameters for no ref param such as labels and MapCacheIdx output[0] and
737     // Unique output[1], in each step, it will train the data from last step, so that can hide the time of Unique
738     // and other cpu kernels. So in the first step, it's fake data.
739     CacheEmbeddingForTrain(graph, is_pipe, cnodes, unique_node, param_cache_enable_set);
740   } else {
741     // If eval, Use EmbeddingLookup(CPU) op to replace GatherV2.
742     // The network is the same as Host-Device mode.
743     CacheEmbeddingForEval(graph, cnodes, unique_node, param_cache_enable_set);
744   }
745 }
746 }  // namespace parallel
747 }  // namespace mindspore
748