• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
18 
19 #include <memory>
20 #include <string>
21 #include <algorithm>
22 #include <utility>
23 #include <functional>
24 
25 #include "ops/sparse_op_name.h"
26 #include "ops/nn_op_name.h"
27 #include "ops/array_op_name.h"
28 #include "ops/sequence_ops.h"
29 #include "ops/framework_ops.h"
30 #include "ir/func_graph.h"
31 #include "abstract/abstract_function.h"
32 #include "include/common/utils/anfalgo.h"
33 #include "include/common/utils/utils.h"
34 #include "utils/ms_context.h"
35 
36 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
37 
38 namespace mindspore {
39 namespace parallel {
40 // One dimensional shape placeholder.
41 const ShapeVector kOneDimShape = {1};
42 // Two dimensional shape placeholder.
43 const ShapeVector kTwoDimsShape = {1, 1};
44 
45 // One dimensional shape placeholder.
46 const ShapeVector kOneDimDynamicShape = {-1};
47 // Two dimensional shape placeholder.
48 const ShapeVector kTwoDimsDynamicShape = {-1, -1};
49 
50 // The output tensor number of recv node.
51 const size_t kRecvNodeOutputNum = 3;
52 
53 // The input index of offset of EmbeddingLookup kernel.
54 constexpr size_t kEmbeddingLookupOffsetIdx = 2;
55 
56 // The dims of embedding table.
57 constexpr size_t kEmbeddingTableDims = 2;
58 
59 constexpr char kEmbeddingRemoteCacheNode[] = "EmbeddingRemoteCacheNode";
60 constexpr char kEmbeddingLocalCacheNode[] = "EmbeddingLocalCacheNode";
61 constexpr char kEnvEmbeddingCacheMemSizeInGBytes[] = "MS_EMBEDDING_REMOTE_CACHE_MEMORY_SIZE";
62 
63 namespace {
CreateFakeValueNode(const AnfNodePtr & origin_node)64 ValueNodePtr CreateFakeValueNode(const AnfNodePtr &origin_node) {
65   MS_EXCEPTION_IF_NULL(origin_node);
66   abstract::AbstractTensorPtr origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
67 
68   MS_EXCEPTION_IF_NULL(origin_abstract);
69   tensor::TensorPtr fake_tensor = std::make_shared<tensor::Tensor>(origin_abstract->element()->BuildType()->type_id(),
70                                                                    origin_abstract->shape()->shape());
71   MS_EXCEPTION_IF_NULL(fake_tensor);
72   fake_tensor->set_base_shape(origin_abstract->shape()->Clone());
73 
74   auto fake_value = NewValueNode(fake_tensor);
75   MS_EXCEPTION_IF_NULL(fake_value);
76   fake_value->set_abstract(fake_tensor->ToAbstract());
77   return fake_value;
78 }
79 
CreateOutputNode(const FuncGraphPtr & func_graph,const AnfNodePtr & origin_output)80 AnfNodePtr CreateOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
81   MS_EXCEPTION_IF_NULL(func_graph);
82   MS_EXCEPTION_IF_NULL(origin_output);
83   MS_EXCEPTION_IF_NULL(origin_output->abstract());
84   if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
85     abstract::AbstractBasePtrList new_elements_abs;
86     std::vector<ValuePtr> new_elements_values;
87 
88     auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
89     for (const auto &element : tuple_elements) {
90       MS_EXCEPTION_IF_NULL(element);
91       auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
92       if (!tensor_abstract) {
93         MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
94       }
95       auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
96                                                           tensor_abstract->shape()->shape());
97       MS_EXCEPTION_IF_NULL(fake_tensor);
98       new_elements_abs.push_back(fake_tensor->ToAbstract());
99       new_elements_values.push_back(fake_tensor);
100     }
101     ValueTuplePtr value_tuple = std::make_shared<ValueTuple>(new_elements_values);
102     auto value_tuple_abs = std::make_shared<abstract::AbstractTuple>(new_elements_abs);
103     auto value_tuple_node = NewValueNode(value_tuple);
104     MS_EXCEPTION_IF_NULL(value_tuple_node);
105     value_tuple_node->set_abstract(value_tuple_abs);
106     return value_tuple_node;
107   } else {
108     return CreateFakeValueNode(origin_output);
109   }
110 }
111 }  // namespace
112 
GetEmbeddingLookupNodes()113 void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() {
114   MS_EXCEPTION_IF_NULL(root_graph_);
115   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root_graph_->get_return());
116   (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
117     MS_EXCEPTION_IF_NULL(node);
118     if (!node->isa<CNode>()) {
119       return;
120     }
121 
122     const std::string kernel_name = common::AnfAlgo::GetCNodeName(node);
123     if (kernel_name != kGatherOpName && kernel_name != kSparseGatherV2OpName && kernel_name != kMapTensorGetOpName) {
124       return;
125     }
126 
127     const PrimitivePtr &prim = common::AnfAlgo::GetCNodePrimitive(node);
128     MS_EXCEPTION_IF_NULL(prim);
129     if (!(prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole))) {
130       return;
131     }
132 
133     int64_t rank_id_attr = GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId));
134     std::string node_role_attr = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
135     if (rank_id_attr == rank_id_ && node_role_attr == node_role_) {
136       auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
137       shapes_to_nodes_[shape] = node;
138       MS_LOG(INFO) << "The shape: " << shape << " for node: " << node->fullname_with_scope();
139     }
140   });
141 }
142 
GetCacheEnableParameters()143 void PsEmbeddingCacheInserter::GetCacheEnableParameters() {
144   MS_EXCEPTION_IF_NULL(root_graph_);
145   const std::vector<AnfNodePtr> &parameters = root_graph_->parameters();
146   auto params_size = parameters.size();
147   for (size_t i = 0; i < params_size; ++i) {
148     MS_EXCEPTION_IF_NULL(parameters[i]);
149     if (!parameters[i]->isa<Parameter>()) {
150       MS_LOG(EXCEPTION) << "The node with name: " << parameters[i]->fullname_with_scope() << "is not a Parameter.";
151     }
152 
153     ParameterPtr param = parameters[i]->cast<ParameterPtr>();
154     MS_EXCEPTION_IF_NULL(param);
155     auto param_info = param->param_info();
156     if (param_info && param_info->key() != -1 && param_info->cache_enable()) {
157       keys_to_params_[param_info->key()] = param;
158       MS_LOG(INFO) << "Parameter[" << param->fullname_with_scope() << "], key[" << param_info->key() << "]";
159     }
160   }
161 }
162 
SetNodeAttr(const CNodePtr & node,const std::string & node_role) const163 void PsEmbeddingCacheInserter::SetNodeAttr(const CNodePtr &node, const std::string &node_role) const {
164   MS_EXCEPTION_IF_NULL(node);
165 
166   // Set attr for call node, call node hasn't primitive to save attrs, so save attrs into CNode.
167   if (common::AnfAlgo::IsCallNode(node)) {
168     node->AddAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice));
169     node->AddAttr(distributed::kOpLabelRankId, MakeValue(rank_id_));
170     node->AddAttr(distributed::kOpLabelRole, MakeValue(node_role));
171   } else {
172     common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), node);
173     common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRankId, MakeValue(rank_id_), node);
174     common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRole, MakeValue(node_role), node);
175   }
176 }
177 
SetAttrForAllNodes() const178 void PsEmbeddingCacheInserter::SetAttrForAllNodes() const {
179   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root_graph_->get_return());
180   (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
181     MS_EXCEPTION_IF_NULL(node);
182     if (!node->isa<CNode>()) {
183       return;
184     }
185     CNodePtr cnode = node->cast<CNodePtr>();
186     MS_EXCEPTION_IF_NULL(cnode);
187     SetNodeAttr(cnode);
188   });
189 }
190 
SetSendNodeAttr(const CNodePtr & send_node,int32_t param_key,const std::string & embedding_cache_op,const std::string & dst_role) const191 void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key,
192                                                const std::string &embedding_cache_op,
193                                                const std::string &dst_role) const {
194   MS_EXCEPTION_IF_NULL(send_node);
195 
196   std::vector<uint32_t> dst_ranks;
197   std::vector<std::string> dst_roles = {dst_role};
198   std::vector<std::string> inter_process_edges;
199 
200   // Set inter process edges, send dst ranks, send dst roles.
201   for (uint32_t i = 0; i < worker_num_; i++) {
202     dst_ranks.push_back(i);
203     dst_roles.push_back(dst_role);
204     // Unique edge name: src role + src rank id -> dst role + dst rank id +embedding cache operation + parameter key.
205     inter_process_edges.push_back(distributed::kEnvRoleOfPServer + std::to_string(rank_id_) + "->" + dst_role +
206                                   std::to_string(i) + "_" + embedding_cache_op + "_" + distributed::kParameterKey +
207                                   std::to_string(param_key));
208   }
209 
210   common::AnfAlgo::SetNodeAttr(kAttrSendDstRanks, MakeValue(dst_ranks), send_node);
211   common::AnfAlgo::SetNodeAttr(kAttrSendDstRoles, MakeValue(dst_roles), send_node);
212   common::AnfAlgo::SetNodeAttr(kAttrSendSrcNodeName, MakeValue(std::string(kEmbeddingRemoteCacheNode)), send_node);
213   common::AnfAlgo::SetNodeAttr(kAttrSendDstNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), send_node);
214 
215   common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node);
216   common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), send_node);
217 }
218 
SetRecvNodeAttr(const CNodePtr & recv_node,const std::string & src_role) const219 void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role) const {
220   MS_EXCEPTION_IF_NULL(recv_node);
221 
222   std::vector<uint32_t> src_ranks;
223   std::vector<std::string> src_roles;
224   std::vector<std::string> inter_process_edges;
225 
226   // Set inter process edges, recv src ranks, recv src roles.
227   // Each server has only one Recv node, which needs to receive all requests from each worker. For example, different
228   // parameters on each worker have two operations: look up embedding and update embedding. Each operation will be
229   // performed by an independent Send node, so the Recv node on the server side will have multiple edges.
230   for (uint32_t i = 0; i < worker_num_; i++) {
231     for (const auto &item : keys_to_params_) {
232       int32_t param_key = item.first;
233       for (uint32_t k = 0; k < distributed::kEmbeddingCacheOps.size(); k++) {
234         src_ranks.push_back(i);
235         src_roles.push_back(src_role);
236         // Unique edge name: src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter
237         // key.
238         inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfPServer +
239                                       std::to_string(rank_id_) + "_" + distributed::kEmbeddingCacheOps[k] + "_" +
240                                       distributed::kParameterKey + std::to_string(param_key));
241       }
242     }
243   }
244 
245   common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRanks, MakeValue(src_ranks), recv_node);
246   common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node);
247   common::AnfAlgo::SetNodeAttr(kAttrRecvSrcNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), recv_node);
248   common::AnfAlgo::SetNodeAttr(kAttrRecvDstNodeName, MakeValue(std::string(kEmbeddingRemoteCacheNode)), recv_node);
249 
250   common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node);
251   common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), recv_node);
252 }
253 
CreateReturnNode(const FuncGraphPtr graph,const AnfNodePtr & output_node) const254 CNodePtr PsEmbeddingCacheInserter::CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const {
255   MS_EXCEPTION_IF_NULL(graph);
256   MS_EXCEPTION_IF_NULL(output_node);
257 
258   // Create fake output value node to make sure the output abstract is the same for each subgraph.
259   auto fake_output_tensor = std::make_shared<tensor::Tensor>(1.0);
260   auto fake_output_value = NewValueNode(fake_output_tensor);
261   MS_EXCEPTION_IF_NULL(fake_output_value);
262   fake_output_value->set_abstract(fake_output_tensor->ToAbstract());
263 
264   // Create depend node.
265   auto depend_node = graph->NewCNode({NewValueNode(prim::kPrimDepend), fake_output_value, output_node});
266   MS_EXCEPTION_IF_NULL(depend_node);
267 
268   // Create return node.
269   std::vector<AnfNodePtr> return_inputs;
270   return_inputs.push_back(NewValueNode(prim::kPrimReturn));
271   return_inputs.push_back(depend_node);
272   auto return_node = graph->NewCNode(return_inputs);
273   MS_EXCEPTION_IF_NULL(return_node);
274 
275   return return_node;
276 }
277 
ConstructEmbeddingLookupSubGraph(const AnfNodePtr & node,const ParameterPtr & param,int32_t param_key) const278 FuncGraphPtr PsEmbeddingCacheInserter::ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node,
279                                                                         const ParameterPtr &param,
280                                                                         int32_t param_key) const {
281   MS_EXCEPTION_IF_NULL(param);
282   MS_EXCEPTION_IF_NULL(node);
283 
284   // 1. Create subgraph and parameters.
285   auto graph = std::make_shared<FuncGraph>();
286   ParameterPtr input_param = graph->add_parameter();
287   MS_EXCEPTION_IF_NULL(input_param);
288   MS_EXCEPTION_IF_NULL(param->abstract());
289   input_param->set_abstract(param->abstract()->Clone());
290   ParameterPtr input_indices = graph->add_parameter();
291   MS_EXCEPTION_IF_NULL(input_indices);
292   input_indices->set_abstract(
293     std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape)));
294 
295   // 2. Create embedding lookup node.
296   auto embedding_cache_lookup_node = CreateEmbeddingLookupKernel(graph, input_param, input_indices, node);
297   MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node);
298 
299   common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
300   common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
301 
302   if (embedding_storage_manager.Exists(param_key)) {
303     common::AnfAlgo::SetNodeAttr(kAttrEnableEmbeddingStorage, MakeValue(true), embedding_cache_lookup_node);
304     common::AnfAlgo::SetNodeAttr(kAttrParameterKey, MakeValue(param_key), embedding_cache_lookup_node);
305   }
306 
307   // 3. Create RpcSend node.
308   std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
309   send_inputs.push_back(embedding_cache_lookup_node);
310   CNodePtr send_node = graph->NewCNode(send_inputs);
311   MS_EXCEPTION_IF_NULL(send_node);
312   common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), send_node);
313   common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), send_node);
314   SetSendNodeAttr(send_node, param_key, distributed::kLookupEmbeddingCache);
315 
316   // 4. Create return node.
317   CNodePtr return_node = CreateReturnNode(graph, send_node);
318   MS_EXCEPTION_IF_NULL(return_node);
319   graph->set_return(return_node);
320 
321   MS_EXCEPTION_IF_NULL(root_graph_);
322   auto manager = root_graph_->manager();
323   MS_EXCEPTION_IF_NULL(manager);
324   manager->AddFuncGraph(graph);
325   return graph;
326 }
327 
ConstructUpdateEmbeddingSubGraph(const ParameterPtr & param,const AnfNodePtr & node,int32_t param_key) const328 FuncGraphPtr PsEmbeddingCacheInserter::ConstructUpdateEmbeddingSubGraph(const ParameterPtr &param,
329                                                                         const AnfNodePtr &node,
330                                                                         int32_t param_key) const {
331   MS_EXCEPTION_IF_NULL(param);
332   MS_EXCEPTION_IF_NULL(node);
333 
334   // 1. Create subgraph and parameters.
335   auto graph = std::make_shared<FuncGraph>();
336   ParameterPtr input_param = graph->add_parameter();
337   MS_EXCEPTION_IF_NULL(input_param);
338   MS_EXCEPTION_IF_NULL(param->abstract());
339   input_param->set_abstract(param->abstract()->Clone());
340 
341   ParameterPtr input_indices = graph->add_parameter();
342   MS_EXCEPTION_IF_NULL(input_indices);
343   input_indices->set_abstract(
344     std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape)));
345 
346   ParameterPtr update_values = graph->add_parameter();
347   MS_EXCEPTION_IF_NULL(update_values);
348   auto emb_shape = common::AnfAlgo::GetOutputInferShape(param, 0);
349   if (emb_shape.empty()) {
350     MS_LOG(EXCEPTION) << "Embedding table shape is empty.";
351   }
352   ShapeVector update_values_shape = emb_shape;
353   const int64_t dynamic_dim = -1;
354   update_values_shape[0] = dynamic_dim;
355   update_values->set_abstract(
356     std::make_shared<abstract::AbstractTensor>(kFloat32, std::make_shared<abstract::Shape>(update_values_shape)));
357 
358   // 2. Create embedding update node.
359   auto embedding_cache_update_node = CreateEmbeddingUpdateKernel(graph, input_param, input_indices, update_values);
360   MS_EXCEPTION_IF_NULL(embedding_cache_update_node);
361   common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_update_node);
362 
363   if (embedding_storage_manager.Exists(param_key)) {
364     common::AnfAlgo::SetNodeAttr(kAttrEnableEmbeddingStorage, MakeValue(true), embedding_cache_update_node);
365     common::AnfAlgo::SetNodeAttr(kAttrParameterKey, MakeValue(param_key), embedding_cache_update_node);
366   }
367 
368   // 3. Create return node.
369   CNodePtr return_node = CreateReturnNode(graph, embedding_cache_update_node);
370   MS_EXCEPTION_IF_NULL(return_node);
371   graph->set_return(return_node);
372 
373   MS_EXCEPTION_IF_NULL(root_graph_);
374   auto manager = root_graph_->manager();
375   MS_EXCEPTION_IF_NULL(manager);
376   manager->AddFuncGraph(graph);
377   return graph;
378 }
379 
CreateEmbeddingLookupKernel(const FuncGraphPtr & graph,const ParameterPtr & input_param,const ParameterPtr & input_indices,const AnfNodePtr & origin_embedding_lookup_node) const380 CNodePtr PsEmbeddingCacheInserter::CreateEmbeddingLookupKernel(const FuncGraphPtr &graph,
381                                                                const ParameterPtr &input_param,
382                                                                const ParameterPtr &input_indices,
383                                                                const AnfNodePtr &origin_embedding_lookup_node) const {
384   MS_EXCEPTION_IF_NULL(graph);
385   MS_EXCEPTION_IF_NULL(input_param);
386   MS_EXCEPTION_IF_NULL(input_indices);
387   MS_EXCEPTION_IF_NULL(origin_embedding_lookup_node);
388 
389   std::vector<AnfNodePtr> embedding_lookup_inputs;
390   // Sparse format is true meaning embedding table implements in the form of hash, false means the form of tensor.
391   if (!distributed::EmbeddingCacheTableManager::GetInstance().is_sparse_format()) {
392     if (!common::AnfAlgo::HasNodeAttr(kAttrOffset, dyn_cast<CNode>(origin_embedding_lookup_node))) {
393       MS_LOG(EXCEPTION) << "Can not find offset attr of kernel: "
394                         << origin_embedding_lookup_node->fullname_with_scope();
395     }
396     int64_t offset = common::AnfAlgo::GetNodeAttr<int64_t>(origin_embedding_lookup_node, kAttrOffset);
397     ValueNodePtr offset_value_node = NewValueNode(offset);
398     MS_EXCEPTION_IF_NULL(offset_value_node);
399 
400     PrimitivePtr embedding_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
401     embedding_lookup_inputs = {NewValueNode(embedding_lookup_primitive), input_param, input_indices, offset_value_node};
402   } else {
403     PrimitivePtr embedding_lookup_primitive = std::make_shared<Primitive>(kMapTensorGetOpName);
404     embedding_lookup_primitive->set_attr(kAttrInsertDefaultValue, MakeValue(false));
405     embedding_lookup_inputs = {NewValueNode(embedding_lookup_primitive), input_param, input_indices};
406   }
407 
408   return graph->NewCNode(embedding_lookup_inputs);
409 }
410 
CreateEmbeddingUpdateKernel(const FuncGraphPtr & graph,const ParameterPtr & input_param,const ParameterPtr & input_indices,const ParameterPtr & update_values) const411 CNodePtr PsEmbeddingCacheInserter::CreateEmbeddingUpdateKernel(const FuncGraphPtr &graph,
412                                                                const ParameterPtr &input_param,
413                                                                const ParameterPtr &input_indices,
414                                                                const ParameterPtr &update_values) const {
415   MS_EXCEPTION_IF_NULL(graph);
416   MS_EXCEPTION_IF_NULL(input_param);
417   MS_EXCEPTION_IF_NULL(input_indices);
418   MS_EXCEPTION_IF_NULL(update_values);
419 
420   // Sparse format is true meaning embedding table implements in the form of hash, false means the form of tensor.
421   bool is_sparse_format = distributed::EmbeddingCacheTableManager::GetInstance().is_sparse_format();
422   PrimitivePtr embedding_update_primitive = is_sparse_format ? std::make_shared<Primitive>(kMapTensorPutOpName)
423                                                              : std::make_shared<Primitive>(kScatterUpdateOpName);
424   std::vector<AnfNodePtr> embedding_update_inputs{NewValueNode(embedding_update_primitive), input_param, input_indices,
425                                                   update_values};
426   return graph->NewCNode(embedding_update_inputs);
427 }
428 
CreateRecvNode() const429 CNodePtr PsEmbeddingCacheInserter::CreateRecvNode() const {
430   // 1. Create input parameter for RpcRecv node.
431   // The indices input.
432   MS_EXCEPTION_IF_NULL(root_graph_);
433   ParameterPtr input_indices = root_graph_->add_parameter();
434   MS_EXCEPTION_IF_NULL(input_indices);
435   input_indices->set_abstract(
436     std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape)));
437   auto fake_input_indices_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimShape);
438   input_indices->set_default_param(fake_input_indices_tensor);
439 
440   // The update values input.
441   ParameterPtr update_values = root_graph_->add_parameter();
442   MS_EXCEPTION_IF_NULL(update_values);
443   update_values->set_abstract(
444     std::make_shared<abstract::AbstractTensor>(kFloat32, std::make_shared<abstract::Shape>(kTwoDimsDynamicShape)));
445   auto fake_update_values_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, kTwoDimsShape);
446   update_values->set_default_param(fake_update_values_tensor);
447 
448   // The service id input, used to choose service to execute.
449   ParameterPtr service_id = root_graph_->add_parameter();
450   MS_EXCEPTION_IF_NULL(service_id);
451   service_id->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
452   auto fake_id_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimDynamicShape);
453   service_id->set_default_param(fake_id_tensor);
454 
455   // 2. Create a RpcRecv node.
456   std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
457   recv_inputs.push_back(input_indices);
458   recv_inputs.push_back(update_values);
459   recv_inputs.push_back(service_id);
460   MS_EXCEPTION_IF_NULL(root_graph_);
461   CNodePtr recv_node = root_graph_->NewCNode(recv_inputs);
462   MS_EXCEPTION_IF_NULL(recv_node);
463 
464   SetRecvNodeAttr(recv_node);
465   common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), recv_node);
466   common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), recv_node);
467 
468   return recv_node;
469 }
470 
ConstructEmbeddingCacheServicesSubGraphs(const std::vector<CNodePtr> & recv_outputs,std::vector<AnfNodePtr> * make_tuple_inputs) const471 bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheServicesSubGraphs(
472   const std::vector<CNodePtr> &recv_outputs, std::vector<AnfNodePtr> *make_tuple_inputs) const {
473   MS_EXCEPTION_IF_NULL(root_graph_);
474   MS_EXCEPTION_IF_NULL(make_tuple_inputs);
475   if (recv_outputs.size() != kRecvNodeOutputNum) {
476     MS_LOG(ERROR) << "The output tensor number of recv node is not equal to " << kRecvNodeOutputNum;
477     return false;
478   }
479 
480   for (const auto &item : keys_to_params_) {
481     int32_t key = item.first;
482     ParameterPtr param = item.second;
483     MS_EXCEPTION_IF_NULL(param);
484     auto shape = common::AnfAlgo::GetOutputInferShape(param, 0);
485     auto iter = shapes_to_nodes_.find(shape);
486     if (iter == shapes_to_nodes_.end()) {
487       MS_LOG(ERROR) << "Can not find cnode for parameter(key[" << key << "]) with shape: " << shape;
488       return false;
489     }
490     AnfNodePtr node = iter->second;
491 
492     // 1. Construct embedding lookup service sub graph.
493     auto emb_lookup_sub_graph = ConstructEmbeddingLookupSubGraph(node, param, key);
494     MS_EXCEPTION_IF_NULL(emb_lookup_sub_graph);
495     auto emb_lookup_graph_value = NewValueNode(emb_lookup_sub_graph);
496     MS_EXCEPTION_IF_NULL(emb_lookup_graph_value);
497     auto emb_lookup_graph_value_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
498       emb_lookup_sub_graph, abstract::AnalysisContext::DummyContext());
499     emb_lookup_graph_value->set_abstract(emb_lookup_graph_value_abstract);
500 
501     CNodePtr emb_lookup_partial_node =
502       root_graph_->NewCNode({NewValueNode(prim::kPrimPartial), emb_lookup_graph_value, param, recv_outputs[0]});
503     MS_EXCEPTION_IF_NULL(emb_lookup_partial_node);
504     AbstractBasePtrList lookup_partial_args_spec_list = {param->abstract(), recv_outputs[0]->abstract()};
505     emb_lookup_partial_node->set_abstract(std::make_shared<abstract::PartialAbstractClosure>(
506       emb_lookup_graph_value_abstract, lookup_partial_args_spec_list, emb_lookup_partial_node));
507 
508     make_tuple_inputs->push_back(emb_lookup_partial_node);
509 
510     // 2. Construct updating embedding service sub graph.
511     auto update_emb_sub_graph = ConstructUpdateEmbeddingSubGraph(param, node, key);
512     MS_EXCEPTION_IF_NULL(update_emb_sub_graph);
513     auto update_emb_graph_value = NewValueNode(update_emb_sub_graph);
514     MS_EXCEPTION_IF_NULL(update_emb_graph_value);
515     auto update_emb_graph_value_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
516       update_emb_sub_graph, abstract::AnalysisContext::DummyContext());
517     update_emb_graph_value->set_abstract(update_emb_graph_value_abstract);
518 
519     CNodePtr update_emb_partial_node = root_graph_->NewCNode(
520       {NewValueNode(prim::kPrimPartial), update_emb_graph_value, param, recv_outputs[0], recv_outputs[1]});
521     MS_EXCEPTION_IF_NULL(update_emb_partial_node);
522     AbstractBasePtrList update_partial_args_spec_list = {param->abstract(), recv_outputs[0]->abstract(),
523                                                          recv_outputs[1]->abstract()};
524     update_emb_partial_node->set_abstract(std::make_shared<abstract::PartialAbstractClosure>(
525       update_emb_graph_value_abstract, update_partial_args_spec_list, update_emb_partial_node));
526 
527     make_tuple_inputs->push_back(update_emb_partial_node);
528   }
529 
530   return true;
531 }
532 
ConstructEmbeddingCacheGraph() const533 bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheGraph() const {
534   // 1. Create recv node for server.
535   CNodePtr recv_node = CreateRecvNode();
536   MS_EXCEPTION_IF_NULL(recv_node);
537   auto value_node_0 = NewValueNode(static_cast<int64_t>(0));
538   auto value_node_1 = NewValueNode(static_cast<int64_t>(1));
539   auto value_node_2 = NewValueNode(static_cast<int64_t>(2));
540   std::vector<AnfNodePtr> getitem_input0{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_0};
541   std::vector<AnfNodePtr> getitem_input1{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_1};
542   std::vector<AnfNodePtr> getitem_input2{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_2};
543 
544   MS_EXCEPTION_IF_NULL(root_graph_);
545   auto getitem_0 = root_graph_->NewCNode(getitem_input0);
546   auto getitem_1 = root_graph_->NewCNode(getitem_input1);
547   auto getitem_2 = root_graph_->NewCNode(getitem_input2);
548   // The tuple_getitem nodes used to get the outputs of recv node.
549   std::vector<CNodePtr> getitems = {getitem_0, getitem_1, getitem_2};
550 
551   std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
552 
553   // 2. Construct the embedding cache services subgraphs, including embedding lookup and update operations, and
554   // package the subgraphs corresponding to the related operations into the partial.
555   RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheServicesSubGraphs(getitems, &make_tuple_inputs),
556                            "Construct embedding cache services sub graphs failed.");
557 
558   auto make_tuple_node = root_graph_->NewCNode(make_tuple_inputs);
559   MS_EXCEPTION_IF_NULL(make_tuple_node);
560 
561   // 3. Create switch layer and call node, used to select and execute the subgraph corresponding to the service
562   // requested.
563   std::vector<AnfNodePtr> switch_layer_inputs = {NewValueNode(prim::kPrimSwitchLayer), getitem_2, make_tuple_node};
564   auto switch_layer_node = root_graph_->NewCNode(switch_layer_inputs);
565   MS_EXCEPTION_IF_NULL(switch_layer_node);
566 
567   CNodePtr call_node = root_graph_->NewCNode({switch_layer_node});
568   MS_EXCEPTION_IF_NULL(call_node);
569 
570   // 4. Replace origin output and useless nodes of origin function graph.
571   AnfNodePtr old_output = root_graph_->output();
572   AnfNodePtr new_output = CreateOutputNode(root_graph_, old_output);
573   auto final_output_node = root_graph_->NewCNode({NewValueNode(prim::kPrimDepend), new_output, call_node});
574   MS_EXCEPTION_IF_NULL(final_output_node);
575 
576   auto graph_manager = root_graph_->manager();
577   MS_EXCEPTION_IF_NULL(graph_manager);
578   return graph_manager->Replace(root_graph_->output(), final_output_node);
579 }
580 
BuildDenseEmbeddingStorages()581 void PsEmbeddingCacheInserter::BuildDenseEmbeddingStorages() {
582   for (const auto &item : keys_to_params_) {
583     int32_t key = item.first;
584     ParameterPtr param = item.second;
585     MS_EXCEPTION_IF_NULL(param);
586 
587     auto param_info = param->param_info();
588     MS_EXCEPTION_IF_NULL(param_info);
589     if (!param_info->use_persistent_storage()) {
590       MS_LOG(INFO) << "No need to use embedding storage for this parameter(key): " << key;
591       continue;
592     }
593 
594     const std::vector<int64_t> &origin_shape = param_info->origin_shape();
595     size_t origin_capacity = LongToSize(origin_shape.front());
596     size_t origin_emb_dim = LongToSize(origin_shape.back());
597     MS_LOG(INFO) << "Get a parameter for embedding storage: " << param->name() << ", origin emb_dim: " << origin_emb_dim
598                  << ", origin capacity: " << origin_capacity;
599 
600     const std::vector<int64_t> &slice_shape = param_info->parameter_shape();
601     if (slice_shape.size() != kEmbeddingTableDims) {
602       MS_LOG(EXCEPTION)
603         << "When build embedding storage, Embedding table should be 2 dims for embedding cache mode, but got: "
604         << slice_shape.size() << " dims, param name: " << param->name() << ", param key: " << key;
605     }
606     size_t capacity = LongToSize(slice_shape.front());
607     size_t emb_dim = LongToSize(slice_shape.back());
608 
609     auto shape = common::AnfAlgo::GetOutputInferShape(param, 0);
610     auto iter = shapes_to_nodes_.find(shape);
611     if (iter == shapes_to_nodes_.end()) {
612       MS_LOG(EXCEPTION) << "Can not find cnode for parameter(key[" << key << "]) with shape: " << shape;
613     }
614     AnfNodePtr node = iter->second;
615     const size_t output_index0 = 0;
616     const size_t key_index = 1;
617 
618     TypeId key_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, key_index);
619     TypeId value_type = common::AnfAlgo::GetOutputInferDataType(node, output_index0);
620     // Create dense or sparse embedding storage and add into embedding storage manager.
621     distributed::CreateEmbeddingStorage(std::make_pair(key_type, value_type), key, emb_dim, capacity);
622     MS_LOG(INFO) << "Add a new dense embedding storage, key: " << key << ", emb_dim: " << emb_dim
623                  << ", capacity: " << capacity << ", origin emb_dim:" << origin_emb_dim
624                  << ", origin capacity: " << origin_capacity;
625   }
626 }
627 
BuildSparseEmbeddingStorages()628 void PsEmbeddingCacheInserter::BuildSparseEmbeddingStorages() {
629   if (common::GetEnv(kEnvEmbeddingCacheMemSizeInGBytes).empty()) {
630     return;
631   }
632   const size_t cache_size_in_gbytes = std::stoul(common::GetEnv(kEnvEmbeddingCacheMemSizeInGBytes));
633   const size_t cache_size_in_bytes = cache_size_in_gbytes << 30;
634 
635   for (const auto &item : keys_to_params_) {
636     int32_t key = item.first;
637     ParameterPtr param = item.second;
638     MS_EXCEPTION_IF_NULL(param);
639 
640     auto param_info = param->param_info();
641     MS_EXCEPTION_IF_NULL(param_info);
642     param_info->set_use_persistent_storage(true);
643 
644     const auto &abstract_base = common::AnfAlgo::FrontendGetNodeAbstractByIndex(param, 0);
645     MS_EXCEPTION_IF_NULL(abstract_base);
646     if (!abstract_base->isa<abstract::AbstractMapTensor>()) {
647       MS_LOG(EXCEPTION) << "Parameter:" << param->DebugString() << " is not a map tensor type.";
648     }
649     const auto &abstract = abstract_base->cast<abstract::AbstractMapTensorPtr>();
650     MS_EXCEPTION_IF_NULL(abstract);
651 
652     MS_EXCEPTION_IF_NULL(abstract->value_shape());
653     const auto &value_shape = abstract->value_shape()->shape();
654     size_t emb_dim = LongToSize(
655       std::accumulate(value_shape.begin(), value_shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()));
656 
657     const auto &map_tensor_type = abstract->map_tensor_type();
658     MS_EXCEPTION_IF_NULL(map_tensor_type);
659     auto key_type = map_tensor_type->key_dtype();
660     auto value_type = map_tensor_type->value_dtype();
661     MS_EXCEPTION_IF_NULL(key_type);
662     MS_EXCEPTION_IF_NULL(value_type);
663     size_t map_element_size = GetTypeByte(key_type) + (GetTypeByte(value_type) * emb_dim);
664     MS_EXCEPTION_IF_ZERO("map_element_size", map_element_size);
665 
666     size_t capacity = cache_size_in_bytes / map_element_size;
667     TypeId key_type_id = key_type->type_id();
668     TypeId value_type_id = value_type->type_id();
669 
670     // Create dense or sparse embedding storage and add into embedding storage manager.
671     distributed::CreateEmbeddingStorage(std::make_pair(key_type_id, value_type_id), key, emb_dim, capacity);
672     MS_LOG(INFO) << "Add a new sparse embedding storage, key: " << key << ", emb_dim: " << emb_dim
673                  << ", capacity: " << capacity;
674   }
675 }
676 
BuildEmbeddingStorages()677 void PsEmbeddingCacheInserter::BuildEmbeddingStorages() {
678   if (embedding_cache_table_manager.is_sparse_format()) {
679     BuildSparseEmbeddingStorages();
680   } else {
681     BuildDenseEmbeddingStorages();
682   }
683 }
684 
Run()685 bool PsEmbeddingCacheInserter::Run() {
686   // Get EmbeddingLookup nodes which are executed on server from origin function graph.
687   GetEmbeddingLookupNodes();
688 
689   // Get parameters enabled embedding cache of origin function graph.
690   GetCacheEnableParameters();
691 
692   // Build embedding storages for parameters enabled embedding cache to read/write embedding table from/to persistent
693   // storage.
694   BuildEmbeddingStorages();
695 
696   // Construct the embedding cache graph of server.
697   RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheGraph(), "Construct embedding cache graph failed.");
698 
699   // Set attr(device target attr and graph split label) for all CNodes.
700   SetAttrForAllNodes();
701 
702   MS_EXCEPTION_IF_NULL(root_graph_);
703   // Need renormalize to infer shape and set abstract.
704   root_graph_->set_flag(kFlagNeedRenormalize, true);
705   return true;
706 }
707 }  // namespace parallel
708 }  // namespace mindspore
709