• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/step_parallel_utils.h"
18 
19 #include <inttypes.h>
20 #include <sys/time.h>
21 #include <algorithm>
22 
23 #include <map>
24 #include <set>
25 #include <string>
26 #include <unordered_map>
27 #include <utility>
28 
29 #include "base/core_ops.h"
30 #include "frontend/operator/ops.h"
31 #include "frontend/optimizer/optimizer.h"
32 #include "frontend/parallel/context.h"
33 #include "frontend/parallel/device_manager.h"
34 #include "frontend/parallel/graph_util/generate_graph.h"
35 #include "frontend/parallel/graph_util/graph_info.h"
36 #include "frontend/parallel/graph_util/node_info.h"
37 #include "frontend/parallel/node_check.h"
38 #include "ir/param_info.h"
39 #include "ir/tensor.h"
40 #include "utils/trace_base.h"
41 #include "utils/comm_manager.h"
42 #include "utils/ms_context.h"
43 #include "utils/symbolic.h"
44 #include "mindspore/core/utils/parallel_node_check.h"
45 
46 namespace mindspore {
47 namespace parallel {
IsSomePrimitive(const CNodePtr & cnode,const std::string & name)48 bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) {
49   if (!cnode) {
50     return false;
51   }
52   ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>();
53   MS_EXCEPTION_IF_NULL(anf_node);
54   PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
55   return (prim->name() == name);
56 }
57 
IsParallelCareNode(const CNodePtr & cnode)58 bool IsParallelCareNode(const CNodePtr &cnode) {
59   MS_EXCEPTION_IF_NULL(cnode);
60   ValueNodePtr prim_node = cnode->input(0)->cast<ValueNodePtr>();
61   if (prim_node == nullptr) {
62     return false;
63   }
64   PrimitivePtr prim = prim_node->value()->cast<PrimitivePtr>();
65   if (prim == nullptr) {
66     return false;
67   }
68   if (IsInParallelBlackList(prim)) {
69     MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
70     return false;
71   }
72   // get_next is not in the forward graph, we need mark the get_next as the forward node
73   if (prim->name() == GET_NEXT || prim->name() == VIRTUAL_OUTPUT) {
74     return true;
75   }
76   if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
77     return false;
78   }
79 
80   return cnode->in_forward_flag();
81 }
82 
GetValueListShape(const AnfNodePtr & node)83 Shapes GetValueListShape(const AnfNodePtr &node) {
84   Shapes shapes;
85   std::vector<ValuePtr> inputs_seq;
86   if (IsValueNode<ValueList>(node)) {
87     inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
88   } else if (IsValueNode<ValueTuple>(node)) {
89     inputs_seq = node->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
90   } else {
91     MS_LOG(EXCEPTION) << "node is eigther ValueList or ValueTuple";
92   }
93   for (auto &ele : inputs_seq) {
94     auto tensor = ele->cast<tensor::TensorPtr>();
95     if (tensor == nullptr) {
96       MS_LOG(WARNING) << "The value node is not a tensor";
97       break;
98     }
99     auto one_shape = tensor->shape();
100     shapes.push_back(one_shape);
101   }
102   return shapes;
103 }
104 
GetNodeShape(const AnfNodePtr & node)105 Shapes GetNodeShape(const AnfNodePtr &node) {
106   MS_EXCEPTION_IF_NULL(node);
107   Shapes shapes;
108   if (IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) {
109     return GetValueListShape(node);
110   }
111   BaseShapePtr base_shape_ptr = node->Shape();
112   if (node->isa<CNode>()) {
113     auto cnode = node->cast<CNodePtr>();
114     if (IsValueNode<Primitive>(cnode->input(0))) {
115       PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
116       MS_EXCEPTION_IF_NULL(prim);
117       if (prim->name() == MAKEREF) {
118         AnfNodePtr ref_node = cnode->input(1);
119         auto func_graph = cnode->func_graph();
120         MS_EXCEPTION_IF_NULL(ref_node);
121         MS_EXCEPTION_IF_NULL(func_graph);
122         return GetRefKeyNodeShape(ref_node, func_graph);
123       }
124     }
125     if (cnode->input(0)->isa<CNode>()) {
126       if (cnode->inputs().size() < 2) {
127         MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
128       }
129       base_shape_ptr = cnode->input(1)->Shape();
130     }
131   }
132   if (base_shape_ptr == nullptr) {
133     MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is "
134                       << node->fullname_with_scope();
135   }
136   auto tuple_shape_ptr = dyn_cast<abstract::SequeueShape>(base_shape_ptr);
137   if (tuple_shape_ptr != nullptr) {
138     auto tuple_shape = tuple_shape_ptr->shape();
139     for (auto &shape : tuple_shape) {
140       auto each_shape = dyn_cast<abstract::Shape>(shape);
141       MS_EXCEPTION_IF_NULL(each_shape);
142       shapes.push_back(each_shape->shape());
143     }
144   } else {
145     auto shape_ptr = dyn_cast<abstract::Shape>(base_shape_ptr);
146     MS_EXCEPTION_IF_NULL(shape_ptr);
147     shapes.push_back(shape_ptr->shape());
148   }
149   return shapes;
150 }
151 
CreateInstanceName(const CNodePtr & node,size_t index)152 std::string CreateInstanceName(const CNodePtr &node, size_t index) {
153   MS_EXCEPTION_IF_NULL(node);
154   if (!IsValueNode<Primitive>(node->input(0))) {
155     MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive";
156   }
157   std::string name_base = node->fullname_with_scope();
158   std::string name = name_base + "_" + std::to_string(index);
159   std::string instance_name = HashInstanceName(name);
160   return instance_name;
161 }
162 
SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input)163 void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) {
164   if (new_node_input.empty()) {
165     return;
166   }
167 
168   auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
169   auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
170   MS_EXCEPTION_IF_NULL(prim);
171 
172   auto attrs = prim->attrs();
173   auto iter = attrs.find(GROUP);
174   if (iter != attrs.end()) {
175     auto value = iter->second;
176     MS_EXCEPTION_IF_NULL(value);
177     if (value->isa<StringImm>()) {
178       std::string hash_name = value->cast<StringImmPtr>()->value();
179       MS_EXCEPTION_IF_NULL(g_device_manager);
180       std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name);
181       (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name));
182     }
183   }
184 }
185 
ReplaceOpInput(const Operator & replace_op,const std::string & instance_name,const CNodePtr & node)186 std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
187                                        const CNodePtr &node) {
188   OperatorArgs arg_replace_op = replace_op.second;
189   ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name);
190   if (pyop_instance == nullptr) {
191     MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed";
192   }
193   OperatorParams params = arg_replace_op.second;
194   if (node->inputs().size() < 2) {
195     // GetNext operator dose not has input
196     if (node->inputs().size() == 1) {
197       return {NewValueNode(pyop_instance)};
198     }
199     MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
200   }
201   std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)};
202 
203   if (replace_op.first == EMBEDDING_LOOKUP) {
204     replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)};
205   }
206 
207   if (!params.empty()) {
208     Param param_first = *(params.begin());
209     int64_t first_position = param_first.second;
210     if (first_position == 1) {
211       replace_input.pop_back();
212     }
213     for (auto &param : params) {
214       AnfNodePtr val = NewValueNode(param.first.second);
215       if (val == nullptr) {
216         MS_LOG(EXCEPTION) << "Failure:val is nullptr";
217       }
218       int64_t position = param.second;
219       (void)replace_input.insert(replace_input.begin() + position, val);
220     }
221   } else if (replace_op.first == SYNC_BATCH_NORM) {
222     for (size_t i = 2; i < node->inputs().size(); ++i) {
223       replace_input.push_back(node->input(i));
224     }
225   }
226   SetCommunicationOpGroupLabel(replace_input);
227   return replace_input;
228 }
229 }  // namespace parallel
230 }  // namespace mindspore
231