1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/cache_embedding/cache_embedding.h"
18 #include <random>
19 #include <vector>
20 #include <list>
21 #include <queue>
22 #include <utility>
23 #include <memory>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <string>
27 #include <algorithm>
28 #include "backend/optimizer/common/helper.h"
29 #include "frontend/optimizer/optimizer.h"
30 #include "ir/func_graph.h"
31 #include "utils/cache_embedding_hashmap_struct.h"
32 namespace mindspore {
33 namespace parallel {
34 using ParamMap = std::unordered_map<ParameterPtr, ParameterPtr>;
35 using ParamSet = std::unordered_set<ParameterPtr>;
36 using NodePairList = std::vector<std::pair<AnfNodePtr, AnfNodePtr>>;
37 using AnfMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
38 using AnfSet = std::unordered_set<AnfNodePtr>;
39
AddCacheParameters(const FuncGraphPtr & graph,const ParamSet & parameter_cache_enable_set)40 ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter_cache_enable_set) {
41 ParamMap cache_host_params_map;
42 for (auto ¶m : parameter_cache_enable_set) {
43 auto param_info = param->param_info();
44 if (param_info && param_info->cache_enable()) {
45 auto data_type = param->Type();
46 auto data_element_type = data_type->cast<mindspore::TensorTypePtr>()->element();
47 auto type_id = data_element_type->type_id();
48 auto cache_shape = param_info->cache_shape();
49 auto ori_param_name = param->name();
50 auto new_tensor = std::make_shared<tensor::Tensor>(type_id, cache_shape);
51 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
52 auto cache_name = ori_param_name + "_cache";
53 new_param_info->set_name(cache_name);
54 new_tensor->set_param_info(new_param_info);
55 auto cache_param = graph->AddWeightParameter(cache_name);
56 cache_param->set_default_param(MakeValue(new_tensor));
57 cache_param->set_abstract(new_tensor->ToAbstract());
58 cache_host_params_map[cache_param] = param;
59 }
60 }
61 return cache_host_params_map;
62 }
63
CheckHostCacheParamSize(const ParamSet & parameter_cache_enable_set)64 bool CheckHostCacheParamSize(const ParamSet ¶meter_cache_enable_set) {
65 int64_t host_size = 0;
66 int64_t cache_size = 0;
67 for (auto &host_param : parameter_cache_enable_set) {
68 auto tmp_host_size = host_param->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
69 auto host_param_info = host_param->param_info();
70 auto cache_shape = host_param_info->cache_shape();
71 if (cache_shape.empty()) {
72 MS_LOG(EXCEPTION) << "The value of cache_shape is empty.";
73 }
74 auto tmp_cache_size = cache_shape[0];
75 if ((host_size != 0 && tmp_host_size != host_size) || (cache_size != 0 && tmp_cache_size != cache_size)) {
76 MS_LOG(EXCEPTION)
77 << "If EmbeddingLookup are cache enable, vocab_size and vocab_cache_size of different cells must be the same.";
78 }
79 cache_size = tmp_cache_size;
80 host_size = tmp_host_size;
81 }
82 if (cache_size > host_size) {
83 MS_LOG(WARNING) << "vocab_cache_size > vocab_size, there is no need use cache.";
84 return false;
85 }
86 return true;
87 }
88
ReplaceCacheParams(const FuncGraphPtr & graph,const ParamMap & map)89 void ReplaceCacheParams(const FuncGraphPtr &graph, const ParamMap &map) {
90 auto manager = graph->manager();
91 MS_EXCEPTION_IF_NULL(manager);
92 for (auto &ele : map) {
93 if (!manager->Replace(ele.second, ele.first)) {
94 MS_LOG(EXCEPTION) << "host param: " << ele.second->name() << ", replace node failed.";
95 }
96 }
97 }
98
MapKeysToSet(const ParamMap & map)99 ParamSet MapKeysToSet(const ParamMap &map) {
100 ParamSet set;
101 for (auto &ele : map) {
102 set.insert(ele.first);
103 }
104 return set;
105 }
106
FindParamCacheEnable(const FuncGraphPtr & graph)107 ParamSet FindParamCacheEnable(const FuncGraphPtr &graph) {
108 ParamSet parameter_cache_enable_set;
109 auto parameters = graph->parameters();
110 auto params_size = parameters.size();
111 for (size_t i = 0; i < params_size; ++i) {
112 auto param = parameters[i]->cast<ParameterPtr>();
113 auto param_info = param->param_info();
114 if (param_info && param_info->cache_enable()) {
115 parameter_cache_enable_set.insert(param);
116 }
117 }
118 return parameter_cache_enable_set;
119 }
120
FindUniqueCacheEnable(const CNodePtrList & cnodes)121 CNodePtrList FindUniqueCacheEnable(const CNodePtrList &cnodes) {
122 size_t cnodes_size = cnodes.size();
123 CNodePtrList unique_cache_enable;
124 for (size_t i = 0; i < cnodes_size; ++i) {
125 if (IsPrimitiveCNode(cnodes[i], prim::kPrimUnique)) {
126 auto unique_node = cnodes[i];
127 auto unique_prim = GetCNodePrimitive(unique_node);
128 MS_EXCEPTION_IF_NULL(unique_prim);
129 auto attr_value = unique_prim->GetAttr(kAttrCacheEnable);
130 if (attr_value != nullptr && GetValue<bool>(attr_value)) {
131 unique_cache_enable.emplace_back(unique_node);
132 }
133 }
134 }
135 if (unique_cache_enable.size() > 1) {
136 MS_LOG(EXCEPTION) << "Support only one of Unique op cache enable, but got " << unique_cache_enable.size();
137 }
138 return unique_cache_enable;
139 }
140
141 template <typename T>
MemCopyFromHostToCache(void * hashmap_addr,void * host_addr,void * cache_addr,size_t host_max,size_t cache_max,size_t hashmap_size,size_t col_size)142 void MemCopyFromHostToCache(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
143 size_t hashmap_size, size_t col_size) {
144 auto host_data = static_cast<char *>(host_addr);
145 auto cache_data = static_cast<char *>(cache_addr);
146 auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
147 // default param type float
148 const size_t param_type_size = 4;
149 size_t single_col_bytes = param_type_size * col_size;
150 for (size_t i = 0; i < hashmap_size; ++i) {
151 if (!hashmap_data[i].IsEmpty()) {
152 size_t host_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].key_);
153 size_t cache_offset = single_col_bytes * static_cast<size_t>(hashmap_data[i].value_);
154 if (host_offset + single_col_bytes <= host_max) {
155 auto ret =
156 memcpy_s(cache_data + cache_offset, cache_max - cache_offset, host_data + host_offset, single_col_bytes);
157 if (ret != 0) {
158 MS_LOG(EXCEPTION) << "Memcpy failed.";
159 }
160 }
161 }
162 }
163 MS_LOG(INFO) << "Memcpy from cache to host success!";
164 }
165
BindAndInitCacheTensor(const ParamMap & param_pair_list,const ParameterPtr & hashmap)166 void BindAndInitCacheTensor(const ParamMap ¶m_pair_list, const ParameterPtr &hashmap) {
167 auto hashmap_tensor_value = hashmap->default_param();
168 auto hashmap_tensor = hashmap_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
169 auto hashmap_size = hashmap_tensor->shape_c()[0];
170 auto hashmap_data_type = hashmap_tensor->data_type();
171 for (auto &ele : param_pair_list) {
172 auto host_tensor_value = ele.second->default_param();
173 auto host_tensor = host_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
174 auto cache_tensor_value = ele.first->default_param();
175 auto cache_tensor = cache_tensor_value->cast<std::shared_ptr<tensor::Tensor>>();
176
177 // bind host, cache, hashmap
178 host_tensor->set_cache_enable(true);
179 host_tensor->set_hashmap_tensor_ptr(hashmap_tensor);
180 host_tensor->set_cache_tensor_ptr(cache_tensor);
181
182 // init cache tensor data
183 auto host_shape = host_tensor->shape_c();
184 auto cache_shape = cache_tensor->shape_c();
185 if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
186 MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
187 << "host shape:" << host_shape << ", cache shape:" << cache_shape;
188 }
189 auto host_data_max_size = static_cast<size_t>(host_tensor->Size());
190 auto cache_data_max_size = static_cast<size_t>(cache_tensor->Size());
191 if (hashmap_data_type == TypeId::kNumberTypeInt32) {
192 MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
193 host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
194 LongToSize(host_shape[1]));
195 } else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
196 MemCopyFromHostToCache<int32_t>(hashmap_tensor->data_c(), host_tensor->data_c(), cache_tensor->data_c(),
197 host_data_max_size, cache_data_max_size, LongToSize(hashmap_size),
198 LongToSize(host_shape[1]));
199 } else {
200 MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
201 }
202 }
203 }
204
205 template <typename T>
InitHashMapData(void * data,const int64_t host_size,const int64_t cache_size,const size_t hashmap_size,const size_t byte_size)206 void InitHashMapData(void *data, const int64_t host_size, const int64_t cache_size, const size_t hashmap_size,
207 const size_t byte_size) {
208 MS_LOG(INFO) << "Start init hashmap data.";
209 MS_EXCEPTION_IF_NULL(data);
210 HashmapEntry<T> *hashmap_data = static_cast<HashmapEntry<T> *>(data);
211 MS_EXCEPTION_IF_NULL(hashmap_data);
212 int ret = memset_s(hashmap_data, byte_size, 0, byte_size);
213 if (ret != 0) {
214 MS_LOG(EXCEPTION) << "Memset failed.";
215 }
216 std::vector<T> host_range;
217 host_range.reserve(static_cast<T>(host_size));
218 for (int64_t i = 0; i < host_size; ++i) {
219 host_range.emplace_back(static_cast<T>(i));
220 }
221 std::random_shuffle(host_range.begin(), host_range.end());
222 size_t size = static_cast<size_t>(cache_size);
223 size_t hashmap_count = 0;
224 for (size_t i = 0; i < size; ++i) {
225 auto random_key = host_range[i];
226 auto entry = HashFunc(random_key, hashmap_size);
227 size_t count = 1;
228 while (!hashmap_data[entry].IsEmpty() && !hashmap_data[entry].IsKey(random_key)) {
229 count += 1;
230 entry = (entry + 1) % static_cast<T>(hashmap_size);
231 }
232 if (hashmap_data[entry].IsEmpty()) {
233 hashmap_count++;
234 hashmap_data[entry].key_ = random_key;
235 hashmap_data[entry].value_ = static_cast<T>(i);
236 hashmap_data[entry].step_ = kInitStep;
237 hashmap_data[entry].tag_ = static_cast<T>(count);
238 }
239 }
240 MS_LOG(INFO) << "Hashmap init success, with " << hashmap_count << " / " << hashmap_size;
241 }
242
InitHashMap(const FuncGraphPtr & func_graph,const int64_t host_size,const int64_t cache_size,TypeId type_id)243 AnfNodePtr InitHashMap(const FuncGraphPtr &func_graph, const int64_t host_size, const int64_t cache_size,
244 TypeId type_id) {
245 // init new tensor
246 size_t hashmap_size = static_cast<size_t>(cache_size * kEmptyRate);
247 std::vector<int64_t> host_shape{static_cast<int64_t>(hashmap_size), 4};
248 auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
249 size_t byte_size = new_tensor->Size();
250 if (type_id == TypeId::kNumberTypeInt64) {
251 InitHashMapData<int64_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
252 } else {
253 InitHashMapData<int32_t>(new_tensor->data_c(), host_size, cache_size, hashmap_size, byte_size);
254 }
255 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
256 std::string hashmap_name = "cache_hashmap";
257 new_param_info->set_name(hashmap_name);
258 new_tensor->set_param_info(new_param_info);
259 auto hashmap = func_graph->AddWeightParameter(hashmap_name);
260 hashmap->set_default_param(MakeValue(new_tensor));
261 hashmap->set_abstract(new_tensor->ToAbstract());
262 return hashmap;
263 }
264
InitStep(const FuncGraphPtr & func_graph,TypeId type_id)265 AnfNodePtr InitStep(const FuncGraphPtr &func_graph, TypeId type_id) {
266 std::vector<int64_t> host_shape{1};
267 auto new_tensor = std::make_shared<tensor::Tensor>(type_id, host_shape);
268 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
269 std::string step_name = "cache_step";
270 new_param_info->set_name(step_name);
271 new_tensor->set_param_info(new_param_info);
272 auto step = func_graph->AddWeightParameter(step_name);
273 step->set_default_param(MakeValue(new_tensor));
274 step->set_abstract(new_tensor->ToAbstract());
275 return step;
276 }
277
CreateMapCacheIdx(const FuncGraphPtr & func_graph,const AnfNodePtr & indices,const ParamMap & cache_host_params_map)278 AnfNodePtr CreateMapCacheIdx(const FuncGraphPtr &func_graph, const AnfNodePtr &indices,
279 const ParamMap &cache_host_params_map) {
280 auto iter = cache_host_params_map.begin();
281 int64_t cache_size = iter->first->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
282 int64_t host_size = iter->second->abstract()->GetShapeTrack()->cast<abstract::ShapePtr>()->shape()[0];
283 auto indices_type = indices->Type();
284 auto indices_element_type = indices_type->cast<mindspore::TensorTypePtr>()->element();
285 auto indices_type_id = indices_element_type->type_id();
286 auto hashmap = InitHashMap(func_graph, host_size, cache_size, indices_type_id);
287 auto step = InitStep(func_graph, indices_type_id);
288 auto max_num = NewValueNode(MakeValue(host_size));
289 auto hashmap_param = hashmap->cast<ParameterPtr>();
290 BindAndInitCacheTensor(cache_host_params_map, hashmap_param);
291 // add rank_id
292 int64_t offset_value = 0;
293 std::string rank_id_str = common::GetEnv("RANK_ID");
294 if (!rank_id_str.empty()) {
295 int64_t rank_id = atoi(rank_id_str.c_str());
296 offset_value = rank_id * host_size;
297 }
298 auto offset = NewValueNode(MakeValue(offset_value));
299 auto max_num_imm = std::make_shared<Int64Imm>(host_size);
300 auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
301 max_num->set_abstract(max_num_abstract_scalar);
302 auto offset_imm = std::make_shared<Int64Imm>(offset_value);
303 auto offset_abstract_scalar = std::make_shared<abstract::AbstractScalar>(offset_imm);
304 offset->set_abstract(offset_abstract_scalar);
305
306 PrimitivePtr map_cache_primitive = prim::kPrimMapCacheIdx;
307 map_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
308 std::vector<AnfNodePtr> map_cache_nodes{NewValueNode(map_cache_primitive), hashmap, indices, step, max_num, offset};
309 auto map_cache_idx = func_graph->NewCNode(map_cache_nodes);
310
311 auto indices_ori_shp = indices->Shape();
312 auto indices_shp = indices_ori_shp->cast<abstract::ShapePtr>();
313 ShapeVector shape;
314 ShapeVector min_shape;
315 ShapeVector max_shape;
316 if (!indices_shp->max_shape().empty()) {
317 max_shape = indices_shp->max_shape();
318 } else {
319 max_shape = indices_shp->shape();
320 }
321 for (size_t i = 0; i < max_shape.size(); i++) {
322 shape.emplace_back(-1);
323 min_shape.emplace_back(1);
324 }
325
326 auto cache_idx = std::make_shared<abstract::AbstractTensor>(indices_element_type, indices_shp);
327 auto old_emb_idx = std::make_shared<abstract::AbstractTensor>(
328 indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
329 auto miss_emb_idx = std::make_shared<abstract::AbstractTensor>(
330 indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
331 auto swap_emb_idx = std::make_shared<abstract::AbstractTensor>(
332 indices_element_type, std::make_shared<abstract::Shape>(shape, min_shape, max_shape));
333
334 std::vector<std::shared_ptr<abstract::AbstractBase>> elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
335 auto abstract = std::make_shared<abstract::AbstractTuple>(elements);
336 map_cache_idx->set_abstract(abstract);
337 return map_cache_idx;
338 }
339
CreateTupleGetItem(const FuncGraphPtr & func_graph,const AnfNodePtr & input,size_t index)340 AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) {
341 MS_EXCEPTION_IF_NULL(func_graph);
342 auto idx = NewValueNode(SizeToLong(index));
343 MS_EXCEPTION_IF_NULL(idx);
344 auto imm = std::make_shared<Int64Imm>(SizeToLong(index));
345 auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
346 idx->set_abstract(abstract_scalar);
347 auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx});
348 auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
349 auto tuple_getitem_abstract = input_abstract_tuple->elements()[index];
350 tuple_getitem->set_abstract(tuple_getitem_abstract);
351 return tuple_getitem;
352 }
353
CreateTupleGetItems(const FuncGraphPtr & func_graph,const AnfNodePtr & input,std::vector<AnfNodePtr> * outputs)354 void CreateTupleGetItems(const FuncGraphPtr &func_graph, const AnfNodePtr &input, std::vector<AnfNodePtr> *outputs) {
355 auto input_abstract_tuple = dyn_cast<abstract::AbstractTuple>(input->abstract());
356 auto size = input_abstract_tuple->elements().size();
357 MS_EXCEPTION_IF_NULL(outputs);
358 for (size_t i = 0; i < size; ++i) {
359 (*outputs).emplace_back(CreateTupleGetItem(func_graph, input, i));
360 }
361 }
362
CreateEmbeddingLookup(const FuncGraphPtr & graph,AnfNodePtr params,AnfNodePtr indices)363 AnfNodePtr CreateEmbeddingLookup(const FuncGraphPtr &graph, AnfNodePtr params, AnfNodePtr indices) {
364 MS_EXCEPTION_IF_NULL(graph);
365 PrimitivePtr emb_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
366 emb_lookup_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
367 emb_lookup_primitive->set_attr(kAttrOffset, MakeValue<int64_t>(0));
368 std::vector<AnfNodePtr> emb_lookup_nodes{NewValueNode(emb_lookup_primitive), params, indices};
369 auto emb_lookup = graph->NewCNode(emb_lookup_nodes);
370 return emb_lookup;
371 }
372
CreateCacheSwapTable(const FuncGraphPtr & graph,ParameterPtr cache_table,AnfNodePtr swap_cache_idx,AnfNodePtr miss_value)373 AnfNodePtr CreateCacheSwapTable(const FuncGraphPtr &graph, ParameterPtr cache_table, AnfNodePtr swap_cache_idx,
374 AnfNodePtr miss_value) {
375 MS_EXCEPTION_IF_NULL(graph);
376 PrimitivePtr cache_swap_table_primitive = std::make_shared<Primitive>(kCacheSwapTableOpName);
377 std::vector<AnfNodePtr> cache_swap_table_nodes{NewValueNode(cache_swap_table_primitive), cache_table, swap_cache_idx,
378 miss_value};
379 auto cache_swap_table = graph->NewCNode(cache_swap_table_nodes);
380 return cache_swap_table;
381 }
382
CreateUpdateCache(const FuncGraphPtr & graph,ParameterPtr params,AnfNodePtr old_emb_idx,AnfNodePtr old_value)383 AnfNodePtr CreateUpdateCache(const FuncGraphPtr &graph, ParameterPtr params, AnfNodePtr old_emb_idx,
384 AnfNodePtr old_value) {
385 MS_EXCEPTION_IF_NULL(graph);
386 PrimitivePtr update_cache_primitive = std::make_shared<Primitive>(kUpdateCacheOpName);
387 update_cache_primitive->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
388
389 auto params_ori_shp = params->Shape();
390 MS_EXCEPTION_IF_NULL(params_ori_shp);
391 auto params_shp = params_ori_shp->cast<abstract::ShapePtr>();
392 MS_EXCEPTION_IF_NULL(params_shp);
393 auto params_shape = params_shp->shape();
394 auto max_size = params_shape[0];
395 auto max_size_node = NewValueNode(MakeValue(max_size));
396 auto max_num_imm = std::make_shared<Int64Imm>(max_size);
397 auto max_num_abstract_scalar = std::make_shared<abstract::AbstractScalar>(max_num_imm);
398 max_size_node->set_abstract(max_num_abstract_scalar);
399
400 std::vector<AnfNodePtr> update_cache_nodes{NewValueNode(update_cache_primitive), params, old_emb_idx, old_value,
401 max_size_node};
402 auto update_cache = graph->NewCNode(update_cache_nodes);
403 return update_cache;
404 }
405
CreateEmbSwapUpdate(const FuncGraphPtr & graph,ParamMap param_pair_list,const AnfNodePtrList & map_cache_idx_node_outputs)406 NodePairList CreateEmbSwapUpdate(const FuncGraphPtr &graph, ParamMap param_pair_list,
407 const AnfNodePtrList &map_cache_idx_node_outputs) {
408 MS_EXCEPTION_IF_NULL(graph);
409 NodePairList node_pair_list;
410 for (auto &ele : param_pair_list) {
411 auto emb_lookup = CreateEmbeddingLookup(graph, ele.second, map_cache_idx_node_outputs[2]);
412 auto cache_swap_table = CreateCacheSwapTable(graph, ele.first, map_cache_idx_node_outputs[3], emb_lookup);
413 auto update_cache = CreateUpdateCache(graph, ele.second, map_cache_idx_node_outputs[1], cache_swap_table);
414 node_pair_list.emplace_back(std::make_pair(cache_swap_table, update_cache));
415 }
416 return node_pair_list;
417 }
418
CreateControlDepend(const FuncGraphPtr & main_graph,const AnfNodePtr & prior_node,const AnfNodePtr & behind_node)419 void CreateControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node) {
420 // Create control depend
421 MS_EXCEPTION_IF_NULL(main_graph);
422 auto manager = main_graph->manager();
423 MS_EXCEPTION_IF_NULL(manager);
424 AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node};
425 auto depend_cnode = main_graph->NewCNode(cd_inputs);
426 if (!manager->Replace(behind_node, depend_cnode)) {
427 MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed.";
428 }
429 }
430
CreateDepend(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & invalid_nodes,const AnfNodePtr & patron_node)431 AnfNodePtr CreateDepend(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &invalid_nodes,
432 const AnfNodePtr &patron_node) {
433 MS_EXCEPTION_IF_NULL(graph);
434 std::vector<AnfNodePtr> make_tuple_list{NewValueNode(prim::kPrimMakeTuple)};
435 std::copy(invalid_nodes.begin(), invalid_nodes.end(), std::back_inserter(make_tuple_list));
436 auto make_tuple = graph->NewCNode(make_tuple_list);
437 std::vector<AnfNodePtr> depend_list{NewValueNode(prim::kPrimDepend), patron_node, make_tuple};
438 auto depend_cnode = graph->NewCNode(depend_list);
439 depend_cnode->set_abstract(patron_node->abstract());
440 return depend_cnode;
441 }
442
FindSparseGatherV2WithCache(const CNodePtrList & cnodes,const ParamSet & param_set)443 CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const ParamSet ¶m_set) {
444 size_t cnodes_size = cnodes.size();
445 CNodePtrList sparse_gather_v2_with_cache;
446 for (size_t i = 0; i < cnodes_size; ++i) {
447 if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2) ||
448 IsPrimitiveCNode(cnodes[i], prim::kPrimEmbeddingLookup)) {
449 auto load_node = cnodes[i]->input(1);
450 if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
451 load_node = load_node->cast<CNodePtr>()->input(1);
452 }
453 if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
454 auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
455 if (param_set.find(param_node) != param_set.end()) {
456 sparse_gather_v2_with_cache.push_back(cnodes[i]);
457 } else {
458 MS_LOG(EXCEPTION) << "EmbeddingLookup can't not support cache and no cache in the same graph.";
459 }
460 }
461 }
462 }
463 if (sparse_gather_v2_with_cache.empty()) {
464 MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param.";
465 }
466
467 auto indices = sparse_gather_v2_with_cache[0]->input(2);
468 for (auto &ele : sparse_gather_v2_with_cache) {
469 if (ele->input(2) != indices) {
470 MS_LOG(EXCEPTION) << "SparseGatherV2 which with cache param have different indices!.";
471 }
472 }
473 return sparse_gather_v2_with_cache;
474 }
475
FindGatherV2FromSparseGatherV2(const FuncGraphPtr & graph,const AnfNodePtr & node)476 AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNodePtr &node) {
477 MS_EXCEPTION_IF_NULL(graph);
478 AnfNodePtrList gatherv2_nodes;
479 auto user_set = graph->manager()->node_users()[node];
480 for (auto &ele : user_set) {
481 if (IsPrimitiveCNode(ele.first, prim::kPrimGather)) {
482 gatherv2_nodes.emplace_back(ele.first);
483 }
484 }
485 if (gatherv2_nodes.size() != 1) {
486 MS_LOG(EXCEPTION) << "SparseGatherV2 with cache can only used by one of gatherv2, but got "
487 << gatherv2_nodes.size();
488 }
489 return gatherv2_nodes[0];
490 }
491
FindNoRefParams(const FuncGraphPtr & graph)492 AnfSet FindNoRefParams(const FuncGraphPtr &graph) {
493 AnfSet no_ref_params;
494 auto params = graph->parameters();
495 for (auto &anf_param : params) {
496 auto param = anf_param->cast<ParameterPtr>();
497 if (!param->has_default()) {
498 MS_LOG(INFO) << param->DebugString() << " has no default";
499 no_ref_params.insert(anf_param);
500 }
501 }
502 return no_ref_params;
503 }
504
RemoveOriginParamFromSet(const CNodePtr & unique_node,AnfSet * no_ref_params)505 void RemoveOriginParamFromSet(const CNodePtr &unique_node, AnfSet *no_ref_params) {
506 std::queue<CNodePtr> que;
507 que.push(unique_node);
508 while (!que.empty()) {
509 auto node = que.front();
510 que.pop();
511 auto node_inputs = node->inputs();
512 for (auto &input : node_inputs) {
513 if (input->isa<CNode>()) {
514 que.push(input->cast<CNodePtr>());
515 } else if (input->isa<Parameter>()) {
516 size_t num = no_ref_params->erase(input);
517 if (num > 0) {
518 MS_LOG(INFO) << "Erase unique_node input from set success.";
519 return;
520 }
521 }
522 }
523 }
524 MS_LOG(EXCEPTION) << "Can not find any parameter that use by Unique.";
525 }
526
CreateOutputNodeParam(const FuncGraphPtr & graph,const AnfNodePtr & ori_input,const std::string & name)527 AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &ori_input, const std::string &name) {
528 auto ori_input_type = ori_input->Type();
529 auto ori_input_element_type = ori_input_type->cast<mindspore::TensorTypePtr>()->element();
530 auto ori_input_type_id = ori_input_element_type->type_id();
531 auto ori_input_shp = ori_input->Shape();
532 auto input_shp = ori_input_shp->cast<abstract::ShapePtr>();
533 auto input_shape = input_shp->shape();
534 auto new_tensor = std::make_shared<tensor::Tensor>(ori_input_type_id, input_shape);
535 ParamInfoPtr new_param_info = std::make_shared<ParamInfo>();
536 auto new_param_name = name + "_pipe";
537 new_param_info->set_name(new_param_name);
538 new_tensor->set_param_info(new_param_info);
539 auto new_param = graph->AddWeightParameter(new_param_name);
540 new_param->set_default_param(MakeValue(new_tensor));
541 auto abs_tensor = new_tensor->ToAbstract();
542 new_param->set_abstract(abs_tensor);
543 return new_param->cast<AnfNodePtr>();
544 }
545
CreateOtherPipeParams(const FuncGraphPtr & graph,const AnfSet & no_ref_params)546 AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) {
547 AnfMap no_ref_pipe_param_map;
548 for (auto ¶m : no_ref_params) {
549 auto ori_param = param->cast<ParameterPtr>();
550 auto ori_name = ori_param->name();
551 auto new_param = CreateOutputNodeParam(graph, param, ori_name);
552 no_ref_pipe_param_map[param] = new_param;
553 }
554 return no_ref_pipe_param_map;
555 }
556
CreateAssign(const FuncGraphPtr & graph,const AnfNodePtr & res_param,const AnfNodePtr & src_param,bool is_dynamic=false)557 AnfNodePtr CreateAssign(const FuncGraphPtr &graph, const AnfNodePtr &res_param, const AnfNodePtr &src_param,
558 bool is_dynamic = false) {
559 auto assign_prim = prim::kPrimAssign;
560 if (is_dynamic) {
561 assign_prim = prim::kPrimDynamicAssign;
562 assign_prim->set_attr(kAttrPrimitiveTarget, MakeValue("CPU"));
563 }
564 std::vector<AnfNodePtr> assign_nodes{NewValueNode(assign_prim), res_param, src_param};
565 auto assign_status = graph->NewCNode(assign_nodes);
566 return assign_status;
567 }
568
FindCNodeOutput(const FuncGraphPtr & graph,const AnfNodePtr & node,int64_t index)569 AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t index) {
570 auto manager = graph->manager();
571 auto node_users = manager->node_users()[node];
572 for (auto &node_user : node_users) {
573 if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) {
574 auto cnode = node_user.first->cast<CNodePtr>();
575 auto node_index = cnode->input(2);
576 if (node_index->isa<ValueNode>()) {
577 auto value_node = node_index->cast<ValueNodePtr>();
578 MS_EXCEPTION_IF_NULL(value_node);
579 auto item_idx = GetValue<int64_t>(value_node->value());
580 if (item_idx == index) {
581 return node_user.first;
582 }
583 }
584 }
585 }
586 MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index;
587 }
588
ReplaceNoRefToParams(const FuncGraphPtr & graph,const AnfMap & no_ref_pipe_param_map,const AnfNodePtr & cache_idx_param,const AnfNodePtr & cache_idx,const AnfNodePtr & sparse_gatherv2_indices)589 void ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map,
590 const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx,
591 const AnfNodePtr &sparse_gatherv2_indices) {
592 auto manager = graph->manager();
593 MS_EXCEPTION_IF_NULL(manager);
594 auto node_users = manager->node_users();
595 // add other no ref pipe param and unique index dense
596 for (auto &ele : no_ref_pipe_param_map) {
597 auto user_set = node_users[ele.first];
598 auto assign_status = CreateAssign(graph, ele.second, ele.first);
599 for (auto user_node : user_set) {
600 CreateControlDepend(graph, user_node.first, assign_status);
601 }
602 if (!manager->Replace(ele.first, ele.second)) {
603 MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed.";
604 }
605 }
606
607 // add cache idx param
608 auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true);
609 auto indices_user_set = node_users[sparse_gatherv2_indices];
610 for (auto &user_node : indices_user_set) {
611 CreateControlDepend(graph, user_node.first, dynamic_assgin_status);
612 }
613 if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) {
614 MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed.";
615 }
616 }
617
CacheEmbeddingForTrain(const FuncGraphPtr & graph,bool is_pipe,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)618 void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes,
619 const CNodePtr &unique_node, const ParamSet ¶m_cache_enable_set) {
620 MS_EXCEPTION_IF_NULL(graph);
621 auto manager = graph->manager();
622 MS_EXCEPTION_IF_NULL(manager);
623 size_t cnodes_size = cnodes.size();
624 auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set);
625 auto param_set = MapKeysToSet(cache_host_params_map);
626 ReplaceCacheParams(graph, cache_host_params_map);
627 graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
628 MS_LOG(INFO) << "Graph is set cache enable.";
629
630 CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set);
631 auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0);
632 auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map);
633
634 AnfNodePtrList map_cache_idx_node_outputs;
635 CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs);
636
637 auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs);
638 AnfNodePtrList invalid_nodes;
639 auto cache_idx = map_cache_idx_node_outputs[0];
640 if (!is_pipe) {
641 if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), cache_idx)) {
642 MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed";
643 }
644 for (auto &ele : node_pair_list) {
645 for (auto &gather_op : sparse_gatherv2_with_cache) {
646 CreateControlDepend(graph, ele.first, gather_op);
647 }
648 invalid_nodes.emplace_back(ele.second);
649 }
650 } else {
651 auto cache_idx_param = CreateOutputNodeParam(graph, unique_node->input(1), std::string("cache_idx"));
652 auto unique_index_reverse = FindCNodeOutput(graph, unique_node, 1);
653 auto unique_index_param = CreateOutputNodeParam(graph, unique_index_reverse, std::string("index_dense"));
654 auto no_ref_params = FindNoRefParams(graph);
655 RemoveOriginParamFromSet(unique_node, &no_ref_params);
656 auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params);
657 no_ref_param_map[unique_index_reverse] = unique_index_param;
658 ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, sparse_gatherv2_with_cache[0]->input(2));
659 std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes),
660 [](const std::pair<AnfNodePtr, AnfNodePtr> &pair) { return pair.second; });
661 }
662 AnfNodePtr last_node = cnodes[cnodes_size - 1];
663 CNodePtr return_node;
664 if (last_node->isa<CNode>()) {
665 return_node = last_node->cast<CNodePtr>();
666 }
667 MS_EXCEPTION_IF_NULL(return_node);
668 if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) {
669 MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode.";
670 }
671 if (return_node->inputs().size() < 2) {
672 MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2.";
673 }
674
675 auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1));
676 if (!manager->Replace(return_node->input(1), depend_node)) {
677 MS_LOG(EXCEPTION) << "Depend replace node failed";
678 }
679 }
680
CacheEmbeddingForEval(const FuncGraphPtr & graph,const CNodePtrList & cnodes,const CNodePtr & unique_node,const ParamSet & param_cache_enable_set)681 void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes, const CNodePtr &unique_node,
682 const ParamSet ¶m_cache_enable_set) {
683 MS_EXCEPTION_IF_NULL(graph);
684 auto manager = graph->manager();
685 MS_EXCEPTION_IF_NULL(manager);
686 graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true);
687 MS_LOG(INFO) << "Graph is set cache enable.";
688 // replace GatherV2 to EmbeddingLookupCPU
689 auto indices = unique_node->input(1);
690 auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set);
691 for (auto &ele : sparse_gatherv2_with_cache) {
692 auto anf_ele = ele->cast<AnfNodePtr>();
693 auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele);
694 auto embedding_lookup = CreateEmbeddingLookup(graph, ele->input(1), indices);
695 if (!manager->Replace(gatherv2, embedding_lookup)) {
696 MS_LOG(EXCEPTION) << "Depend replace node failed";
697 }
698 }
699 }
700
AddCacheEmbedding(const FuncGraphPtr & graph,bool is_pipe)701 void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) {
702 MS_EXCEPTION_IF_NULL(graph);
703 std::list<CNodePtr> orders = graph->GetOrderedCnodes();
704 CNodePtrList cnodes(orders.begin(), orders.end());
705 bool training = graph->has_flag("training");
706 auto param_cache_enable_set = FindParamCacheEnable(graph);
707 if (param_cache_enable_set.empty()) {
708 MS_LOG(INFO) << "Parameters are all not cache enable.";
709 return;
710 } else {
711 MS_LOG(INFO) << "Parameters have cache enable.";
712 }
713 if (!CheckHostCacheParamSize(param_cache_enable_set)) {
714 return;
715 }
716 for (auto &node : cnodes) {
717 if (IsPrimitiveCNode(node, prim::kPrimNPUAllocFloatStatus)) {
718 MS_LOG(EXCEPTION) << "Cache embedding haven't support loss scale yet.";
719 }
720 }
721 auto unique_cache_enable = FindUniqueCacheEnable(cnodes);
722 if (unique_cache_enable.empty()) {
723 MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable.";
724 return;
725 }
726 auto unique_node = unique_cache_enable[0];
727 if (training) {
728 // If training, create cache parameters corresponding to the host params with is cache_enable.
729 // Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr.
730 // Bind hashmap tensor ptr and cache tensor ptr to host tensor, so that we can flush values
731 // from cache to host in each epoch end.
732 // Create EmbeddingLookup(CPU), CacheSwapTable(Ascend), UpdateCache(CPU) for each pair of params, in order to
733 // flush miss values to cache params and write back old values to host params.
734 // If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add
735 // ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph.
736 // If use pipe in training, create parameters for no ref param such as labels and MapCacheIdx output[0] and
737 // Unique output[1], in each step, it will train the data from last step, so that can hide the time of Unique
738 // and other cpu kernels. So in the first step, it's fake data.
739 CacheEmbeddingForTrain(graph, is_pipe, cnodes, unique_node, param_cache_enable_set);
740 } else {
741 // If eval, Use EmbeddingLookup(CPU) op to replace GatherV2.
742 // The network is the same as Host-Device mode.
743 CacheEmbeddingForEval(graph, cnodes, unique_node, param_cache_enable_set);
744 }
745 }
746 } // namespace parallel
747 } // namespace mindspore
748