1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/graph_util/node_info.h"
18
19 #include <string>
20
21 #include "base/core_ops.h"
22 #include "ir/param_info.h"
23 #include "ir/meta_tensor.h"
24 #include "pipeline/jit/parse/python_adapter.h"
25 #include "frontend/parallel/ops_info/ops_utils.h"
26 #include "frontend/parallel/step_parallel.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28
29 namespace mindspore {
30 namespace parallel {
31 const std::vector<std::string> filter_attrs = {RECOMPUTE, TARGET};
ParameterName(const AnfNodePtr & node_ptr)32 std::string ParameterName(const AnfNodePtr &node_ptr) {
33 auto para_ptr = node_ptr->cast<ParameterPtr>();
34 MS_EXCEPTION_IF_NULL(para_ptr);
35 return para_ptr->name();
36 }
37
ParameterRequireGrad(const AnfNodePtr & node_ptr)38 bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
39 auto para_ptr = node_ptr->cast<ParameterPtr>();
40 if (para_ptr == nullptr) {
41 return false;
42 }
43 if (!para_ptr->has_default()) {
44 return false;
45 }
46 auto param_value = para_ptr->param_info();
47 if (param_value == nullptr) {
48 return false;
49 }
50 return param_value->requires_grad();
51 }
52
GetRealInput(const AnfNodePtr & input)53 AnfNodePtr GetRealInput(const AnfNodePtr &input) {
54 if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
55 return input->cast<CNodePtr>()->input(1);
56 }
57 return input;
58 }
59
60 // Given the node, return whether each input is a parameter or a output of a operator.
61 // The returned boolean vector should be the same order of the inputs, thus its implementation
62 // is closely consistent with ExtractShape() in step_parallel.cc
ExtractInputParameterByNode(const CNodePtr & node)63 std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
64 std::vector<bool> is_parameter;
65 std::vector<AnfNodePtr> node_inputs{node->inputs()};
66 // input is a ValueList or ValueTuple, then all inputs are not parameter.
67 if ((node_inputs.size() == 2) &&
68 (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
69 std::vector<ValuePtr> inputs_seq;
70 if (IsValueNode<ValueList>(node_inputs[1])) {
71 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
72 } else {
73 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
74 }
75 return std::vector<bool>(inputs_seq.size(), false);
76 }
77 if ((node_inputs.size() == 2) &&
78 (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
79 node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
80 }
81 for (size_t i = 1; i < node_inputs.size(); ++i) {
82 auto input = GetRealInput(node_inputs[i]);
83 if (HasAbstractMonad(input)) {
84 continue;
85 }
86 if (input->isa<Parameter>()) {
87 auto input_parameter = input->cast<ParameterPtr>();
88 is_parameter.push_back(ParameterRequireGrad(input_parameter));
89 } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
90 is_parameter.push_back(false);
91 }
92 }
93 return is_parameter;
94 }
95
96 // Given the type, return the number of bytes to represent this type
GetLengthOfDataType(const TypePtr & type)97 size_t GetLengthOfDataType(const TypePtr &type) {
98 switch (type->type_id()) {
99 case kNumberTypeBool:
100 return sizeof(bool);
101 case kNumberTypeInt8:
102 return sizeof(int8_t);
103 case kNumberTypeInt16:
104 return sizeof(int16_t);
105 case kNumberTypeInt32:
106 return sizeof(int32_t);
107 case kNumberTypeInt64:
108 return sizeof(int64_t);
109 case kNumberTypeUInt8:
110 return sizeof(uint8_t);
111 case kNumberTypeUInt16:
112 return sizeof(uint16_t);
113 case kNumberTypeUInt32:
114 return sizeof(uint32_t);
115 case kNumberTypeUInt64:
116 return sizeof(uint64_t);
117 case kNumberTypeFloat16:
118 return sizeof(float) / 2;
119 case kNumberTypeFloat32:
120 return sizeof(float);
121 case kNumberTypeFloat64:
122 return sizeof(double);
123 case kNumberTypeInt:
124 return sizeof(int64_t);
125 case kNumberTypeUInt:
126 return sizeof(unsigned);
127 case kNumberTypeFloat:
128 return sizeof(float);
129 default:
130 MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
131 }
132 }
133
GetInputsTypeLen(const AnfNodePtr & input)134 size_t GetInputsTypeLen(const AnfNodePtr &input) {
135 MS_EXCEPTION_IF_NULL(input);
136 if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
137 MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
138 }
139
140 size_t input_type_len = 0;
141 auto type = input->Type();
142 MS_EXCEPTION_IF_NULL(type);
143 if (type->isa<mindspore::TensorType>()) {
144 auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
145 input_type_len = GetLengthOfDataType(input_element_type);
146 } else {
147 MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
148 }
149 return input_type_len;
150 }
151
ExtractInputTypeLengthByNode(const CNodePtr & node)152 std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
153 MS_EXCEPTION_IF_NULL(node);
154 std::vector<size_t> inputs_type_len;
155 std::vector<AnfNodePtr> node_inputs{node->inputs()};
156
157 if ((node_inputs.size() == 2) &&
158 (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
159 std::vector<ValuePtr> inputs_seq;
160 if (IsValueNode<ValueList>(node_inputs[1])) {
161 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
162 } else {
163 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
164 }
165 for (auto &ele : inputs_seq) {
166 auto tensor = ele->cast<tensor::TensorPtr>();
167 MS_EXCEPTION_IF_NULL(tensor);
168 inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
169 }
170 return inputs_type_len;
171 }
172
173 if ((node_inputs.size() == 2) &&
174 (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
175 node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
176 }
177
178 // extract input element length
179 for (auto &input : node_inputs) {
180 if (HasAbstractMonad(input)) {
181 continue;
182 }
183 if (IsValueNode<RefKey>(input)) {
184 auto func_graph = node->func_graph();
185 MS_EXCEPTION_IF_NULL(func_graph);
186 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
187 if (parameters.size() != 1) {
188 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
189 }
190 inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
191 } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
192 // extract input shape from parameter and apply node
193 inputs_type_len.push_back(GetInputsTypeLen(input));
194 }
195 }
196 return inputs_type_len;
197 }
198
ExtractOutputTypeByNode(const CNodePtr & node)199 std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
200 MS_EXCEPTION_IF_NULL(node);
201 std::vector<TypePtr> outputs_type;
202 // extract output element type
203 auto primary_output_type = node->Type();
204 MS_EXCEPTION_IF_NULL(primary_output_type);
205 if (primary_output_type->isa<mindspore::Tuple>()) {
206 // in this case, the output is a tuple
207 auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
208 auto elements = tuple_output_type->elements();
209 for (auto &ele : elements) {
210 if (ele->isa<mindspore::TensorType>()) {
211 auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
212 outputs_type.push_back(ele_element_type);
213 } else {
214 MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
215 }
216 }
217 } else {
218 // in this case, the output is a single tensor
219 if (primary_output_type->isa<mindspore::TensorType>()) {
220 auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
221 outputs_type.push_back(element_type);
222 } else {
223 MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
224 }
225 }
226 return outputs_type;
227 }
228
FindParameterByRefKeyNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph)229 std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
230 MS_EXCEPTION_IF_NULL(node);
231 MS_EXCEPTION_IF_NULL(func_graph);
232 std::vector<AnfNodePtr> parameters;
233 if (!IsValueNode<RefKey>(node)) {
234 MS_LOG(ERROR) << "The node is not a ref key";
235 return parameters;
236 }
237
238 auto ref_key = GetValueNode<RefKeyPtr>(node);
239 MS_EXCEPTION_IF_NULL(ref_key);
240 auto name = ref_key->tag();
241
242 auto manager = func_graph->manager();
243 MS_EXCEPTION_IF_NULL(manager);
244 auto roots = manager->roots();
245 if (roots.size() != 1) {
246 MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1";
247 return parameters;
248 }
249
250 FuncGraphPtr root_g = roots.back();
251 MS_EXCEPTION_IF_NULL(root_g);
252 for (auto ¶m_node : root_g->parameters()) {
253 auto param = param_node->cast<ParameterPtr>();
254 if (param && (name == param->name())) {
255 parameters.push_back(param_node);
256 MS_LOG(INFO) << "The name of ref key is: " << name;
257 return parameters;
258 }
259 }
260
261 MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter";
262 return parameters;
263 }
264
AnfNodeIsPrimitive(const AnfNodePtr & anf_node,const std::string & prim_name)265 bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
266 MS_EXCEPTION_IF_NULL(anf_node);
267 auto cnode = anf_node->cast<CNodePtr>();
268 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
269 return false;
270 }
271
272 auto value_node = cnode->input(0)->cast<ValueNodePtr>();
273 auto prim = GetValueNode<PrimitivePtr>(value_node);
274 MS_EXCEPTION_IF_NULL(prim);
275 if (prim->name() == prim_name) {
276 return true;
277 }
278 return false;
279 }
280
FindReshape(const CNodePtr & cnode,std::unordered_set<std::string> * op_cache)281 bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) {
282 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
283 return false;
284 }
285 if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
286 return false;
287 }
288 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
289 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
290 MS_EXCEPTION_IF_NULL(prim);
291 if (prim->name() == RESHAPE) {
292 auto operator_info = cnode->user_data<OperatorInfo>();
293 std::string op_info_name = operator_info->name();
294 if (op_cache->find(op_info_name) != op_cache->end()) {
295 return false;
296 }
297 op_cache->insert(op_info_name);
298 return true;
299 }
300 return false;
301 }
302
303 // Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
FindReshapePreNodeStraCosts(const AnfNodePtr & node,OperatorInfoPtr * pre_operator_info,int64_t * out_index,size_t curr_depth)304 bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index,
305 size_t curr_depth) {
306 if (curr_depth > MAX_RECURSIVE_DEPTH) {
307 MS_LOG(WARNING) << "When finding Reshape's previous node, exceeded the max recursive depth: "
308 << MAX_RECURSIVE_DEPTH;
309 return false;
310 }
311 // if previous node is a parameter, handle it in the outsize.
312 if (node->isa<Parameter>()) {
313 return false;
314 }
315 if (!node->isa<CNode>()) {
316 return false;
317 }
318 CNodePtr cnode = node->cast<CNodePtr>();
319 if (!IsValueNode<Primitive>(cnode->input(0))) {
320 return false;
321 }
322 auto node_op_info = cnode->user_data<OperatorInfo>();
323 if (IsParallelCareNode(cnode) && (node_op_info != nullptr) && !IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
324 *pre_operator_info = node_op_info;
325 *out_index = 0;
326 return true;
327 }
328 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
329 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
330 if (prim->name() == prim::kTupleGetItem) {
331 *out_index = GetTupleGetItemIndex(cnode);
332 // find tuple_get_item's previous node
333 auto pre_node = cnode->input(1);
334 if (!pre_node->isa<CNode>()) {
335 MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
336 }
337 CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
338 auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
339 if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
340 *pre_operator_info = pre_op_info;
341 return true;
342 }
343 return false;
344 }
345 for (size_t index = 0; index < cnode->inputs().size(); ++index) {
346 if (prim->name() == DEPEND && index != 1) {
347 continue;
348 }
349 if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index, ++curr_depth)) {
350 continue;
351 }
352 return true;
353 }
354 MS_LOG(WARNING)
355 << "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
356 return false;
357 }
358
359 // Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
360 // if reshape's output connect to several primitive, return the first layout found
FindReshapeNextNodeStraCosts(const CNodePtr & cnode,OperatorInfoPtr * next_operator_info,int64_t * in_index,bool * is_next_reshape,size_t curr_depth)361 bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index,
362 bool *is_next_reshape, size_t curr_depth) {
363 if (curr_depth > MAX_RECURSIVE_DEPTH) {
364 MS_LOG(WARNING) << "When finding Reshape's next node, exceeded the max recursive depth: " << MAX_RECURSIVE_DEPTH;
365 return false;
366 }
367 MS_EXCEPTION_IF_NULL(cnode);
368 MS_EXCEPTION_IF_NULL(cnode->func_graph());
369 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
370 MS_EXCEPTION_IF_NULL(manager);
371 AnfNodeIndexSet node_set = manager->node_users()[cnode];
372 for (auto &node_pair : node_set) {
373 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
374 if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
375 continue;
376 }
377 if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
378 *is_next_reshape = true;
379 continue;
380 }
381 ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
382 MS_EXCEPTION_IF_NULL(prim_anf_node);
383 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
384 MS_EXCEPTION_IF_NULL(node_prim);
385 MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
386 if (node_prim->name() == DEPEND && node_pair.second != 1) {
387 continue;
388 }
389 auto op_info = use_apply->user_data<OperatorInfo>();
390 if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
391 MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name();
392 *is_next_reshape = false;
393 *next_operator_info = op_info;
394 *in_index = node_pair.second - 1;
395 return true;
396 }
397 MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " "
398 << IsParallelCareNode(use_apply) << " " << (op_info != nullptr);
399
400 if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index, is_next_reshape, ++curr_depth)) {
401 return true;
402 }
403 }
404 return false;
405 }
406
SetUserAttrs(const std::unordered_map<std::string,ValuePtr> & origin_prim_attrs,const PrimitivePtr & self_prim)407 void SetUserAttrs(const std::unordered_map<std::string, ValuePtr> &origin_prim_attrs, const PrimitivePtr &self_prim) {
408 MS_EXCEPTION_IF_NULL(self_prim);
409 for (auto attr_name : filter_attrs) {
410 auto iter = origin_prim_attrs.find(attr_name);
411 if (iter != origin_prim_attrs.cend()) {
412 self_prim->set_attr(attr_name, iter->second);
413 MS_LOG(INFO) << "The new prim " << self_prim << " add attr " << attr_name;
414 }
415 }
416 }
417 } // namespace parallel
418 } // namespace mindspore
419