1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/util.h"
18 #include <unordered_map>
19 #include <vector>
20 #include <memory>
21 #include "ps/constants.h"
22 #include "ps/ps_context.h"
23 #include "utils/ms_utils.h"
24
25 namespace mindspore {
26 namespace ps {
27 std::unordered_map<std::string, int64_t> Util::optimizer_to_ids{
28 {kApplyMomentum, 0},
29 {kSparseAdam, 1},
30 {kSparseLazyAdam, 2},
31 {kSparseFtrl, 3},
32 };
33
34 std::unordered_map<int64_t, std::string> Util::id_to_optimizers{
35 {0, kApplyMomentum},
36 {1, kSparseAdam},
37 {2, kSparseLazyAdam},
38 {3, kSparseFtrl},
39 };
40
41 std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{
42 {0, kApplyMomentumOp},
43 {1, kSparseAdamOp},
44 {2, kSparseLazyAdamOp},
45 {3, kSparseFtrlOp},
46 };
47
IsRoleOfPServer()48 bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
49
IsRoleOfScheduler()50 bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
51
optimizer_id(const std::string & name)52 int64_t Util::optimizer_id(const std::string &name) {
53 if (optimizer_to_ids.count(name) > 0) {
54 return optimizer_to_ids[name];
55 }
56 return -1;
57 }
58
optimizer_name(int64_t id)59 std::string Util::optimizer_name(int64_t id) {
60 if (id_to_optimizers.count(id) > 0) {
61 return id_to_optimizers[id];
62 }
63 return "";
64 }
65
optimizer_node_name(int64_t id)66 std::string Util::optimizer_node_name(int64_t id) {
67 if (id_to_optimizer_nodes.count(id) > 0) {
68 return id_to_optimizer_nodes[id];
69 }
70 return "";
71 }
72
is_optimizer(const std::string & name)73 bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }
74
LocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)75 int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
76 std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
77 if (shard_dims.count(rank_id) == 0) {
78 MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
79 }
80 return shard_dims[rank_id];
81 }
82
AllRankLocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)83 std::map<int64_t, int64_t> Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
84 if (first_dim <= 0 || server_num <= 0 || rank_id < 0) {
85 MS_LOG(EXCEPTION) << "Input values are invalid.";
86 }
87 if (rank_id >= server_num) {
88 MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
89 }
90 std::map<int64_t, int64_t> shard_dims;
91 for (int64_t i = 0; i < server_num; i++) {
92 shard_dims[i] = 0;
93 }
94 if (server_num != static_cast<int64_t>(shard_dims.size())) {
95 MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
96 }
97 int64_t server_index = -1;
98 for (int64_t i = 0; i < first_dim; i++) {
99 server_index = (server_index + 1) % server_num;
100 shard_dims[server_index] = shard_dims[server_index] + 1;
101 }
102 if (shard_dims.count(rank_id) == 0) {
103 MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
104 }
105 return shard_dims;
106 }
107
ReduceSparseGradient(float * gradients,int * indices,const size_t indices_size,size_t segment_size,const size_t first_dim_size,const size_t outer_dim_size,mindspore::kernel::SparseGradient<int> * unique_sparse_grad)108 void Util::ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
109 const size_t first_dim_size, const size_t outer_dim_size,
110 mindspore::kernel::SparseGradient<int> *unique_sparse_grad) {
111 size_t slice_segment_size = indices_size * segment_size;
112 std::vector<float> workspace_grad(slice_segment_size);
113 std::vector<int> workspace_indices(indices_size);
114
115 MS_EXCEPTION_IF_NULL(gradients);
116 MS_EXCEPTION_IF_NULL(indices);
117
118 mindspore::kernel::SparseGradient<int> workspace_sparse_grad(
119 {workspace_grad.data(), workspace_indices.data(), indices_size});
120 mindspore::kernel::SparseGradient<int> input_sparse_grad({gradients, indices, indices_size});
121 mindspore::kernel::ReduceSparseGradientParam<int> param;
122 param.input_grad_ = &input_sparse_grad;
123 param.workspace_grad_ = &workspace_sparse_grad;
124 param.output_grad_ = unique_sparse_grad;
125 param.max_index_ = first_dim_size;
126 param.value_stride_ = outer_dim_size;
127
128 mindspore::kernel::SparseOptimizerCPUKernel::BucketReduceSparseGradient(param);
129 }
130
FuseServerCommOps(const pipeline::ResourcePtr & res)131 bool Util::FuseServerCommOps(const pipeline::ResourcePtr &res) {
132 FuncGraphPtr func_graph = res->func_graph();
133 MS_EXCEPTION_IF_NULL(func_graph);
134 DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
135 DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
136 return true;
137 }
138
DoFusion(const FuncGraphPtr & func_graph,const std::string & cnode_name,const std::string & fused_cnode_name)139 void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
140 const std::string &fused_cnode_name) {
141 MS_EXCEPTION_IF_NULL(func_graph);
142 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
143
144 std::vector<AnfNodePtr> single_nodes;
145 std::vector<std::string> weight_names;
146 std::vector<int64_t> indices;
147 for (const AnfNodePtr &node : node_list) {
148 if (node != nullptr && node->isa<CNode>()) {
149 if (AnfAlgo::GetCNodeName(node) == cnode_name) {
150 single_nodes.push_back(node);
151
152 auto weight_name_value_node =
153 AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
154 const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
155 weight_names.push_back(weight_name);
156
157 auto weight_index_value_node =
158 AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
159 int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
160 indices.push_back(weight_index);
161 }
162 }
163 }
164
165 auto prim = std::make_shared<Primitive>(fused_cnode_name);
166 MS_EXCEPTION_IF_NULL(prim);
167 std::vector<AnfNodePtr> fused_node_inputs = {};
168 fused_node_inputs.push_back(NewValueNode(prim));
169 (void)std::for_each(single_nodes.begin(), single_nodes.end(), [&](const AnfNodePtr &node) {
170 fused_node_inputs.push_back(AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
171 });
172
173 auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
174 MS_EXCEPTION_IF_NULL(fused_cnode);
175 AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
176 AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
177 AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);
178
179 auto kernel_info = std::make_shared<device::KernelInfo>();
180 MS_EXCEPTION_IF_NULL(kernel_info);
181 fused_cnode->set_kernel_info(kernel_info);
182 auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
183 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());
184
185 AbstractBasePtrList abstract_list;
186 for (const auto &node : single_nodes) {
187 auto cnode = node->cast<CNodePtr>();
188 MS_EXCEPTION_IF_NULL(cnode);
189 abstract_list.push_back(cnode->abstract());
190 }
191 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
192 MS_EXCEPTION_IF_NULL(abstract_tuple);
193 fused_cnode->set_abstract(abstract_tuple);
194
195 auto manager = func_graph->manager();
196 MS_EXCEPTION_IF_NULL(manager);
197 for (const auto &node : single_nodes) {
198 if (!manager->Replace(node, fused_cnode)) {
199 MS_LOG(EXCEPTION) << "manager replace node failed";
200 }
201 }
202 return;
203 }
204
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)205 kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
206 std::vector<std::string> inputs_device_format;
207 std::vector<std::string> outputs_device_format;
208 std::vector<TypeId> inputs_device_type;
209 std::vector<TypeId> outputs_device_type;
210 std::vector<std::vector<size_t>> outputs_shape;
211 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
212 for (size_t idx = 0; idx < node_list.size(); ++idx) {
213 auto cnode = utils::cast<CNodePtr>(node_list[idx]);
214 MS_EXCEPTION_IF_NULL(cnode);
215 size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
216 for (size_t input_index = 0; input_index < input_num; ++input_index) {
217 (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
218 inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
219 }
220 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
221 for (size_t output_index = 0; output_index < output_num; ++output_index) {
222 (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
223 outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index));
224 outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index));
225 }
226 }
227 builder.SetInputsFormat(inputs_device_format);
228 builder.SetOutputsFormat(outputs_device_format);
229 builder.SetInputsDeviceType(inputs_device_type);
230 builder.SetOutputsDeviceType(outputs_device_type);
231 return builder.Build();
232 }
233 } // namespace ps
234 } // namespace mindspore
235