• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/distributed/ps/util.h"
18 #include <vector>
19 #include <memory>
20 #include "mindspore/core/ops/ascend_op_name.h"
21 #include "mindspore/core/ops/other_op_name.h"
22 #include "utils/hash_map.h"
23 #include "include/backend/distributed/ps/constants.h"
24 #include "include/backend/distributed/ps/ps_context.h"
25 #include "distributed/persistent/data.h"
26 
27 namespace mindspore {
28 namespace ps {
29 namespace {
30 static mindspore::HashMap<std::string, int64_t> optimizer_to_ids = {
31   {kApplyMomentum, 0},
32   {kSparseAdam, 1},
33   {kSparseLazyAdam, 2},
34   {kSparseFtrl, 3},
35 };
36 
37 static mindspore::HashMap<int64_t, std::string> id_to_optimizers = {
38   {0, kApplyMomentum},
39   {1, kSparseAdam},
40   {2, kSparseLazyAdam},
41   {3, kSparseFtrl},
42 };
43 
44 static mindspore::HashMap<int64_t, std::string> id_to_optimizer_nodes = {
45   {0, kApplyMomentumOp},
46   {1, kSparseAdamOp},
47   {2, kSparseLazyAdamOp},
48   {3, kSparseFtrlOp},
49 };
50 }  // namespace
51 
IsRoleOfPServer()52 bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); }
53 
IsRoleOfScheduler()54 bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); }
55 
optimizer_id(const std::string & name)56 int64_t Util::optimizer_id(const std::string &name) {
57   if (optimizer_to_ids.count(name) > 0) {
58     return optimizer_to_ids[name];
59   }
60   return -1;
61 }
62 
optimizer_name(int64_t id)63 std::string Util::optimizer_name(int64_t id) {
64   if (id_to_optimizers.count(id) > 0) {
65     return id_to_optimizers[id];
66   }
67   return "";
68 }
69 
optimizer_node_name(int64_t id)70 std::string Util::optimizer_node_name(int64_t id) {
71   if (id_to_optimizer_nodes.count(id) > 0) {
72     return id_to_optimizer_nodes[id];
73   }
74   return "";
75 }
76 
is_optimizer(const std::string & name)77 bool Util::is_optimizer(const std::string &name) { return optimizer_to_ids.count(name) > 0; }
78 
LocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)79 int64_t Util::LocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
80   std::map<int64_t, int64_t> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
81   if (shard_dims.count(rank_id) == 0) {
82     MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
83   }
84   return shard_dims[rank_id];
85 }
86 
AllRankLocalShard(int64_t first_dim,int64_t rank_id,int64_t server_num)87 std::map<int64_t, int64_t> Util::AllRankLocalShard(int64_t first_dim, int64_t rank_id, int64_t server_num) {
88   if (first_dim <= 0 || server_num <= 0 || rank_id < 0) {
89     MS_LOG(EXCEPTION) << "Input values are invalid, first_dim: " << first_dim << ", server_num: " << server_num
90                       << ", rank_id: " << rank_id;
91   }
92   if (rank_id >= server_num) {
93     MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
94   }
95   std::map<int64_t, int64_t> shard_dims;
96   for (int64_t i = 0; i < server_num; i++) {
97     shard_dims[i] = 0;
98   }
99   if (server_num != static_cast<int64_t>(shard_dims.size())) {
100     MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
101   }
102   int64_t server_index = -1;
103   for (int64_t i = 0; i < first_dim; i++) {
104     server_index = (server_index + 1) % server_num;
105     shard_dims[server_index] = shard_dims[server_index] + 1;
106   }
107   if (shard_dims.count(rank_id) == 0) {
108     MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
109   }
110   return shard_dims;
111 }
112 
FuseServerCommOps(const FuncGraphPtr & func_graph)113 bool Util::FuseServerCommOps(const FuncGraphPtr &func_graph) {
114   MS_EXCEPTION_IF_NULL(func_graph);
115   DoFusion(func_graph, kPullWeightOpName, kFusedPullWeightOpName);
116   DoFusion(func_graph, kPushWeightOpName, kFusedPushWeightOpName);
117   return true;
118 }
119 
MakeWeightPtr(const std::shared_ptr<std::vector<float>> & data,bool enable_recovery,const std::shared_ptr<std::vector<int>> & shape)120 WeightPtr Util::MakeWeightPtr(const std::shared_ptr<std::vector<float>> &data, bool enable_recovery,
121                               const std::shared_ptr<std::vector<int>> &shape) {
122   WeightPtr weight_ptr;
123   if (!enable_recovery) {
124     weight_ptr = std::make_shared<Weight>(data, shape);
125   } else {
126     weight_ptr = std::make_shared<PersistentWeight>(data, shape);
127   }
128   return weight_ptr;
129 }
130 
GetPrimitiveName(const CNodePtr & cnode)131 std::string Util::GetPrimitiveName(const CNodePtr &cnode) {
132   MS_EXCEPTION_IF_NULL(cnode);
133   auto &inputs = cnode->inputs();
134   if (inputs.empty()) {
135     MS_LOG(EXCEPTION) << "Inputs of node " << cnode->fullname_with_scope() << " is empty.";
136     return "";
137   }
138   auto fn = inputs[0];
139   if (!IsValueNode<Primitive>(fn)) {
140     return "";
141   }
142 
143   auto node_prim = GetValueNode<PrimitivePtr>(fn);
144   MS_EXCEPTION_IF_NULL(node_prim);
145   return node_prim->name();
146 }
147 
DoFusion(const FuncGraphPtr & func_graph,const std::string & cnode_name,const std::string & fused_cnode_name)148 void Util::DoFusion(const FuncGraphPtr &func_graph, const std::string &cnode_name,
149                     const std::string &fused_cnode_name) {
150   MS_EXCEPTION_IF_NULL(func_graph);
151   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
152 
153   std::vector<AnfNodePtr> single_nodes;
154   std::vector<std::string> weight_names;
155   std::vector<int64_t> indices;
156   for (const AnfNodePtr &node : node_list) {
157     if (node != nullptr && node->isa<CNode>()) {
158       if (GetPrimitiveName(node->cast<CNodePtr>()) == cnode_name) {
159         single_nodes.push_back(node);
160 
161         auto weight_name_value_node =
162           common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightNameOffset)->cast<ValueNodePtr>();
163         const std::string &weight_name = GetValue<std::string>(weight_name_value_node->value());
164         weight_names.push_back(weight_name);
165 
166         auto weight_index_value_node =
167           common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), kNodeInputWeightIndexOffset)->cast<ValueNodePtr>();
168         int64_t weight_index = GetValue<int64_t>(weight_index_value_node->value());
169         indices.push_back(weight_index);
170       }
171     }
172   }
173 
174   auto prim = std::make_shared<Primitive>(fused_cnode_name);
175   MS_EXCEPTION_IF_NULL(prim);
176   std::vector<AnfNodePtr> fused_node_inputs = {};
177   fused_node_inputs.push_back(NewValueNode(prim));
178   (void)std::for_each(single_nodes.begin(), single_nodes.end(), [&](const AnfNodePtr &node) {
179     fused_node_inputs.push_back(common::AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0));
180   });
181 
182   auto fused_cnode = func_graph->NewCNode(fused_node_inputs);
183   MS_EXCEPTION_IF_NULL(fused_cnode);
184   common::AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(weight_names), fused_cnode);
185   common::AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(indices), fused_cnode);
186   common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(kCPUDevice), fused_cnode);
187 
188   auto kernel_info = std::make_shared<device::KernelInfo>();
189   MS_EXCEPTION_IF_NULL(kernel_info);
190   fused_cnode->set_kernel_info(kernel_info);
191   auto kernel_build_info = GenerateKernelBuildInfo(single_nodes);
192   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_cnode.get());
193 
194   AbstractBasePtrList abstract_list;
195   for (const auto &node : single_nodes) {
196     auto cnode = node->cast<CNodePtr>();
197     MS_EXCEPTION_IF_NULL(cnode);
198     abstract_list.push_back(cnode->abstract());
199   }
200   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
201   MS_EXCEPTION_IF_NULL(abstract_tuple);
202   fused_cnode->set_abstract(abstract_tuple);
203 
204   auto manager = func_graph->manager();
205   MS_EXCEPTION_IF_NULL(manager);
206   for (const auto &node : single_nodes) {
207     if (!manager->Replace(node, fused_cnode)) {
208       MS_LOG(EXCEPTION) << "manager replace node failed";
209     }
210   }
211   return;
212 }
213 
GenerateKernelBuildInfo(const std::vector<AnfNodePtr> & node_list)214 kernel::KernelBuildInfoPtr Util::GenerateKernelBuildInfo(const std::vector<AnfNodePtr> &node_list) {
215   std::vector<std::string> inputs_device_format;
216   std::vector<std::string> outputs_device_format;
217   std::vector<TypeId> inputs_device_type;
218   std::vector<TypeId> outputs_device_type;
219   std::vector<ShapeVector> outputs_shape;
220   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
221   for (size_t idx = 0; idx < node_list.size(); ++idx) {
222     auto cnode = utils::cast<CNodePtr>(node_list[idx]);
223     MS_EXCEPTION_IF_NULL(cnode);
224     size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
225     for (size_t input_index = 0; input_index < input_num; ++input_index) {
226       (void)inputs_device_format.emplace_back(kOpFormat_DEFAULT);
227       inputs_device_type.push_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index));
228     }
229     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
230     for (size_t output_index = 0; output_index < output_num; ++output_index) {
231       (void)outputs_device_format.emplace_back(kOpFormat_DEFAULT);
232       outputs_device_type.push_back(common::AnfAlgo::GetOutputInferDataType(cnode, output_index));
233       outputs_shape.push_back(common::AnfAlgo::GetOutputInferShape(cnode, output_index));
234     }
235   }
236   builder.SetInputsFormat(inputs_device_format);
237   builder.SetOutputsFormat(outputs_device_format);
238   builder.SetInputsDeviceType(inputs_device_type);
239   builder.SetOutputsDeviceType(outputs_device_type);
240   return builder.Build();
241 }
242 }  // namespace ps
243 }  // namespace mindspore
244