• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_WORKER_H_
18 #define MINDSPORE_CCSRC_PS_WORKER_H_
19 
20 #include <utility>
21 #include <memory>
22 #include <vector>
23 #include <string>
24 #include <numeric>
25 #include <functional>
26 #include <algorithm>
27 #include <map>
28 #include <mutex>
29 #include <unordered_set>
30 #include <unordered_map>
31 
32 #include "utils/log_adapter.h"
33 #include "ir/tensor.h"
34 #include "ps/util.h"
35 #include "ps/constants.h"
36 #include "utils/shape_utils.h"
37 #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
38 #include "ps/core/worker_node.h"
39 #include "ps/embedding_table_shard_metadata.h"
40 #include "proto/comm.pb.h"
41 #include "proto/ps.pb.h"
42 #include "ps/ps_context.h"
43 
44 namespace mindspore {
45 namespace ps {
46 class Worker {
47  public:
GetInstance()48   static Worker &GetInstance() {
49     static Worker instance;
50     return instance;
51   }
52   using Callback = std::function<void()>;
53   using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>;
54   using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>;
55 
56   using EmbeddingPartitioner = std::function<void(
57     const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
58   using KVPartitioner =
59     std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>;
60 
61   void Run();
62   void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes);
63   void Pull(const size_t key, void *dev_addr, const size_t size);
64   size_t SetParamKey(const std::string &param_name);
65   size_t GetParamKey(const std::string &param_name);
66   void SetParamInitInServer(const std::string &param_name, bool init_in_server);
67   bool GetParamInitInServer(const std::string &param_name);
68   void SetKeyOptimId(size_t key, const std::string &optimizer_name);
69   void SetOptimInputShapes(size_t key, const ShapeVector &shape);
70   void AddEmbeddingTable(const Key &key, const size_t &row_count);
71   void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
72                             const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
73                             const ParamInitInfoMessage &info);
74   void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor);
75   void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
76                            int64_t cmd);
77   void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
78                             const std::vector<float> &vals);
79 
running()80   bool running() { return running_; }
81   void Finalize();
82 
83  private:
Worker()84   Worker() : server_num_(-1), running_(false), key_cnt_(0) {}
85   ~Worker() = default;
86   Worker(const Worker &) = delete;
87   Worker &operator=(const Worker &) = delete;
88 
89   void Initialize();
90   bool IsKeyInit(const size_t key);
91   void AddKeyToServerId(const Key &key);
92   void AddKeyByHashMod(const Key &key);
93   void InitPSOptimId(const size_t param_key);
94   void InitPSOptimInputShapes(const size_t key);
95   void InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size);
96   bool IsReadyForPush(const Key &key);
97   bool IsReadyForPull(const Key &key);
98   void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
99                              const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
100                              const size_t segment_size, float *gradient, int *indices);
101   void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
102                         const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data);
103 
104   void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {},
105                 int command = 0, int64_t priority = 0);
106   void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
107                       size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
108   void PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens = nullptr,
109                 int cmd = 0, int64_t priority = 0);
110 
111   void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
112                            const std::map<int64_t, int64_t> &attrs);
113 
114   void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
115                          const std::map<int64_t, int64_t> &attrs);
116   void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
117                              const std::map<int64_t, int64_t> &attrs);
118   void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
119                                       const std::map<int64_t, int64_t> &attrs);
120   void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
121                                   const std::map<int64_t, int64_t> &attrs);
122   void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
123                             const std::map<int64_t, int64_t> &attrs);
124   void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
125                    const std::map<int64_t, int64_t> &attrs);
126   void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
127                    const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens);
128 
129   int64_t server_num_;
130   bool running_;
131   std::mutex running_mutex_;
132   size_t key_cnt_;
133   std::map<std::string, size_t> param_to_key_;
134   std::map<size_t, bool> init_keys_;
135   std::map<size_t, int64_t> key_to_optimId_;
136   std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_;
137   std::map<std::string, bool> param_to_init_in_server_;
138   core::WorkerNode worker_node_;
139 
140   EmbeddingPartitioner lookup_partitioner_;
141   KVPartitioner sparse_partitioner_;
142   KVPartitioner round_robin_partitioner_;
143   KVPartitioner worker_init_embedding_partitioner_;
144   KVPartitioner update_embedding_partitioner_;
145   KVPartitioner broadcast_partitioner_;
146   std::unordered_map<Key, int64_t> key_to_server_id_;
147   std::unordered_map<Key, size_t> embedding_row_cnt_;
148 
149   std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_;
150 };
151 }  // namespace ps
152 }  // namespace mindspore
153 #endif  // MINDSPORE_CCSRC_PS_WORKER_H_
154