1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/distributed/ps/util.h"
18 #include <vector>
19 #include <memory>
20 #include "mindspore/core/ops/ascend_op_name.h"
21 #include "mindspore/core/ops/other_op_name.h"
22 #include "utils/hash_map.h"
23 #include "include/backend/distributed/ps/constants.h"
24 #include "include/backend/distributed/ps/ps_context.h"
25 #include "distributed/persistent/data.h"
26
27 namespace mindspore {
28 namespace ps {
29 namespace {
30 static mindspore::HashMap<std::string, int64_t> optimizer_to_ids = {
31 {kApplyMomentum, 0},
32 {kSparseAdam, 1},
33 {kSparseLazyAdam, 2},
34 {kSparseFtrl, 3},
35 };
36
37 static mindspore::HashMap<int64_t, std::string> id_to_optimizers = {
38 {0, kApplyMomentum},
39 {1, kSparseAdam},
40 {2, kSparseLazyAdam},
41 {3, kSparseFtrl},
42 };
43
44 static mindspore::HashMap<int64_t, std::string> id_to_optimizer_nodes = {
45 {0, kApplyMomentumOp},
46 {1, kSparseAdamOp},
47 {2, kSparseLazyAdamOp},
48 {3, kSparseFtrlOp},
49 };
50 } // namespace
51
IsRoleOfPServer()52 bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
53
IsRoleOfScheduler()54 bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
55
optimizer_id(const std::string & name)56 int64_t Util::optimizer_id(const std::string &name) {
57 if (optimizer_to_ids.count(name) > 0) {
58 return optimizer_to_ids[name];
59 }
60 return -1;
61 }
62
optimizer_name(int64_t id)63 std::string Util::optimizer_name(int64_t id) {
64 if (id_to_optimizers.count(id) > 0) {
65 return id_to_optimizers[id];
66 }
67 return "";
68 }
69
optimizer_node_name(int64_t id)70 std::string Util::optimizer_node_name(int64_t id) {
71 if (id_to_optimizer_nodes.count(id) > 0) {
72 return id_to_optimizer_nodes[id];
73 }
74 return "";
75 }
76
is_optimizer(const std::string & name)77 bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }
78
LocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)79 int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
80 std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
81 if (shard_dims.count(rank_id) == 0) {
82 MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
83 }
84 return shard_dims[rank_id];
85 }
86
AllRankLocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)87 std::map<int64_t, int64_t> Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
88 if (first_dim <= 0 || server_num <= 0 || rank_id < 0) {
89 MS_LOG(EXCEPTION) << "Input values are invalid, first_dim: " << first_dim << ", server_num: " << server_num
90 << ", rank_id: " << rank_id;
91 }
92 if (rank_id >= server_num) {
93 MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
94 }
95 std::map<int64_t, int64_t> shard_dims;
96 for (int64_t i = 0; i < server_num; i++) {
97 shard_dims[i] = 0;
98 }
99 if (server_num != static_cast<int64_t>(shard_dims.size())) {
100 MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
101 }
102 int64_t server_index = -1;
103 for (int64_t i = 0; i < first_dim; i++) {
104 server_index = (server_index + 1) % server_num;
105 shard_dims[server_index] = shard_dims[server_index] + 1;
106 }
107 if (shard_dims.count(rank_id) == 0) {
108 MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
109 }
110 return shard_dims;
111 }
112
FuseServerCommOps(const FuncGraphPtr & func_graph)113 bool Util::FuseServerCommOps(const FuncGraphPtr &func_graph) {
114 MS_EXCEPTION_IF_NULL(func_graph);
115 DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
116 DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
117 return true;
118 }
119
MakeWeightPtr(const std::shared_ptr<std::vector<float>> & data,bool enable_recovery,const std::shared_ptr<std::vector<int>> & shape)120 WeightPtr Util::MakeWeightPtr(const std::shared_ptr<std::vector<float>> &data, bool enable_recovery,
121 const std::shared_ptr<std::vector<int>> &shape) {
122 WeightPtr weight_ptr;
123 if (!enable_recovery) {
124 weight_ptr = std::make_shared<Weight>(data, shape);
125 } else {
126 weight_ptr = std::make_shared<PersistentWeight>(data, shape);
127 }
128 return weight_ptr;
129 }
130
GetPrimitiveName(const CNodePtr & cnode)131 std::string Util::GetPrimitiveName(const CNodePtr &cnode) {
132 MS_EXCEPTION_IF_NULL(cnode);
133 auto &inputs = cnode->inputs();
134 if (inputs.empty()) {
135 MS_LOG(EXCEPTION) << "Inputs of node " << cnode->fullname_with_scope() << " is empty.";
136 return "";
137 }
138 auto fn = inputs[0];
139 if (!IsValueNode<Primitive>(fn)) {
140 return "";
141 }
142
143 auto node_prim = GetValueNode<PrimitivePtr>(fn);
144 MS_EXCEPTION_IF_NULL(node_prim);
145 return node_prim->name();
146 }
147
DoFusion(const FuncGraphPtr & func_graph,const std::string & cnode_name,const std::string & fused_cnode_name)148 void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
149 const std::string &fused_cnode_name) {
150 MS_EXCEPTION_IF_NULL(func_graph);
151 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
152
153 std::vector<AnfNodePtr> single_nodes;
154 std::vector<std::string> weight_names;
155 std::vector<int64_t> indices;
156 for (const AnfNodePtr &node : node_list) {
157 if (node != nullptr && node->isa<CNode>()) {
158 if (GetPrimitiveName(node->cast<CNodePtr>()) == cnode_name) {
159 single_nodes.push_back(node);
160
161 auto weight_name_value_node =
162 common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
163 const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
164 weight_names.push_back(weight_name);
165
166 auto weight_index_value_node =
167 common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
168 int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
169 indices.push_back(weight_index);
170 }
171 }
172 }
173
174 auto prim = std::make_shared<Primitive>(fused_cnode_name);
175 MS_EXCEPTION_IF_NULL(prim);
176 std::vector<AnfNodePtr> fused_node_inputs = {};
177 fused_node_inputs.push_back(NewValueNode(prim));
178 (void)std::for_each(single_nodes.begin(), single_nodes.end(), [&](const AnfNodePtr &node) {
179 fused_node_inputs.push_back(common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
180 });
181
182 auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
183 MS_EXCEPTION_IF_NULL(fused_cnode);
184 common::AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
185 common::AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
186 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);
187
188 auto kernel_info = std::make_shared<device::KernelInfo>();
189 MS_EXCEPTION_IF_NULL(kernel_info);
190 fused_cnode->set_kernel_info(kernel_info);
191 auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
192 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());
193
194 AbstractBasePtrList abstract_list;
195 for (const auto &node : single_nodes) {
196 auto cnode = node->cast<CNodePtr>();
197 MS_EXCEPTION_IF_NULL(cnode);
198 abstract_list.push_back(cnode->abstract());
199 }
200 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
201 MS_EXCEPTION_IF_NULL(abstract_tuple);
202 fused_cnode->set_abstract(abstract_tuple);
203
204 auto manager = func_graph->manager();
205 MS_EXCEPTION_IF_NULL(manager);
206 for (const auto &node : single_nodes) {
207 if (!manager->Replace(node, fused_cnode)) {
208 MS_LOG(EXCEPTION) << "manager replace node failed";
209 }
210 }
211 return;
212 }
213
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)214 kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
215 std::vector<std::string> inputs_device_format;
216 std::vector<std::string> outputs_device_format;
217 std::vector<TypeId> inputs_device_type;
218 std::vector<TypeId> outputs_device_type;
219 std::vector<ShapeVector> outputs_shape;
220 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
221 for (size_t idx = 0; idx < node_list.size(); ++idx) {
222 auto cnode = utils::cast<CNodePtr>(node_list[idx]);
223 MS_EXCEPTION_IF_NULL(cnode);
224 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
225 for (size_t input_index = 0; input_index < input_num; ++input_index) {
226 (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
227 inputs_device_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
228 }
229 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
230 for (size_t output_index = 0; output_index < output_num; ++output_index) {
231 (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
232 outputs_device_type.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
233 outputs_shape.push_back(common::AnfAlgo::GetOutputInferShape(cnode, output_index));
234 }
235 }
236 builder.SetInputsFormat(inputs_device_format);
237 builder.SetOutputsFormat(outputs_device_format);
238 builder.SetInputsDeviceType(inputs_device_type);
239 builder.SetOutputsDeviceType(outputs_device_type);
240 return builder.Build();
241 }
242 } // namespace ps
243 } // namespace mindspore
244