• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/util.h"
18 #include <unordered_map>
19 #include <vector>
20 #include <memory>
21 #include "ps/constants.h"
22 #include "ps/ps_context.h"
23 #include "utils/ms_utils.h"
24 
25 namespace mindspore {
26 namespace ps {
27 std::unordered_map<std::string, int64_t> Util::optimizer_to_ids{
28   {kApplyMomentum, 0},
29   {kSparseAdam, 1},
30   {kSparseLazyAdam, 2},
31   {kSparseFtrl, 3},
32 };
33 
34 std::unordered_map<int64_t, std::string> Util::id_to_optimizers{
35   {0, kApplyMomentum},
36   {1, kSparseAdam},
37   {2, kSparseLazyAdam},
38   {3, kSparseFtrl},
39 };
40 
41 std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{
42   {0, kApplyMomentumOp},
43   {1, kSparseAdamOp},
44   {2, kSparseLazyAdamOp},
45   {3, kSparseFtrlOp},
46 };
47 
IsRoleOfPServer()48 bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
49 
IsRoleOfScheduler()50 bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
51 
optimizer_id(const std::string & name)52 int64_t Util::optimizer_id(const std::string &name) {
53   if (optimizer_to_ids.count(name) > 0) {
54     return optimizer_to_ids[name];
55   }
56   return -1;
57 }
58 
optimizer_name(int64_t id)59 std::string Util::optimizer_name(int64_t id) {
60   if (id_to_optimizers.count(id) > 0) {
61     return id_to_optimizers[id];
62   }
63   return "";
64 }
65 
optimizer_node_name(int64_t id)66 std::string Util::optimizer_node_name(int64_t id) {
67   if (id_to_optimizer_nodes.count(id) > 0) {
68     return id_to_optimizer_nodes[id];
69   }
70   return "";
71 }
72 
is_optimizer(const std::string & name)73 bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }
74 
LocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)75 int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
76   std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
77   if (shard_dims.count(rank_id) == 0) {
78     MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
79   }
80   return shard_dims[rank_id];
81 }
82 
AllRankLocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)83 std::map<int64_t, int64_t> Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
84   if (first_dim <= 0 || server_num <= 0 || rank_id < 0) {
85     MS_LOG(EXCEPTION) << "Input values are invalid.";
86   }
87   if (rank_id >= server_num) {
88     MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
89   }
90   std::map<int64_t, int64_t> shard_dims;
91   for (int64_t i = 0; i < server_num; i++) {
92     shard_dims[i] = 0;
93   }
94   if (server_num != static_cast<int64_t>(shard_dims.size())) {
95     MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
96   }
97   int64_t server_index = -1;
98   for (int64_t i = 0; i < first_dim; i++) {
99     server_index = (server_index + 1) % server_num;
100     shard_dims[server_index] = shard_dims[server_index] + 1;
101   }
102   if (shard_dims.count(rank_id) == 0) {
103     MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
104   }
105   return shard_dims;
106 }
107 
ReduceSparseGradient(float * gradients,int * indices,const size_t indices_size,size_t segment_size,const size_t first_dim_size,const size_t outer_dim_size,mindspore::kernel::SparseGradient<int> * unique_sparse_grad)108 void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
109                                 const size_t first_dim_size, const size_t outer_dim_size,
110                                 mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {
111   size_t slice_segment_size = indices_size * segment_size;
112   std::vector<float> workspace_grad(slice_segment_size);
113   std::vector<int> workspace_indices(indices_size);
114 
115   MS_EXCEPTION_IF_NULL(gradients);
116   MS_EXCEPTION_IF_NULL(indices);
117 
118   mindspore::kernel::SparseGradient<int> workspace_sparse_grad(
119     {workspace_grad.data(), workspace_indices.data(), indices_size});
120   mindspore::kernel::SparseGradient<int> input_sparse_grad({gradients, indices, indices_size});
121   mindspore::kernel::ReduceSparseGradientParam<int> param;
122   param.input_grad_ = &input_sparse_grad;
123   param.workspace_grad_ = &workspace_sparse_grad;
124   param.output_grad_ = unique_sparse_grad;
125   param.max_index_ = first_dim_size;
126   param.value_stride_ = outer_dim_size;
127 
128   mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
129 }
130 
FuseServerCommOps(const pipeline::ResourcePtr & res)131 bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
132   FuncGraphPtr func_graph = res->func_graph();
133   MS_EXCEPTION_IF_NULL(func_graph);
134   DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
135   DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
136   return true;
137 }
138 
DoFusion(const FuncGraphPtr & func_graph,const std::string & cnode_name,const std::string & fused_cnode_name)139 void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
140                     const std::string &fused_cnode_name) {
141   MS_EXCEPTION_IF_NULL(func_graph);
142   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
143 
144   std::vector<AnfNodePtr> single_nodes;
145   std::vector<std::string> weight_names;
146   std::vector<int64_t> indices;
147   for (const AnfNodePtr &node : node_list) {
148     if (node != nullptr && node->isa<CNode>()) {
149       if (AnfAlgo::GetCNodeName(node) == cnode_name) {
150         single_nodes.push_back(node);
151 
152         auto weight_name_value_node =
153           AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
154         const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
155         weight_names.push_back(weight_name);
156 
157         auto weight_index_value_node =
158           AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
159         int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
160         indices.push_back(weight_index);
161       }
162     }
163   }
164 
165   auto prim = std::make_shared<Primitive>(fused_cnode_name);
166   MS_EXCEPTION_IF_NULL(prim);
167   std::vector<AnfNodePtr> fused_node_inputs = {};
168   fused_node_inputs.push_back(NewValueNode(prim));
169   (void)std::for_each(single_nodes.begin(), single_nodes.end(), [&](const AnfNodePtr &node) {
170     fused_node_inputs.push_back(AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
171   });
172 
173   auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
174   MS_EXCEPTION_IF_NULL(fused_cnode);
175   AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
176   AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
177   AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);
178 
179   auto kernel_info = std::make_shared<device::KernelInfo>();
180   MS_EXCEPTION_IF_NULL(kernel_info);
181   fused_cnode->set_kernel_info(kernel_info);
182   auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
183   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());
184 
185   AbstractBasePtrList abstract_list;
186   for (const auto &node : single_nodes) {
187     auto cnode = node->cast<CNodePtr>();
188     MS_EXCEPTION_IF_NULL(cnode);
189     abstract_list.push_back(cnode->abstract());
190   }
191   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
192   MS_EXCEPTION_IF_NULL(abstract_tuple);
193   fused_cnode->set_abstract(abstract_tuple);
194 
195   auto manager = func_graph->manager();
196   MS_EXCEPTION_IF_NULL(manager);
197   for (const auto &node : single_nodes) {
198     if (!manager->Replace(node, fused_cnode)) {
199       MS_LOG(EXCEPTION) << "manager replace node failed";
200     }
201   }
202   return;
203 }
204 
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)205 kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
206   std::vector<std::string> inputs_device_format;
207   std::vector<std::string> outputs_device_format;
208   std::vector<TypeId> inputs_device_type;
209   std::vector<TypeId> outputs_device_type;
210   std::vector<std::vector<size_t>> outputs_shape;
211   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
212   for (size_t idx = 0; idx < node_list.size(); ++idx) {
213     auto cnode = utils::cast<CNodePtr>(node_list[idx]);
214     MS_EXCEPTION_IF_NULL(cnode);
215     size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
216     for (size_t input_index = 0; input_index < input_num; ++input_index) {
217       (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
218       inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
219     }
220     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
221     for (size_t output_index = 0; output_index < output_num; ++output_index) {
222       (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
223       outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
224       outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
225     }
226   }
227   builder.SetInputsFormat(inputs_device_format);
228   builder.SetOutputsFormat(outputs_device_format);
229   builder.SetInputsDeviceType(inputs_device_type);
230   builder.SetOutputsDeviceType(outputs_device_type);
231   return builder.Build();
232 }
233 }  // namespace ps
234 }  // namespace mindspore
235