1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
18
19 #include <memory>
20 #include <string>
21 #include <algorithm>
22 #include <utility>
23 #include <functional>
24
25 #include "ops/sparse_op_name.h"
26 #include "ops/nn_op_name.h"
27 #include "ops/array_op_name.h"
28 #include "ops/sequence_ops.h"
29 #include "ops/framework_ops.h"
30 #include "ir/func_graph.h"
31 #include "abstract/abstract_function.h"
32 #include "include/common/utils/anfalgo.h"
33 #include "include/common/utils/utils.h"
34 #include "utils/ms_context.h"
35
36 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
37
38 namespace mindspore {
39 namespace parallel {
40 // One dimensional shape placeholder.
41 const ShapeVector kOneDimShape = {1};
42 // Two dimensional shape placeholder.
43 const ShapeVector kTwoDimsShape = {1, 1};
44
45 // One dimensional shape placeholder.
46 const ShapeVector kOneDimDynamicShape = {-1};
47 // Two dimensional shape placeholder.
48 const ShapeVector kTwoDimsDynamicShape = {-1, -1};
49
50 // The output tensor number of recv node.
51 const size_t kRecvNodeOutputNum = 3;
52
53 // The input index of offset of EmbeddingLookup kernel.
54 constexpr size_t kEmbeddingLookupOffsetIdx = 2;
55
56 // The dims of embedding table.
57 constexpr size_t kEmbeddingTableDims = 2;
58
59 constexpr char kEmbeddingRemoteCacheNode[] = "EmbeddingRemoteCacheNode";
60 constexpr char kEmbeddingLocalCacheNode[] = "EmbeddingLocalCacheNode";
61 constexpr char kEnvEmbeddingCacheMemSizeInGBytes[] = "MS_EMBEDDING_REMOTE_CACHE_MEMORY_SIZE";
62
63 namespace {
CreateFakeValueNode(const AnfNodePtr & origin_node)64 ValueNodePtr CreateFakeValueNode(const AnfNodePtr &origin_node) {
65 MS_EXCEPTION_IF_NULL(origin_node);
66 abstract::AbstractTensorPtr origin_abstract = origin_node->abstract()->cast<abstract::AbstractTensorPtr>();
67
68 MS_EXCEPTION_IF_NULL(origin_abstract);
69 tensor::TensorPtr fake_tensor = std::make_shared<tensor::Tensor>(origin_abstract->element()->BuildType()->type_id(),
70 origin_abstract->shape()->shape());
71 MS_EXCEPTION_IF_NULL(fake_tensor);
72 fake_tensor->set_base_shape(origin_abstract->shape()->Clone());
73
74 auto fake_value = NewValueNode(fake_tensor);
75 MS_EXCEPTION_IF_NULL(fake_value);
76 fake_value->set_abstract(fake_tensor->ToAbstract());
77 return fake_value;
78 }
79
CreateOutputNode(const FuncGraphPtr & func_graph,const AnfNodePtr & origin_output)80 AnfNodePtr CreateOutputNode(const FuncGraphPtr &func_graph, const AnfNodePtr &origin_output) {
81 MS_EXCEPTION_IF_NULL(func_graph);
82 MS_EXCEPTION_IF_NULL(origin_output);
83 MS_EXCEPTION_IF_NULL(origin_output->abstract());
84 if (origin_output->abstract()->isa<abstract::AbstractTuple>()) {
85 abstract::AbstractBasePtrList new_elements_abs;
86 std::vector<ValuePtr> new_elements_values;
87
88 auto tuple_elements = origin_output->abstract()->cast<abstract::AbstractTuplePtr>()->elements();
89 for (const auto &element : tuple_elements) {
90 MS_EXCEPTION_IF_NULL(element);
91 auto tensor_abstract = element->cast<abstract::AbstractTensorPtr>();
92 if (!tensor_abstract) {
93 MS_LOG(EXCEPTION) << "Only support to replace tuple with all tensor elements.";
94 }
95 auto fake_tensor = std::make_shared<tensor::Tensor>(tensor_abstract->element()->BuildType()->type_id(),
96 tensor_abstract->shape()->shape());
97 MS_EXCEPTION_IF_NULL(fake_tensor);
98 new_elements_abs.push_back(fake_tensor->ToAbstract());
99 new_elements_values.push_back(fake_tensor);
100 }
101 ValueTuplePtr value_tuple = std::make_shared<ValueTuple>(new_elements_values);
102 auto value_tuple_abs = std::make_shared<abstract::AbstractTuple>(new_elements_abs);
103 auto value_tuple_node = NewValueNode(value_tuple);
104 MS_EXCEPTION_IF_NULL(value_tuple_node);
105 value_tuple_node->set_abstract(value_tuple_abs);
106 return value_tuple_node;
107 } else {
108 return CreateFakeValueNode(origin_output);
109 }
110 }
111 } // namespace
112
GetEmbeddingLookupNodes()113 void PsEmbeddingCacheInserter::GetEmbeddingLookupNodes() {
114 MS_EXCEPTION_IF_NULL(root_graph_);
115 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root_graph_->get_return());
116 (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
117 MS_EXCEPTION_IF_NULL(node);
118 if (!node->isa<CNode>()) {
119 return;
120 }
121
122 const std::string kernel_name = common::AnfAlgo::GetCNodeName(node);
123 if (kernel_name != kGatherOpName && kernel_name != kSparseGatherV2OpName && kernel_name != kMapTensorGetOpName) {
124 return;
125 }
126
127 const PrimitivePtr &prim = common::AnfAlgo::GetCNodePrimitive(node);
128 MS_EXCEPTION_IF_NULL(prim);
129 if (!(prim->HasAttr(distributed::kOpLabelRankId) && prim->HasAttr(distributed::kOpLabelRole))) {
130 return;
131 }
132
133 int64_t rank_id_attr = GetValue<int64_t>(prim->GetAttr(distributed::kOpLabelRankId));
134 std::string node_role_attr = GetValue<std::string>(prim->GetAttr(distributed::kOpLabelRole));
135 if (rank_id_attr == rank_id_ && node_role_attr == node_role_) {
136 auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
137 shapes_to_nodes_[shape] = node;
138 MS_LOG(INFO) << "The shape: " << shape << " for node: " << node->fullname_with_scope();
139 }
140 });
141 }
142
GetCacheEnableParameters()143 void PsEmbeddingCacheInserter::GetCacheEnableParameters() {
144 MS_EXCEPTION_IF_NULL(root_graph_);
145 const std::vector<AnfNodePtr> ¶meters = root_graph_->parameters();
146 auto params_size = parameters.size();
147 for (size_t i = 0; i < params_size; ++i) {
148 MS_EXCEPTION_IF_NULL(parameters[i]);
149 if (!parameters[i]->isa<Parameter>()) {
150 MS_LOG(EXCEPTION) << "The node with name: " << parameters[i]->fullname_with_scope() << "is not a Parameter.";
151 }
152
153 ParameterPtr param = parameters[i]->cast<ParameterPtr>();
154 MS_EXCEPTION_IF_NULL(param);
155 auto param_info = param->param_info();
156 if (param_info && param_info->key() != -1 && param_info->cache_enable()) {
157 keys_to_params_[param_info->key()] = param;
158 MS_LOG(INFO) << "Parameter[" << param->fullname_with_scope() << "], key[" << param_info->key() << "]";
159 }
160 }
161 }
162
SetNodeAttr(const CNodePtr & node,const std::string & node_role) const163 void PsEmbeddingCacheInserter::SetNodeAttr(const CNodePtr &node, const std::string &node_role) const {
164 MS_EXCEPTION_IF_NULL(node);
165
166 // Set attr for call node, call node hasn't primitive to save attrs, so save attrs into CNode.
167 if (common::AnfAlgo::IsCallNode(node)) {
168 node->AddAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice));
169 node->AddAttr(distributed::kOpLabelRankId, MakeValue(rank_id_));
170 node->AddAttr(distributed::kOpLabelRole, MakeValue(node_role));
171 } else {
172 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), node);
173 common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRankId, MakeValue(rank_id_), node);
174 common::AnfAlgo::SetNodeAttr(distributed::kOpLabelRole, MakeValue(node_role), node);
175 }
176 }
177
SetAttrForAllNodes() const178 void PsEmbeddingCacheInserter::SetAttrForAllNodes() const {
179 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(root_graph_->get_return());
180 (void)std::for_each(all_nodes.begin(), all_nodes.end(), [this](const AnfNodePtr &node) {
181 MS_EXCEPTION_IF_NULL(node);
182 if (!node->isa<CNode>()) {
183 return;
184 }
185 CNodePtr cnode = node->cast<CNodePtr>();
186 MS_EXCEPTION_IF_NULL(cnode);
187 SetNodeAttr(cnode);
188 });
189 }
190
SetSendNodeAttr(const CNodePtr & send_node,int32_t param_key,const std::string & embedding_cache_op,const std::string & dst_role) const191 void PsEmbeddingCacheInserter::SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key,
192 const std::string &embedding_cache_op,
193 const std::string &dst_role) const {
194 MS_EXCEPTION_IF_NULL(send_node);
195
196 std::vector<uint32_t> dst_ranks;
197 std::vector<std::string> dst_roles = {dst_role};
198 std::vector<std::string> inter_process_edges;
199
200 // Set inter process edges, send dst ranks, send dst roles.
201 for (uint32_t i = 0; i < worker_num_; i++) {
202 dst_ranks.push_back(i);
203 dst_roles.push_back(dst_role);
204 // Unique edge name: src role + src rank id -> dst role + dst rank id +embedding cache operation + parameter key.
205 inter_process_edges.push_back(distributed::kEnvRoleOfPServer + std::to_string(rank_id_) + "->" + dst_role +
206 std::to_string(i) + "_" + embedding_cache_op + "_" + distributed::kParameterKey +
207 std::to_string(param_key));
208 }
209
210 common::AnfAlgo::SetNodeAttr(kAttrSendDstRanks, MakeValue(dst_ranks), send_node);
211 common::AnfAlgo::SetNodeAttr(kAttrSendDstRoles, MakeValue(dst_roles), send_node);
212 common::AnfAlgo::SetNodeAttr(kAttrSendSrcNodeName, MakeValue(std::string(kEmbeddingRemoteCacheNode)), send_node);
213 common::AnfAlgo::SetNodeAttr(kAttrSendDstNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), send_node);
214
215 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), send_node);
216 common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), send_node);
217 }
218
SetRecvNodeAttr(const CNodePtr & recv_node,const std::string & src_role) const219 void PsEmbeddingCacheInserter::SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role) const {
220 MS_EXCEPTION_IF_NULL(recv_node);
221
222 std::vector<uint32_t> src_ranks;
223 std::vector<std::string> src_roles;
224 std::vector<std::string> inter_process_edges;
225
226 // Set inter process edges, recv src ranks, recv src roles.
227 // Each server has only one Recv node, which needs to receive all requests from each worker. For example, different
228 // parameters on each worker have two operations: look up embedding and update embedding. Each operation will be
229 // performed by an independent Send node, so the Recv node on the server side will have multiple edges.
230 for (uint32_t i = 0; i < worker_num_; i++) {
231 for (const auto &item : keys_to_params_) {
232 int32_t param_key = item.first;
233 for (uint32_t k = 0; k < distributed::kEmbeddingCacheOps.size(); k++) {
234 src_ranks.push_back(i);
235 src_roles.push_back(src_role);
236 // Unique edge name: src role + src rank id -> dst role + dst rank id + embedding cache operation + parameter
237 // key.
238 inter_process_edges.push_back(src_role + std::to_string(i) + "->" + distributed::kEnvRoleOfPServer +
239 std::to_string(rank_id_) + "_" + distributed::kEmbeddingCacheOps[k] + "_" +
240 distributed::kParameterKey + std::to_string(param_key));
241 }
242 }
243 }
244
245 common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRanks, MakeValue(src_ranks), recv_node);
246 common::AnfAlgo::SetNodeAttr(kAttrRecvSrcRoles, MakeValue(src_roles), recv_node);
247 common::AnfAlgo::SetNodeAttr(kAttrRecvSrcNodeName, MakeValue(std::string(kEmbeddingLocalCacheNode)), recv_node);
248 common::AnfAlgo::SetNodeAttr(kAttrRecvDstNodeName, MakeValue(std::string(kEmbeddingRemoteCacheNode)), recv_node);
249
250 common::AnfAlgo::SetNodeAttr(kAttrInterProcessEdgeNames, MakeValue(inter_process_edges), recv_node);
251 common::AnfAlgo::SetNodeAttr(kAttrIsMuxRpcKernel, MakeValue(true), recv_node);
252 }
253
CreateReturnNode(const FuncGraphPtr graph,const AnfNodePtr & output_node) const254 CNodePtr PsEmbeddingCacheInserter::CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const {
255 MS_EXCEPTION_IF_NULL(graph);
256 MS_EXCEPTION_IF_NULL(output_node);
257
258 // Create fake output value node to make sure the output abstract is the same for each subgraph.
259 auto fake_output_tensor = std::make_shared<tensor::Tensor>(1.0);
260 auto fake_output_value = NewValueNode(fake_output_tensor);
261 MS_EXCEPTION_IF_NULL(fake_output_value);
262 fake_output_value->set_abstract(fake_output_tensor->ToAbstract());
263
264 // Create depend node.
265 auto depend_node = graph->NewCNode({NewValueNode(prim::kPrimDepend), fake_output_value, output_node});
266 MS_EXCEPTION_IF_NULL(depend_node);
267
268 // Create return node.
269 std::vector<AnfNodePtr> return_inputs;
270 return_inputs.push_back(NewValueNode(prim::kPrimReturn));
271 return_inputs.push_back(depend_node);
272 auto return_node = graph->NewCNode(return_inputs);
273 MS_EXCEPTION_IF_NULL(return_node);
274
275 return return_node;
276 }
277
ConstructEmbeddingLookupSubGraph(const AnfNodePtr & node,const ParameterPtr & param,int32_t param_key) const278 FuncGraphPtr PsEmbeddingCacheInserter::ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node,
279 const ParameterPtr ¶m,
280 int32_t param_key) const {
281 MS_EXCEPTION_IF_NULL(param);
282 MS_EXCEPTION_IF_NULL(node);
283
284 // 1. Create subgraph and parameters.
285 auto graph = std::make_shared<FuncGraph>();
286 ParameterPtr input_param = graph->add_parameter();
287 MS_EXCEPTION_IF_NULL(input_param);
288 MS_EXCEPTION_IF_NULL(param->abstract());
289 input_param->set_abstract(param->abstract()->Clone());
290 ParameterPtr input_indices = graph->add_parameter();
291 MS_EXCEPTION_IF_NULL(input_indices);
292 input_indices->set_abstract(
293 std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape)));
294
295 // 2. Create embedding lookup node.
296 auto embedding_cache_lookup_node = CreateEmbeddingLookupKernel(graph, input_param, input_indices, node);
297 MS_EXCEPTION_IF_NULL(embedding_cache_lookup_node);
298
299 common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
300 common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), embedding_cache_lookup_node);
301
302 if (embedding_storage_manager.Exists(param_key)) {
303 common::AnfAlgo::SetNodeAttr(kAttrEnableEmbeddingStorage, MakeValue(true), embedding_cache_lookup_node);
304 common::AnfAlgo::SetNodeAttr(kAttrParameterKey, MakeValue(param_key), embedding_cache_lookup_node);
305 }
306
307 // 3. Create RpcSend node.
308 std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
309 send_inputs.push_back(embedding_cache_lookup_node);
310 CNodePtr send_node = graph->NewCNode(send_inputs);
311 MS_EXCEPTION_IF_NULL(send_node);
312 common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), send_node);
313 common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), send_node);
314 SetSendNodeAttr(send_node, param_key, distributed::kLookupEmbeddingCache);
315
316 // 4. Create return node.
317 CNodePtr return_node = CreateReturnNode(graph, send_node);
318 MS_EXCEPTION_IF_NULL(return_node);
319 graph->set_return(return_node);
320
321 MS_EXCEPTION_IF_NULL(root_graph_);
322 auto manager = root_graph_->manager();
323 MS_EXCEPTION_IF_NULL(manager);
324 manager->AddFuncGraph(graph);
325 return graph;
326 }
327
ConstructUpdateEmbeddingSubGraph(const ParameterPtr & param,const AnfNodePtr & node,int32_t param_key) const328 FuncGraphPtr PsEmbeddingCacheInserter::ConstructUpdateEmbeddingSubGraph(const ParameterPtr ¶m,
329 const AnfNodePtr &node,
330 int32_t param_key) const {
331 MS_EXCEPTION_IF_NULL(param);
332 MS_EXCEPTION_IF_NULL(node);
333
334 // 1. Create subgraph and parameters.
335 auto graph = std::make_shared<FuncGraph>();
336 ParameterPtr input_param = graph->add_parameter();
337 MS_EXCEPTION_IF_NULL(input_param);
338 MS_EXCEPTION_IF_NULL(param->abstract());
339 input_param->set_abstract(param->abstract()->Clone());
340
341 ParameterPtr input_indices = graph->add_parameter();
342 MS_EXCEPTION_IF_NULL(input_indices);
343 input_indices->set_abstract(
344 std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape)));
345
346 ParameterPtr update_values = graph->add_parameter();
347 MS_EXCEPTION_IF_NULL(update_values);
348 auto emb_shape = common::AnfAlgo::GetOutputInferShape(param, 0);
349 if (emb_shape.empty()) {
350 MS_LOG(EXCEPTION) << "Embedding table shape is empty.";
351 }
352 ShapeVector update_values_shape = emb_shape;
353 const int64_t dynamic_dim = -1;
354 update_values_shape[0] = dynamic_dim;
355 update_values->set_abstract(
356 std::make_shared<abstract::AbstractTensor>(kFloat32, std::make_shared<abstract::Shape>(update_values_shape)));
357
358 // 2. Create embedding update node.
359 auto embedding_cache_update_node = CreateEmbeddingUpdateKernel(graph, input_param, input_indices, update_values);
360 MS_EXCEPTION_IF_NULL(embedding_cache_update_node);
361 common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), embedding_cache_update_node);
362
363 if (embedding_storage_manager.Exists(param_key)) {
364 common::AnfAlgo::SetNodeAttr(kAttrEnableEmbeddingStorage, MakeValue(true), embedding_cache_update_node);
365 common::AnfAlgo::SetNodeAttr(kAttrParameterKey, MakeValue(param_key), embedding_cache_update_node);
366 }
367
368 // 3. Create return node.
369 CNodePtr return_node = CreateReturnNode(graph, embedding_cache_update_node);
370 MS_EXCEPTION_IF_NULL(return_node);
371 graph->set_return(return_node);
372
373 MS_EXCEPTION_IF_NULL(root_graph_);
374 auto manager = root_graph_->manager();
375 MS_EXCEPTION_IF_NULL(manager);
376 manager->AddFuncGraph(graph);
377 return graph;
378 }
379
CreateEmbeddingLookupKernel(const FuncGraphPtr & graph,const ParameterPtr & input_param,const ParameterPtr & input_indices,const AnfNodePtr & origin_embedding_lookup_node) const380 CNodePtr PsEmbeddingCacheInserter::CreateEmbeddingLookupKernel(const FuncGraphPtr &graph,
381 const ParameterPtr &input_param,
382 const ParameterPtr &input_indices,
383 const AnfNodePtr &origin_embedding_lookup_node) const {
384 MS_EXCEPTION_IF_NULL(graph);
385 MS_EXCEPTION_IF_NULL(input_param);
386 MS_EXCEPTION_IF_NULL(input_indices);
387 MS_EXCEPTION_IF_NULL(origin_embedding_lookup_node);
388
389 std::vector<AnfNodePtr> embedding_lookup_inputs;
390 // Sparse format is true meaning embedding table implements in the form of hash, false means the form of tensor.
391 if (!distributed::EmbeddingCacheTableManager::GetInstance().is_sparse_format()) {
392 if (!common::AnfAlgo::HasNodeAttr(kAttrOffset, dyn_cast<CNode>(origin_embedding_lookup_node))) {
393 MS_LOG(EXCEPTION) << "Can not find offset attr of kernel: "
394 << origin_embedding_lookup_node->fullname_with_scope();
395 }
396 int64_t offset = common::AnfAlgo::GetNodeAttr<int64_t>(origin_embedding_lookup_node, kAttrOffset);
397 ValueNodePtr offset_value_node = NewValueNode(offset);
398 MS_EXCEPTION_IF_NULL(offset_value_node);
399
400 PrimitivePtr embedding_lookup_primitive = std::make_shared<Primitive>(kEmbeddingLookupOpName);
401 embedding_lookup_inputs = {NewValueNode(embedding_lookup_primitive), input_param, input_indices, offset_value_node};
402 } else {
403 PrimitivePtr embedding_lookup_primitive = std::make_shared<Primitive>(kMapTensorGetOpName);
404 embedding_lookup_primitive->set_attr(kAttrInsertDefaultValue, MakeValue(false));
405 embedding_lookup_inputs = {NewValueNode(embedding_lookup_primitive), input_param, input_indices};
406 }
407
408 return graph->NewCNode(embedding_lookup_inputs);
409 }
410
CreateEmbeddingUpdateKernel(const FuncGraphPtr & graph,const ParameterPtr & input_param,const ParameterPtr & input_indices,const ParameterPtr & update_values) const411 CNodePtr PsEmbeddingCacheInserter::CreateEmbeddingUpdateKernel(const FuncGraphPtr &graph,
412 const ParameterPtr &input_param,
413 const ParameterPtr &input_indices,
414 const ParameterPtr &update_values) const {
415 MS_EXCEPTION_IF_NULL(graph);
416 MS_EXCEPTION_IF_NULL(input_param);
417 MS_EXCEPTION_IF_NULL(input_indices);
418 MS_EXCEPTION_IF_NULL(update_values);
419
420 // Sparse format is true meaning embedding table implements in the form of hash, false means the form of tensor.
421 bool is_sparse_format = distributed::EmbeddingCacheTableManager::GetInstance().is_sparse_format();
422 PrimitivePtr embedding_update_primitive = is_sparse_format ? std::make_shared<Primitive>(kMapTensorPutOpName)
423 : std::make_shared<Primitive>(kScatterUpdateOpName);
424 std::vector<AnfNodePtr> embedding_update_inputs{NewValueNode(embedding_update_primitive), input_param, input_indices,
425 update_values};
426 return graph->NewCNode(embedding_update_inputs);
427 }
428
CreateRecvNode() const429 CNodePtr PsEmbeddingCacheInserter::CreateRecvNode() const {
430 // 1. Create input parameter for RpcRecv node.
431 // The indices input.
432 MS_EXCEPTION_IF_NULL(root_graph_);
433 ParameterPtr input_indices = root_graph_->add_parameter();
434 MS_EXCEPTION_IF_NULL(input_indices);
435 input_indices->set_abstract(
436 std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(kOneDimDynamicShape)));
437 auto fake_input_indices_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimShape);
438 input_indices->set_default_param(fake_input_indices_tensor);
439
440 // The update values input.
441 ParameterPtr update_values = root_graph_->add_parameter();
442 MS_EXCEPTION_IF_NULL(update_values);
443 update_values->set_abstract(
444 std::make_shared<abstract::AbstractTensor>(kFloat32, std::make_shared<abstract::Shape>(kTwoDimsDynamicShape)));
445 auto fake_update_values_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, kTwoDimsShape);
446 update_values->set_default_param(fake_update_values_tensor);
447
448 // The service id input, used to choose service to execute.
449 ParameterPtr service_id = root_graph_->add_parameter();
450 MS_EXCEPTION_IF_NULL(service_id);
451 service_id->set_abstract(std::make_shared<abstract::AbstractTensor>(kInt32, kOneDimShape));
452 auto fake_id_tensor = std::make_shared<tensor::Tensor>(kNumberTypeInt32, kOneDimDynamicShape);
453 service_id->set_default_param(fake_id_tensor);
454
455 // 2. Create a RpcRecv node.
456 std::vector<AnfNodePtr> recv_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcRecvOpName))};
457 recv_inputs.push_back(input_indices);
458 recv_inputs.push_back(update_values);
459 recv_inputs.push_back(service_id);
460 MS_EXCEPTION_IF_NULL(root_graph_);
461 CNodePtr recv_node = root_graph_->NewCNode(recv_inputs);
462 MS_EXCEPTION_IF_NULL(recv_node);
463
464 SetRecvNodeAttr(recv_node);
465 common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), recv_node);
466 common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), recv_node);
467
468 return recv_node;
469 }
470
ConstructEmbeddingCacheServicesSubGraphs(const std::vector<CNodePtr> & recv_outputs,std::vector<AnfNodePtr> * make_tuple_inputs) const471 bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheServicesSubGraphs(
472 const std::vector<CNodePtr> &recv_outputs, std::vector<AnfNodePtr> *make_tuple_inputs) const {
473 MS_EXCEPTION_IF_NULL(root_graph_);
474 MS_EXCEPTION_IF_NULL(make_tuple_inputs);
475 if (recv_outputs.size() != kRecvNodeOutputNum) {
476 MS_LOG(ERROR) << "The output tensor number of recv node is not equal to " << kRecvNodeOutputNum;
477 return false;
478 }
479
480 for (const auto &item : keys_to_params_) {
481 int32_t key = item.first;
482 ParameterPtr param = item.second;
483 MS_EXCEPTION_IF_NULL(param);
484 auto shape = common::AnfAlgo::GetOutputInferShape(param, 0);
485 auto iter = shapes_to_nodes_.find(shape);
486 if (iter == shapes_to_nodes_.end()) {
487 MS_LOG(ERROR) << "Can not find cnode for parameter(key[" << key << "]) with shape: " << shape;
488 return false;
489 }
490 AnfNodePtr node = iter->second;
491
492 // 1. Construct embedding lookup service sub graph.
493 auto emb_lookup_sub_graph = ConstructEmbeddingLookupSubGraph(node, param, key);
494 MS_EXCEPTION_IF_NULL(emb_lookup_sub_graph);
495 auto emb_lookup_graph_value = NewValueNode(emb_lookup_sub_graph);
496 MS_EXCEPTION_IF_NULL(emb_lookup_graph_value);
497 auto emb_lookup_graph_value_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
498 emb_lookup_sub_graph, abstract::AnalysisContext::DummyContext());
499 emb_lookup_graph_value->set_abstract(emb_lookup_graph_value_abstract);
500
501 CNodePtr emb_lookup_partial_node =
502 root_graph_->NewCNode({NewValueNode(prim::kPrimPartial), emb_lookup_graph_value, param, recv_outputs[0]});
503 MS_EXCEPTION_IF_NULL(emb_lookup_partial_node);
504 AbstractBasePtrList lookup_partial_args_spec_list = {param->abstract(), recv_outputs[0]->abstract()};
505 emb_lookup_partial_node->set_abstract(std::make_shared<abstract::PartialAbstractClosure>(
506 emb_lookup_graph_value_abstract, lookup_partial_args_spec_list, emb_lookup_partial_node));
507
508 make_tuple_inputs->push_back(emb_lookup_partial_node);
509
510 // 2. Construct updating embedding service sub graph.
511 auto update_emb_sub_graph = ConstructUpdateEmbeddingSubGraph(param, node, key);
512 MS_EXCEPTION_IF_NULL(update_emb_sub_graph);
513 auto update_emb_graph_value = NewValueNode(update_emb_sub_graph);
514 MS_EXCEPTION_IF_NULL(update_emb_graph_value);
515 auto update_emb_graph_value_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
516 update_emb_sub_graph, abstract::AnalysisContext::DummyContext());
517 update_emb_graph_value->set_abstract(update_emb_graph_value_abstract);
518
519 CNodePtr update_emb_partial_node = root_graph_->NewCNode(
520 {NewValueNode(prim::kPrimPartial), update_emb_graph_value, param, recv_outputs[0], recv_outputs[1]});
521 MS_EXCEPTION_IF_NULL(update_emb_partial_node);
522 AbstractBasePtrList update_partial_args_spec_list = {param->abstract(), recv_outputs[0]->abstract(),
523 recv_outputs[1]->abstract()};
524 update_emb_partial_node->set_abstract(std::make_shared<abstract::PartialAbstractClosure>(
525 update_emb_graph_value_abstract, update_partial_args_spec_list, update_emb_partial_node));
526
527 make_tuple_inputs->push_back(update_emb_partial_node);
528 }
529
530 return true;
531 }
532
ConstructEmbeddingCacheGraph() const533 bool PsEmbeddingCacheInserter::ConstructEmbeddingCacheGraph() const {
534 // 1. Create recv node for server.
535 CNodePtr recv_node = CreateRecvNode();
536 MS_EXCEPTION_IF_NULL(recv_node);
537 auto value_node_0 = NewValueNode(static_cast<int64_t>(0));
538 auto value_node_1 = NewValueNode(static_cast<int64_t>(1));
539 auto value_node_2 = NewValueNode(static_cast<int64_t>(2));
540 std::vector<AnfNodePtr> getitem_input0{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_0};
541 std::vector<AnfNodePtr> getitem_input1{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_1};
542 std::vector<AnfNodePtr> getitem_input2{NewValueNode(prim::kPrimTupleGetItem), recv_node, value_node_2};
543
544 MS_EXCEPTION_IF_NULL(root_graph_);
545 auto getitem_0 = root_graph_->NewCNode(getitem_input0);
546 auto getitem_1 = root_graph_->NewCNode(getitem_input1);
547 auto getitem_2 = root_graph_->NewCNode(getitem_input2);
548 // The tuple_getitem nodes used to get the outputs of recv node.
549 std::vector<CNodePtr> getitems = {getitem_0, getitem_1, getitem_2};
550
551 std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)};
552
553 // 2. Construct the embedding cache services subgraphs, including embedding lookup and update operations, and
554 // package the subgraphs corresponding to the related operations into the partial.
555 RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheServicesSubGraphs(getitems, &make_tuple_inputs),
556 "Construct embedding cache services sub graphs failed.");
557
558 auto make_tuple_node = root_graph_->NewCNode(make_tuple_inputs);
559 MS_EXCEPTION_IF_NULL(make_tuple_node);
560
561 // 3. Create switch layer and call node, used to select and execute the subgraph corresponding to the service
562 // requested.
563 std::vector<AnfNodePtr> switch_layer_inputs = {NewValueNode(prim::kPrimSwitchLayer), getitem_2, make_tuple_node};
564 auto switch_layer_node = root_graph_->NewCNode(switch_layer_inputs);
565 MS_EXCEPTION_IF_NULL(switch_layer_node);
566
567 CNodePtr call_node = root_graph_->NewCNode({switch_layer_node});
568 MS_EXCEPTION_IF_NULL(call_node);
569
570 // 4. Replace origin output and useless nodes of origin function graph.
571 AnfNodePtr old_output = root_graph_->output();
572 AnfNodePtr new_output = CreateOutputNode(root_graph_, old_output);
573 auto final_output_node = root_graph_->NewCNode({NewValueNode(prim::kPrimDepend), new_output, call_node});
574 MS_EXCEPTION_IF_NULL(final_output_node);
575
576 auto graph_manager = root_graph_->manager();
577 MS_EXCEPTION_IF_NULL(graph_manager);
578 return graph_manager->Replace(root_graph_->output(), final_output_node);
579 }
580
BuildDenseEmbeddingStorages()581 void PsEmbeddingCacheInserter::BuildDenseEmbeddingStorages() {
582 for (const auto &item : keys_to_params_) {
583 int32_t key = item.first;
584 ParameterPtr param = item.second;
585 MS_EXCEPTION_IF_NULL(param);
586
587 auto param_info = param->param_info();
588 MS_EXCEPTION_IF_NULL(param_info);
589 if (!param_info->use_persistent_storage()) {
590 MS_LOG(INFO) << "No need to use embedding storage for this parameter(key): " << key;
591 continue;
592 }
593
594 const std::vector<int64_t> &origin_shape = param_info->origin_shape();
595 size_t origin_capacity = LongToSize(origin_shape.front());
596 size_t origin_emb_dim = LongToSize(origin_shape.back());
597 MS_LOG(INFO) << "Get a parameter for embedding storage: " << param->name() << ", origin emb_dim: " << origin_emb_dim
598 << ", origin capacity: " << origin_capacity;
599
600 const std::vector<int64_t> &slice_shape = param_info->parameter_shape();
601 if (slice_shape.size() != kEmbeddingTableDims) {
602 MS_LOG(EXCEPTION)
603 << "When build embedding storage, Embedding table should be 2 dims for embedding cache mode, but got: "
604 << slice_shape.size() << " dims, param name: " << param->name() << ", param key: " << key;
605 }
606 size_t capacity = LongToSize(slice_shape.front());
607 size_t emb_dim = LongToSize(slice_shape.back());
608
609 auto shape = common::AnfAlgo::GetOutputInferShape(param, 0);
610 auto iter = shapes_to_nodes_.find(shape);
611 if (iter == shapes_to_nodes_.end()) {
612 MS_LOG(EXCEPTION) << "Can not find cnode for parameter(key[" << key << "]) with shape: " << shape;
613 }
614 AnfNodePtr node = iter->second;
615 const size_t output_index0 = 0;
616 const size_t key_index = 1;
617
618 TypeId key_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, key_index);
619 TypeId value_type = common::AnfAlgo::GetOutputInferDataType(node, output_index0);
620 // Create dense or sparse embedding storage and add into embedding storage manager.
621 distributed::CreateEmbeddingStorage(std::make_pair(key_type, value_type), key, emb_dim, capacity);
622 MS_LOG(INFO) << "Add a new dense embedding storage, key: " << key << ", emb_dim: " << emb_dim
623 << ", capacity: " << capacity << ", origin emb_dim:" << origin_emb_dim
624 << ", origin capacity: " << origin_capacity;
625 }
626 }
627
BuildSparseEmbeddingStorages()628 void PsEmbeddingCacheInserter::BuildSparseEmbeddingStorages() {
629 if (common::GetEnv(kEnvEmbeddingCacheMemSizeInGBytes).empty()) {
630 return;
631 }
632 const size_t cache_size_in_gbytes = std::stoul(common::GetEnv(kEnvEmbeddingCacheMemSizeInGBytes));
633 const size_t cache_size_in_bytes = cache_size_in_gbytes << 30;
634
635 for (const auto &item : keys_to_params_) {
636 int32_t key = item.first;
637 ParameterPtr param = item.second;
638 MS_EXCEPTION_IF_NULL(param);
639
640 auto param_info = param->param_info();
641 MS_EXCEPTION_IF_NULL(param_info);
642 param_info->set_use_persistent_storage(true);
643
644 const auto &abstract_base = common::AnfAlgo::FrontendGetNodeAbstractByIndex(param, 0);
645 MS_EXCEPTION_IF_NULL(abstract_base);
646 if (!abstract_base->isa<abstract::AbstractMapTensor>()) {
647 MS_LOG(EXCEPTION) << "Parameter:" << param->DebugString() << " is not a map tensor type.";
648 }
649 const auto &abstract = abstract_base->cast<abstract::AbstractMapTensorPtr>();
650 MS_EXCEPTION_IF_NULL(abstract);
651
652 MS_EXCEPTION_IF_NULL(abstract->value_shape());
653 const auto &value_shape = abstract->value_shape()->shape();
654 size_t emb_dim = LongToSize(
655 std::accumulate(value_shape.begin(), value_shape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()));
656
657 const auto &map_tensor_type = abstract->map_tensor_type();
658 MS_EXCEPTION_IF_NULL(map_tensor_type);
659 auto key_type = map_tensor_type->key_dtype();
660 auto value_type = map_tensor_type->value_dtype();
661 MS_EXCEPTION_IF_NULL(key_type);
662 MS_EXCEPTION_IF_NULL(value_type);
663 size_t map_element_size = GetTypeByte(key_type) + (GetTypeByte(value_type) * emb_dim);
664 MS_EXCEPTION_IF_ZERO("map_element_size", map_element_size);
665
666 size_t capacity = cache_size_in_bytes / map_element_size;
667 TypeId key_type_id = key_type->type_id();
668 TypeId value_type_id = value_type->type_id();
669
670 // Create dense or sparse embedding storage and add into embedding storage manager.
671 distributed::CreateEmbeddingStorage(std::make_pair(key_type_id, value_type_id), key, emb_dim, capacity);
672 MS_LOG(INFO) << "Add a new sparse embedding storage, key: " << key << ", emb_dim: " << emb_dim
673 << ", capacity: " << capacity;
674 }
675 }
676
BuildEmbeddingStorages()677 void PsEmbeddingCacheInserter::BuildEmbeddingStorages() {
678 if (embedding_cache_table_manager.is_sparse_format()) {
679 BuildSparseEmbeddingStorages();
680 } else {
681 BuildDenseEmbeddingStorages();
682 }
683 }
684
Run()685 bool PsEmbeddingCacheInserter::Run() {
686 // Get EmbeddingLookup nodes which are executed on server from origin function graph.
687 GetEmbeddingLookupNodes();
688
689 // Get parameters enabled embedding cache of origin function graph.
690 GetCacheEnableParameters();
691
692 // Build embedding storages for parameters enabled embedding cache to read/write embedding table from/to persistent
693 // storage.
694 BuildEmbeddingStorages();
695
696 // Construct the embedding cache graph of server.
697 RETURN_IF_FALSE_WITH_LOG(ConstructEmbeddingCacheGraph(), "Construct embedding cache graph failed.");
698
699 // Set attr(device target attr and graph split label) for all CNodes.
700 SetAttrForAllNodes();
701
702 MS_EXCEPTION_IF_NULL(root_graph_);
703 // Need renormalize to infer shape and set abstract.
704 root_graph_->set_flag(kFlagNeedRenormalize, true);
705 return true;
706 }
707 } // namespace parallel
708 } // namespace mindspore
709