• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/node_info.h"
18 
19 #include <string>
20 
21 #include "base/core_ops.h"
22 #include "ir/param_info.h"
23 #include "ir/meta_tensor.h"
24 #include "pipeline/jit/parse/python_adapter.h"
25 #include "frontend/parallel/ops_info/ops_utils.h"
26 #include "frontend/parallel/step_parallel.h"
27 #include "frontend/parallel/step_parallel_utils.h"
28 
29 namespace mindspore {
30 namespace parallel {
31 const std::vector<std::string> filter_attrs = {RECOMPUTE, TARGET};
ParameterName(const AnfNodePtr & node_ptr)32 std::string ParameterName(const AnfNodePtr &node_ptr) {
33   auto para_ptr = node_ptr->cast<ParameterPtr>();
34   MS_EXCEPTION_IF_NULL(para_ptr);
35   return para_ptr->name();
36 }
37 
ParameterRequireGrad(const AnfNodePtr & node_ptr)38 bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
39   auto para_ptr = node_ptr->cast<ParameterPtr>();
40   if (para_ptr == nullptr) {
41     return false;
42   }
43   if (!para_ptr->has_default()) {
44     return false;
45   }
46   auto param_value = para_ptr->param_info();
47   if (param_value == nullptr) {
48     return false;
49   }
50   return param_value->requires_grad();
51 }
52 
GetRealInput(const AnfNodePtr & input)53 AnfNodePtr GetRealInput(const AnfNodePtr &input) {
54   if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
55     return input->cast<CNodePtr>()->input(1);
56   }
57   return input;
58 }
59 
60 // Given the node, return whether each input is a parameter or a output of a operator.
61 // The returned boolean vector should be the same order of the inputs, thus its implementation
62 // is closely consistent with ExtractShape() in step_parallel.cc
ExtractInputParameterByNode(const CNodePtr & node)63 std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
64   std::vector<bool> is_parameter;
65   std::vector<AnfNodePtr> node_inputs{node->inputs()};
66   // input is a ValueList or ValueTuple, then all inputs are not parameter.
67   if ((node_inputs.size() == 2) &&
68       (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
69     std::vector<ValuePtr> inputs_seq;
70     if (IsValueNode<ValueList>(node_inputs[1])) {
71       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
72     } else {
73       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
74     }
75     return std::vector<bool>(inputs_seq.size(), false);
76   }
77   if ((node_inputs.size() == 2) &&
78       (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
79     node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
80   }
81   for (size_t i = 1; i < node_inputs.size(); ++i) {
82     auto input = GetRealInput(node_inputs[i]);
83     if (HasAbstractMonad(input)) {
84       continue;
85     }
86     if (input->isa<Parameter>()) {
87       auto input_parameter = input->cast<ParameterPtr>();
88       is_parameter.push_back(ParameterRequireGrad(input_parameter));
89     } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
90       is_parameter.push_back(false);
91     }
92   }
93   return is_parameter;
94 }
95 
96 // Given the type, return the number of bytes to represent this type
GetLengthOfDataType(const TypePtr & type)97 size_t GetLengthOfDataType(const TypePtr &type) {
98   switch (type->type_id()) {
99     case kNumberTypeBool:
100       return sizeof(bool);
101     case kNumberTypeInt8:
102       return sizeof(int8_t);
103     case kNumberTypeInt16:
104       return sizeof(int16_t);
105     case kNumberTypeInt32:
106       return sizeof(int32_t);
107     case kNumberTypeInt64:
108       return sizeof(int64_t);
109     case kNumberTypeUInt8:
110       return sizeof(uint8_t);
111     case kNumberTypeUInt16:
112       return sizeof(uint16_t);
113     case kNumberTypeUInt32:
114       return sizeof(uint32_t);
115     case kNumberTypeUInt64:
116       return sizeof(uint64_t);
117     case kNumberTypeFloat16:
118       return sizeof(float) / 2;
119     case kNumberTypeFloat32:
120       return sizeof(float);
121     case kNumberTypeFloat64:
122       return sizeof(double);
123     case kNumberTypeInt:
124       return sizeof(int64_t);
125     case kNumberTypeUInt:
126       return sizeof(unsigned);
127     case kNumberTypeFloat:
128       return sizeof(float);
129     default:
130       MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
131   }
132 }
133 
GetInputsTypeLen(const AnfNodePtr & input)134 size_t GetInputsTypeLen(const AnfNodePtr &input) {
135   MS_EXCEPTION_IF_NULL(input);
136   if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
137     MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
138   }
139 
140   size_t input_type_len = 0;
141   auto type = input->Type();
142   MS_EXCEPTION_IF_NULL(type);
143   if (type->isa<mindspore::TensorType>()) {
144     auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
145     input_type_len = GetLengthOfDataType(input_element_type);
146   } else {
147     MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
148   }
149   return input_type_len;
150 }
151 
ExtractInputTypeLengthByNode(const CNodePtr & node)152 std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
153   MS_EXCEPTION_IF_NULL(node);
154   std::vector<size_t> inputs_type_len;
155   std::vector<AnfNodePtr> node_inputs{node->inputs()};
156 
157   if ((node_inputs.size() == 2) &&
158       (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) {
159     std::vector<ValuePtr> inputs_seq;
160     if (IsValueNode<ValueList>(node_inputs[1])) {
161       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
162     } else {
163       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
164     }
165     for (auto &ele : inputs_seq) {
166       auto tensor = ele->cast<tensor::TensorPtr>();
167       MS_EXCEPTION_IF_NULL(tensor);
168       inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
169     }
170     return inputs_type_len;
171   }
172 
173   if ((node_inputs.size() == 2) &&
174       (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) {
175     node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
176   }
177 
178   // extract input element length
179   for (auto &input : node_inputs) {
180     if (HasAbstractMonad(input)) {
181       continue;
182     }
183     if (IsValueNode<RefKey>(input)) {
184       auto func_graph = node->func_graph();
185       MS_EXCEPTION_IF_NULL(func_graph);
186       std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
187       if (parameters.size() != 1) {
188         MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
189       }
190       inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
191     } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
192       // extract input shape from parameter and apply node
193       inputs_type_len.push_back(GetInputsTypeLen(input));
194     }
195   }
196   return inputs_type_len;
197 }
198 
ExtractOutputTypeByNode(const CNodePtr & node)199 std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
200   MS_EXCEPTION_IF_NULL(node);
201   std::vector<TypePtr> outputs_type;
202   // extract output element type
203   auto primary_output_type = node->Type();
204   MS_EXCEPTION_IF_NULL(primary_output_type);
205   if (primary_output_type->isa<mindspore::Tuple>()) {
206     // in this case, the output is a tuple
207     auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
208     auto elements = tuple_output_type->elements();
209     for (auto &ele : elements) {
210       if (ele->isa<mindspore::TensorType>()) {
211         auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
212         outputs_type.push_back(ele_element_type);
213       } else {
214         MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
215       }
216     }
217   } else {
218     // in this case, the output is a single tensor
219     if (primary_output_type->isa<mindspore::TensorType>()) {
220       auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
221       outputs_type.push_back(element_type);
222     } else {
223       MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
224     }
225   }
226   return outputs_type;
227 }
228 
FindParameterByRefKeyNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph)229 std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
230   MS_EXCEPTION_IF_NULL(node);
231   MS_EXCEPTION_IF_NULL(func_graph);
232   std::vector<AnfNodePtr> parameters;
233   if (!IsValueNode<RefKey>(node)) {
234     MS_LOG(ERROR) << "The node is not a ref key";
235     return parameters;
236   }
237 
238   auto ref_key = GetValueNode<RefKeyPtr>(node);
239   MS_EXCEPTION_IF_NULL(ref_key);
240   auto name = ref_key->tag();
241 
242   auto manager = func_graph->manager();
243   MS_EXCEPTION_IF_NULL(manager);
244   auto roots = manager->roots();
245   if (roots.size() != 1) {
246     MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1";
247     return parameters;
248   }
249 
250   FuncGraphPtr root_g = roots.back();
251   MS_EXCEPTION_IF_NULL(root_g);
252   for (auto &param_node : root_g->parameters()) {
253     auto param = param_node->cast<ParameterPtr>();
254     if (param && (name == param->name())) {
255       parameters.push_back(param_node);
256       MS_LOG(INFO) << "The name of ref key is: " << name;
257       return parameters;
258     }
259   }
260 
261   MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter";
262   return parameters;
263 }
264 
AnfNodeIsPrimitive(const AnfNodePtr & anf_node,const std::string & prim_name)265 bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
266   MS_EXCEPTION_IF_NULL(anf_node);
267   auto cnode = anf_node->cast<CNodePtr>();
268   if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
269     return false;
270   }
271 
272   auto value_node = cnode->input(0)->cast<ValueNodePtr>();
273   auto prim = GetValueNode<PrimitivePtr>(value_node);
274   MS_EXCEPTION_IF_NULL(prim);
275   if (prim->name() == prim_name) {
276     return true;
277   }
278   return false;
279 }
280 
FindReshape(const CNodePtr & cnode,std::unordered_set<std::string> * op_cache)281 bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) {
282   if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
283     return false;
284   }
285   if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
286     return false;
287   }
288   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
289   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
290   MS_EXCEPTION_IF_NULL(prim);
291   if (prim->name() == RESHAPE) {
292     auto operator_info = cnode->user_data<OperatorInfo>();
293     std::string op_info_name = operator_info->name();
294     if (op_cache->find(op_info_name) != op_cache->end()) {
295       return false;
296     }
297     op_cache->insert(op_info_name);
298     return true;
299   }
300   return false;
301 }
302 
303 // Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
FindReshapePreNodeStraCosts(const AnfNodePtr & node,OperatorInfoPtr * pre_operator_info,int64_t * out_index,size_t curr_depth)304 bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index,
305                                  size_t curr_depth) {
306   if (curr_depth > MAX_RECURSIVE_DEPTH) {
307     MS_LOG(WARNING) << "When finding Reshape's previous node, exceeded the max recursive depth: "
308                     << MAX_RECURSIVE_DEPTH;
309     return false;
310   }
311   // if previous node is a parameter, handle it in the outsize.
312   if (node->isa<Parameter>()) {
313     return false;
314   }
315   if (!node->isa<CNode>()) {
316     return false;
317   }
318   CNodePtr cnode = node->cast<CNodePtr>();
319   if (!IsValueNode<Primitive>(cnode->input(0))) {
320     return false;
321   }
322   auto node_op_info = cnode->user_data<OperatorInfo>();
323   if (IsParallelCareNode(cnode) && (node_op_info != nullptr) && !IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
324     *pre_operator_info = node_op_info;
325     *out_index = 0;
326     return true;
327   }
328   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
329   PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
330   if (prim->name() == prim::kTupleGetItem) {
331     *out_index = GetTupleGetItemIndex(cnode);
332     // find tuple_get_item's previous node
333     auto pre_node = cnode->input(1);
334     if (!pre_node->isa<CNode>()) {
335       MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
336     }
337     CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
338     auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
339     if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
340       *pre_operator_info = pre_op_info;
341       return true;
342     }
343     return false;
344   }
345   for (size_t index = 0; index < cnode->inputs().size(); ++index) {
346     if (prim->name() == DEPEND && index != 1) {
347       continue;
348     }
349     if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index, ++curr_depth)) {
350       continue;
351     }
352     return true;
353   }
354   MS_LOG(WARNING)
355     << "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
356   return false;
357 }
358 
359 // Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
360 // if reshape's output connect to several primitive, return the first layout found
FindReshapeNextNodeStraCosts(const CNodePtr & cnode,OperatorInfoPtr * next_operator_info,int64_t * in_index,bool * is_next_reshape,size_t curr_depth)361 bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index,
362                                   bool *is_next_reshape, size_t curr_depth) {
363   if (curr_depth > MAX_RECURSIVE_DEPTH) {
364     MS_LOG(WARNING) << "When finding Reshape's next node, exceeded the max recursive depth: " << MAX_RECURSIVE_DEPTH;
365     return false;
366   }
367   MS_EXCEPTION_IF_NULL(cnode);
368   MS_EXCEPTION_IF_NULL(cnode->func_graph());
369   FuncGraphManagerPtr manager = cnode->func_graph()->manager();
370   MS_EXCEPTION_IF_NULL(manager);
371   AnfNodeIndexSet node_set = manager->node_users()[cnode];
372   for (auto &node_pair : node_set) {
373     CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
374     if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
375       continue;
376     }
377     if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
378       *is_next_reshape = true;
379       continue;
380     }
381     ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
382     MS_EXCEPTION_IF_NULL(prim_anf_node);
383     PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
384     MS_EXCEPTION_IF_NULL(node_prim);
385     MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
386     if (node_prim->name() == DEPEND && node_pair.second != 1) {
387       continue;
388     }
389     auto op_info = use_apply->user_data<OperatorInfo>();
390     if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
391       MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name();
392       *is_next_reshape = false;
393       *next_operator_info = op_info;
394       *in_index = node_pair.second - 1;
395       return true;
396     }
397     MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << "  "
398                   << IsParallelCareNode(use_apply) << "   " << (op_info != nullptr);
399 
400     if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index, is_next_reshape, ++curr_depth)) {
401       return true;
402     }
403   }
404   return false;
405 }
406 
SetUserAttrs(const std::unordered_map<std::string,ValuePtr> & origin_prim_attrs,const PrimitivePtr & self_prim)407 void SetUserAttrs(const std::unordered_map<std::string, ValuePtr> &origin_prim_attrs, const PrimitivePtr &self_prim) {
408   MS_EXCEPTION_IF_NULL(self_prim);
409   for (auto attr_name : filter_attrs) {
410     auto iter = origin_prim_attrs.find(attr_name);
411     if (iter != origin_prim_attrs.cend()) {
412       self_prim->set_attr(attr_name, iter->second);
413       MS_LOG(INFO) << "The new prim " << self_prim << " add attr " << attr_name;
414     }
415   }
416 }
417 }  // namespace parallel
418 }  // namespace mindspore
419