1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_WORKER_H_ 18 #define MINDSPORE_CCSRC_PS_WORKER_H_ 19 20 #include <utility> 21 #include <memory> 22 #include <vector> 23 #include <string> 24 #include <numeric> 25 #include <functional> 26 #include <algorithm> 27 #include <map> 28 #include <mutex> 29 #include <unordered_set> 30 #include <unordered_map> 31 32 #include "utils/log_adapter.h" 33 #include "ir/tensor.h" 34 #include "ps/util.h" 35 #include "ps/constants.h" 36 #include "utils/shape_utils.h" 37 #include "ps/ps_cache/ps_data/ps_data_prefetch.h" 38 #include "ps/core/worker_node.h" 39 #include "ps/embedding_table_shard_metadata.h" 40 #include "proto/comm.pb.h" 41 #include "proto/ps.pb.h" 42 #include "ps/ps_context.h" 43 44 namespace mindspore { 45 namespace ps { 46 class Worker { 47 public: GetInstance()48 static Worker &GetInstance() { 49 static Worker instance; 50 return instance; 51 } 52 using Callback = std::function<void()>; 53 using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>; 54 using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>; 55 56 using EmbeddingPartitioner = std::function<void( 57 const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>; 58 using KVPartitioner = 59 std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>; 60 61 void Run(); 62 void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes); 63 void Pull(const size_t key, void *dev_addr, const size_t size); 64 size_t SetParamKey(const std::string ¶m_name); 65 size_t GetParamKey(const std::string ¶m_name); 66 void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); 67 bool GetParamInitInServer(const std::string ¶m_name); 68 void SetKeyOptimId(size_t key, const std::string &optimizer_name); 69 void SetOptimInputShapes(size_t key, const ShapeVector &shape); 70 void AddEmbeddingTable(const Key &key, const size_t &row_count); 71 void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape, 72 const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape, 73 const ParamInitInfoMessage &info); 74 void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); 75 void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result, 76 int64_t cmd); 77 void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids, 78 const std::vector<float> &vals); 79 running()80 bool running() { return running_; } 81 void Finalize(); 82 83 private: Worker()84 Worker() : server_num_(-1), running_(false), key_cnt_(0) {} 85 ~Worker() = default; 86 Worker(const Worker &) = delete; 87 Worker &operator=(const Worker &) = delete; 88 89 void Initialize(); 90 bool IsKeyInit(const size_t key); 91 void AddKeyToServerId(const Key &key); 92 void AddKeyByHashMod(const Key &key); 93 void InitPSOptimId(const size_t param_key); 94 void InitPSOptimInputShapes(const size_t key); 95 void InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size); 96 bool IsReadyForPush(const Key &key); 97 bool IsReadyForPull(const Key &key); 98 void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids, 99 const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice, 100 const size_t segment_size, float *gradient, int *indices); 101 void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index, 102 const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data); 103 104 void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {}, 105 int command = 0, int64_t priority = 0); 106 void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens, 107 size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); 108 void PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens = nullptr, 109 int cmd = 0, int64_t priority = 0); 110 111 void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, 112 const std::map<int64_t, int64_t> &attrs); 113 114 void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition, 115 const std::map<int64_t, int64_t> &attrs); 116 void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition, 117 const std::map<int64_t, int64_t> &attrs); 118 void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition, 119 const std::map<int64_t, int64_t> &attrs); 120 void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition, 121 const std::map<int64_t, int64_t> &attrs); 122 void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition, 123 const std::map<int64_t, int64_t> &attrs); 124 void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner, 125 const std::map<int64_t, int64_t> &attrs); 126 void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner, 127 const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens); 128 129 int64_t server_num_; 130 bool running_; 131 std::mutex running_mutex_; 132 size_t key_cnt_; 133 std::map<std::string, size_t> param_to_key_; 134 std::map<size_t, bool> init_keys_; 135 std::map<size_t, int64_t> key_to_optimId_; 136 std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_; 137 std::map<std::string, bool> param_to_init_in_server_; 138 core::WorkerNode worker_node_; 139 140 EmbeddingPartitioner lookup_partitioner_; 141 KVPartitioner sparse_partitioner_; 142 KVPartitioner round_robin_partitioner_; 143 KVPartitioner worker_init_embedding_partitioner_; 144 KVPartitioner update_embedding_partitioner_; 145 KVPartitioner broadcast_partitioner_; 146 std::unordered_map<Key, int64_t> key_to_server_id_; 147 std::unordered_map<Key, size_t> embedding_row_cnt_; 148 149 std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_; 150 }; 151 } // namespace ps 152 } // namespace mindspore 153 #endif // MINDSPORE_CCSRC_PS_WORKER_H_ 154