• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_
19 
20 #include <string>
21 #include <map>
22 #include <vector>
23 
24 #include "ir/anf.h"
25 #include "include/backend/distributed/constants.h"
26 #include "utils/shape_utils.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 // Build service-side graph for embedding distributed cache based on Parameter Server,
31 // and remove all nodes of origin func graph.
32 class PsEmbeddingCacheInserter {
33  public:
PsEmbeddingCacheInserter(const FuncGraphPtr & root_graph,int64_t rank_id,const std::string & node_role,uint32_t worker_num)34   PsEmbeddingCacheInserter(const FuncGraphPtr &root_graph, int64_t rank_id, const std::string &node_role,
35                            uint32_t worker_num)
36       : root_graph_(root_graph), rank_id_(rank_id), node_role_(node_role), worker_num_(worker_num) {}
37 
~PsEmbeddingCacheInserter()38   ~PsEmbeddingCacheInserter() {
39     root_graph_ = nullptr;
40     keys_to_params_.clear();
41     shapes_to_nodes_.clear();
42   }
43 
44   // Insert embedding cache sub graphs to replace all nodes of origin func graph.
45   bool Run();
46 
47  private:
48   // Construct the embedding cache graph of server:
49   // Recv --> SwitchLayer --> Call --> Return
50   // the SwitchLayer is used to select the subgraph corresponding to the service requested to be executed.
51   bool ConstructEmbeddingCacheGraph() const;
52 
53   // Create RpcRecv node for server to receive request.
54   CNodePtr CreateRecvNode() const;
55 
56   // Build Embedding store for each param which enable cache. Embedding store can read/write embedding from/to
57   // persistent storage.
58   void BuildEmbeddingStorages();
59   // Build Embedding store for dense mode(Tensor).
60   void BuildDenseEmbeddingStorages();
61   // Build Embedding store for sparse mode(Hash Table).
62   void BuildSparseEmbeddingStorages();
63 
64   // Construct the embedding cache services subgraphs, including embedding lookup and update operations, and package the
65   // subgraphs corresponding to the related operations into the partial.
66   bool ConstructEmbeddingCacheServicesSubGraphs(const std::vector<CNodePtr> &recv_outputs,
67                                                 std::vector<AnfNodePtr> *make_tuple_inputs) const;
68 
69   // Construct embedding lookup service sub graph:
70   // Input(param, indices) --> EmbeddingLookup/MapTensorGet --> RpcSend --> Return
71   // RpcSend is used to send the embeddings to the service caller.
72   FuncGraphPtr ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node, const ParameterPtr &param,
73                                                 int32_t param_key) const;
74 
75   // Construct updating embedding service sub graph:
76   // Input(param, indices, update_values) --> ScatterUpdate/MapTensorPut --> Return
77   // The Sub is used to rectify the id via offset for embedding slice.
78   FuncGraphPtr ConstructUpdateEmbeddingSubGraph(const ParameterPtr &param, const AnfNodePtr &node,
79                                                 int32_t param_key) const;
80 
81   // Create embedding lookup kernel: 'EmbeddingLookup' for Tensor or 'MapTensorGet' for Hash Table.
82   CNodePtr CreateEmbeddingLookupKernel(const FuncGraphPtr &graph, const ParameterPtr &input_param,
83                                        const ParameterPtr &input_indices,
84                                        const AnfNodePtr &origin_embedding_lookup_node) const;
85 
86   // Create embedding update kernel: 'ScatterUpdate' for Tensor or 'MapTensorPut' for Hash Table.
87   CNodePtr CreateEmbeddingUpdateKernel(const FuncGraphPtr &graph, const ParameterPtr &input_param,
88                                        const ParameterPtr &input_indices, const ParameterPtr &update_values) const;
89 
90   // Create return node for subgraph, using depend node to return a fake value node to ensure that the output abstract
91   // of each subgraph is the same.
92   CNodePtr CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const;
93 
94   // Set attr(device target attr and graph split label) for all CNodes.
95   void SetAttrForAllNodes() const;
96 
97   // Set device target attr to cpu, set graph split label(rank id and node role, such as (0, "MS_PSERVER")).
98   void SetNodeAttr(const CNodePtr &node, const std::string &node_role = distributed::kEnvRoleOfPServer) const;
99 
100   // Set attrs for send node, such as:inter process edges, send dst ranks, send dst roles.
101   void SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key, const std::string &embedding_cache_op,
102                        const std::string &dst_role = distributed::kEnvRoleOfWorker) const;
103 
104   // Set attrs for recv node, such as:inter process edges, recv src ranks, recv src roles.
105   void SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role = distributed::kEnvRoleOfWorker) const;
106 
107   // Get EmbeddingLookup nodes which are executed on server from origin function graph.
108   void GetEmbeddingLookupNodes();
109 
110   // Get parameters enabled embedding cache of origin function graph.
111   void GetCacheEnableParameters();
112 
113   // Origin root function graph.
114   FuncGraphPtr root_graph_;
115 
116   // The rank id of this process.
117   int64_t rank_id_;
118   // The node role of this process.
119   std::string node_role_;
120   // The worker number of in cluster.
121   uint32_t worker_num_;
122 
123   // Record parameters enabled embedding cache of origin function graph.
124   // Key: parameter key, Value: ParameterPtr
125   std::map<int32_t, ParameterPtr> keys_to_params_;
126 
127   // Record EmbeddingLookup nodes which are executed on server from origin function graph.
128   // Key: shape of EmbeddingLookup node, Value: EmbeddingLookup AnfNodePtr.
129   std::map<ShapeVector, AnfNodePtr> shapes_to_nodes_;
130 };
131 }  // namespace parallel
132 }  // namespace mindspore
133 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_
134