1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_ 19 20 #include <string> 21 #include <map> 22 #include <vector> 23 24 #include "ir/anf.h" 25 #include "include/backend/distributed/constants.h" 26 #include "utils/shape_utils.h" 27 28 namespace mindspore { 29 namespace parallel { 30 // Build service-side graph for embedding distributed cache based on Parameter Server, 31 // and remove all nodes of origin func graph. 32 class PsEmbeddingCacheInserter { 33 public: PsEmbeddingCacheInserter(const FuncGraphPtr & root_graph,int64_t rank_id,const std::string & node_role,uint32_t worker_num)34 PsEmbeddingCacheInserter(const FuncGraphPtr &root_graph, int64_t rank_id, const std::string &node_role, 35 uint32_t worker_num) 36 : root_graph_(root_graph), rank_id_(rank_id), node_role_(node_role), worker_num_(worker_num) {} 37 ~PsEmbeddingCacheInserter()38 ~PsEmbeddingCacheInserter() { 39 root_graph_ = nullptr; 40 keys_to_params_.clear(); 41 shapes_to_nodes_.clear(); 42 } 43 44 // Insert embedding cache sub graphs to replace all nodes of origin func graph. 45 bool Run(); 46 47 private: 48 // Construct the embedding cache graph of server: 49 // Recv --> SwitchLayer --> Call --> Return 50 // the SwitchLayer is used to select the subgraph corresponding to the service requested to be executed. 51 bool ConstructEmbeddingCacheGraph() const; 52 53 // Create RpcRecv node for server to receive request. 54 CNodePtr CreateRecvNode() const; 55 56 // Build Embedding store for each param which enable cache. Embedding store can read/write embedding from/to 57 // persistent storage. 58 void BuildEmbeddingStorages(); 59 // Build Embedding store for dense mode(Tensor). 60 void BuildDenseEmbeddingStorages(); 61 // Build Embedding store for sparse mode(Hash Table). 62 void BuildSparseEmbeddingStorages(); 63 64 // Construct the embedding cache services subgraphs, including embedding lookup and update operations, and package the 65 // subgraphs corresponding to the related operations into the partial. 66 bool ConstructEmbeddingCacheServicesSubGraphs(const std::vector<CNodePtr> &recv_outputs, 67 std::vector<AnfNodePtr> *make_tuple_inputs) const; 68 69 // Construct embedding lookup service sub graph: 70 // Input(param, indices) --> EmbeddingLookup/MapTensorGet --> RpcSend --> Return 71 // RpcSend is used to send the embeddings to the service caller. 72 FuncGraphPtr ConstructEmbeddingLookupSubGraph(const AnfNodePtr &node, const ParameterPtr ¶m, 73 int32_t param_key) const; 74 75 // Construct updating embedding service sub graph: 76 // Input(param, indices, update_values) --> ScatterUpdate/MapTensorPut --> Return 77 // The Sub is used to rectify the id via offset for embedding slice. 78 FuncGraphPtr ConstructUpdateEmbeddingSubGraph(const ParameterPtr ¶m, const AnfNodePtr &node, 79 int32_t param_key) const; 80 81 // Create embedding lookup kernel: 'EmbeddingLookup' for Tensor or 'MapTensorGet' for Hash Table. 82 CNodePtr CreateEmbeddingLookupKernel(const FuncGraphPtr &graph, const ParameterPtr &input_param, 83 const ParameterPtr &input_indices, 84 const AnfNodePtr &origin_embedding_lookup_node) const; 85 86 // Create embedding update kernel: 'ScatterUpdate' for Tensor or 'MapTensorPut' for Hash Table. 87 CNodePtr CreateEmbeddingUpdateKernel(const FuncGraphPtr &graph, const ParameterPtr &input_param, 88 const ParameterPtr &input_indices, const ParameterPtr &update_values) const; 89 90 // Create return node for subgraph, using depend node to return a fake value node to ensure that the output abstract 91 // of each subgraph is the same. 92 CNodePtr CreateReturnNode(const FuncGraphPtr graph, const AnfNodePtr &output_node) const; 93 94 // Set attr(device target attr and graph split label) for all CNodes. 95 void SetAttrForAllNodes() const; 96 97 // Set device target attr to cpu, set graph split label(rank id and node role, such as (0, "MS_PSERVER")). 98 void SetNodeAttr(const CNodePtr &node, const std::string &node_role = distributed::kEnvRoleOfPServer) const; 99 100 // Set attrs for send node, such as:inter process edges, send dst ranks, send dst roles. 101 void SetSendNodeAttr(const CNodePtr &send_node, int32_t param_key, const std::string &embedding_cache_op, 102 const std::string &dst_role = distributed::kEnvRoleOfWorker) const; 103 104 // Set attrs for recv node, such as:inter process edges, recv src ranks, recv src roles. 105 void SetRecvNodeAttr(const CNodePtr &recv_node, const std::string &src_role = distributed::kEnvRoleOfWorker) const; 106 107 // Get EmbeddingLookup nodes which are executed on server from origin function graph. 108 void GetEmbeddingLookupNodes(); 109 110 // Get parameters enabled embedding cache of origin function graph. 111 void GetCacheEnableParameters(); 112 113 // Origin root function graph. 114 FuncGraphPtr root_graph_; 115 116 // The rank id of this process. 117 int64_t rank_id_; 118 // The node role of this process. 119 std::string node_role_; 120 // The worker number of in cluster. 121 uint32_t worker_num_; 122 123 // Record parameters enabled embedding cache of origin function graph. 124 // Key: parameter key, Value: ParameterPtr 125 std::map<int32_t, ParameterPtr> keys_to_params_; 126 127 // Record EmbeddingLookup nodes which are executed on server from origin function graph. 128 // Key: shape of EmbeddingLookup node, Value: EmbeddingLookup AnfNodePtr. 129 std::map<ShapeVector, AnfNodePtr> shapes_to_nodes_; 130 }; 131 } // namespace parallel 132 } // namespace mindspore 133 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_EMBEDDING_CACHE_PS_EMBEDDING_CACHE_INSERTER_H_ 134