• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_ANFALGO_H
18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_ANFALGO_H
19 
20 #include <functional>
21 #include <iostream>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 #include <utility>
29 #include <vector>
30 #include "base/base.h"
31 #include "include/common/utils/contract.h"
32 #include "include/common/utils/utils.h"
33 #include "include/common/visible.h"
34 #include "ir/anf.h"
35 #include "ir/dtype.h"
36 #include "ir/func_graph.h"
37 #include "ir/kernel_info_dev.h"
38 #include "ir/primitive.h"
39 #include "ops/array_op_name.h"
40 #include "ops/other_op_name.h"
41 #include "ops/sequence_ops.h"
42 #include "utils/anf_utils.h"
43 
44 namespace mindspore {
45 namespace common {
46 using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
47 
48 class COMMON_EXPORT AnfAlgo {
49  public:
50   // get real input node of tuple_get_item
51   static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
52   static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
53   // get input_anf_node's real kernel by recurse
54   static KernelWithIndex VisitKernel(const AnfNodePtr &anf_node, size_t index);
55   static KernelWithIndex VisitKernelWithReturnType(
56     const AnfNodePtr &anf_node, size_t index, bool skip_nop_node = false,
57     const std::vector<PrimitivePtr> &return_types = {prim::kPrimMakeTuple},
58     abstract::AbstractBasePtr *abstract = nullptr, bool is_index_valid = false);
59 
60   // Skip the monad node to get the real node.
61   static KernelWithIndex FetchRealNodeSkipMonadControl(const KernelWithIndex &node_with_index);
62 
63   static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
64                                               const std::vector<PrimitivePtr> &return_types = {});
65   static std::vector<KernelWithIndex> GetAllOutputIndexByReturnTypes(const AnfNodePtr &node,
66                                                                      const std::vector<PrimitivePtr> &return_types = {},
67                                                                      bool need_make_tuple = false);
68   static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
69                                                             const std::vector<PrimitivePtr> &return_types = {});
70   static std::vector<KernelWithIndex> GetAllOutputWithOutMonadAndParameter(const AnfNodePtr &node);
71   // get cnode primitive
72   static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
73   static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
74   static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
75   // Get cnode primitive attr.
GetCNodePrimitiveAttr(const AnfNodePtr & node,const std::string & key)76   static ValuePtr GetCNodePrimitiveAttr(const AnfNodePtr &node, const std::string &key) {
77     const auto &primitive = GetCNodePrimitive(node);
78     return primitive != nullptr ? primitive->GetAttr(key) : nullptr;
79   }
80   // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
81   static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
82   // get cnode primitive
83   static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node);
84   // get kernel_name of anf node
85   static std::string GetCNodeName(const AnfNodePtr &node);
86   static bool IsGetNextNode(const AnfNodePtr &node);
87   // get detail info of anf node
88   static std::string GetNodeDebugString(const AnfNodePtr &node);
89   // get attr of anf node
90   template <typename T>
GetNodeAttr(const AnfNodePtr & node,const std::string & key)91   static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
92     MS_EXCEPTION_IF_NULL(node);
93     if (!node->isa<CNode>()) {
94       std::string node_debug_log = node->DebugString();
95       MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
96     }
97     // single op cnode.
98     if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) {
99       return GetValue<T>(primitive->GetAttr(key));
100     }
101     // graph kernel cnode.
102     auto fg = GetCNodeFuncGraphPtr(node);
103     MS_EXCEPTION_IF_NULL(fg);
104     return GetValue<T>(fg->get_attr(key));
105   }
106   static bool IsTupleOutput(const AnfNodePtr &anf);
107   // set attr of anf node
108   static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
109   // set attr of anf node safely(use a copy of primitive)
110   static void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
111   // set attr of key from 'from' node to 'to' node
112   static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
113   // set a new key for attr from 'from' node to 'to' node
114   static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
115                            const AnfNodePtr &to);
116   // set all attrs from 'from' node to 'to' node
117   static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
118   // check whether a cnode has the specified attr.
119   static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
120   // delete attr of anf node
121   static void EraseNodeAttr(const std::string &key, const AnfNodePtr &node);
122   // get the num of inputs include monads for a cnode
123   static size_t GetInputNum(const CNodePtr &cnode);
124   // get the num of inputs exclude monads for real_kernel (which can be build and run in device)
125   static size_t GetInputTensorNum(const AnfNodePtr &node);
126   // get prev node output width output index has tuplegetitem
127   static bool IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
128   // get prev node output width output index
129   static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
130   // get all the untuple real prev_nodes output
131   static std::vector<KernelWithIndex> GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
132                                                              bool skip_nop_node = false);
133 
134   // get output shapes inferred by ME from input nodes.
135   static ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx,
136                                          bool is_real_squence_output = false);
137   // get input shapes inferred by ME from input nodes.
138   static ShapeVector GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
139   // get output data type inferred by ME of anf node
140   static TypePtr GetOutputInferType(const AnfNodePtr &node, size_t output_idx, bool is_real_tuple = false);
141   static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
142   static TypeId GetOutputInferDataType(const TypePtr &type, size_t output_idx);
143   // get output original data type from prev node,input_index is the input index of current node related to prev node
144   static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
145   static TypePtr GetPrevNodeOutputInferType(const AnfNodePtr &node, size_t input_idx);
146   // for tuple condition
147   static std::vector<TypeId> GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
148   // set infer shapes and types of anf node
149   static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
150                                          AnfNode *node, bool disable_dynamic_len = false);
151   // set output shape ptr
152   static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
153                                           const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
154 
155   static void SetSingleOutputTypeAndDetailShape(const std::vector<TypeId> &types,
156                                                 const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
157 
158   static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
159   // checkout whether the anf node is a graph kernel.
160   static bool IsGraphKernel(const AnfNodePtr &node);
161   // checkout whether the anf node is an inner node of graph kernel.
162   static bool IsNodeInGraphKernel(const AnfNodePtr &node);
163   // get the real output of GraphKernel.
164   static AnfNodePtr GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index);
165   // check parameter is weight or data
166   static bool IsParameterWeight(const ParameterPtr &node);
167   // checkout whether the anf node is include the label_index.
168   static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
169   // Check whether the cnode update parameter
170   static bool IsUpdateParameterKernel(const CNodePtr &node);
171   static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
172   static bool IsCommunicationOp(const AnfNodePtr &node);
173   static bool IsDtypeFormatSensitiveOp(const AnfNodePtr &node);
174   static bool IsFusedCommunicationOp(const AnfNodePtr &node);
175   static bool IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type);
176   static bool IsGetNext(const NotNull<AnfNodePtr> &node);
177   static bool IsNeedSkipNopOpAddr(const AnfNodePtr &node);
178   static bool IsNeedSkipNopOpExecution(const AnfNodePtr &node);
179   static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
180   static bool IsSwitchCall(const CNodePtr &call_node);
181   static bool IsScalarInput(const CNodePtr &cnode, size_t index);
182   static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
183   static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
184   static void ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list);
185   // get fix output precision of cnode.
186   static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
187   // get fix output precision from prev node, input_idx is the input index of current node related to prev node.
188   static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
189   static bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr);
190   static bool IsNodeOutputDynamicShape(const AnfNodePtr &node);
191   static bool IsDynamicShape(const AnfNodePtr &node);
192   static bool IsDynamicRankNode(const AnfNodePtr &node);
193   static bool IsDynamicValue(const AnfNodePtr &node);
194   static bool IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr);
195   static bool IsNodeOutputDynamicRank(const AnfNodePtr &node);
196   static bool IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
197   static bool IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
198   static bool IsCondControlKernel(const CNodePtr &node);
199   static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
200   static std::optional<string> GetDumpFlag(const AnfNodePtr &node);
201   static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
202   static std::vector<int64_t> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
203   static bool IsHostKernel(const CNodePtr &kernel_node);
204   static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &real_input, size_t real_input_index);
205   // Find real input nodes.
206   static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
207                                    std::set<AnfNodePtr> *visited);
208   static void GetAllVisitedCNode(const CNodePtr &node, std::vector<AnfNodePtr> *used_kernels,
209                                  std::set<AnfNodePtr> *visited);
210   static std::string GetGraphSplitGroup(const AnfNodePtr &node);
211   static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
212   // Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.
213   static void GetRealInputs(const AnfNodePtr &node, std::vector<KernelWithIndex> *inputs);
214   // Check whether tensors need broadcast or not.
215   template <typename T>
IsTensorBroadcast(const std::vector<T> & lhs,const std::vector<T> & rhs)216   static inline bool IsTensorBroadcast(const std::vector<T> &lhs, const std::vector<T> &rhs) {
217     if (lhs.size() != rhs.size()) {
218       return true;
219     }
220     for (size_t i = 0; i < lhs.size(); i++) {
221       if (lhs[i] != rhs[i]) {
222         return true;
223       }
224     }
225     return false;
226   }
227 
228   // Calc tensor size in byte.
229   template <typename T>
TensorSizeInByte(const std::vector<int64_t> & shape)230   static size_t TensorSizeInByte(const std::vector<int64_t> &shape) {
231     return sizeof(T) * SizeOf(shape);
232   }
233 
234   template <typename T>
TensorSizeInByte(const std::vector<size_t> & shape)235   static size_t TensorSizeInByte(const std::vector<size_t> &shape) {
236     size_t res = sizeof(T);
237     res = std::accumulate(shape.begin(), shape.end(), res, std::multiplies<size_t>());
238 
239     return res;
240   }
241 
242   // Judge a control operator need be compiled into kernel graph rather than be cut into single op and
243   // executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch
244   // in backend in PyNative mode.
245   static bool IsBpropCutOpExecInBackend(const AnfNodePtr &node);
246 
247   static bool IsNodeInputContainMonad(const AnfNodePtr &node);
248   // Check whether a cnode has a monad input.
249   static bool HasMonadInput(const AnfNodePtr &node);
250 
251   // Check if node is non-task op.
252   static bool IsNonTaskOp(const CNodePtr &node);
253   // Check if node has none input after IR fusion.
254   static bool IsNoneInput(const AnfNodePtr &node, size_t index);
255   // Check whether node is a call node, call nodes are those cnodes whose first input is not primitive node.
256   static bool IsCallNode(const AnfNodePtr &node);
257   // Get the output number according to abstract, when there is a tuple in abstract, it needs to get recursively.
258   static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract);
259   // Get attr groups
260   static int64_t GetAttrGroups(const AnfNodePtr &node, size_t index);
261 
IsAllgather(const CNodePtr & cnode)262   static inline bool IsAllgather(const CNodePtr &cnode) { return GetCNodeName(cnode) == kAllGatherOpName; }
263 
IsFusion(const CNodePtr & cnode)264   static inline bool IsFusion(const CNodePtr &cnode) {
265     return HasNodeAttr(kAttrFusion, cnode) && GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0;
266   }
267 
IsFromParallelOptimizer(const CNodePtr & cnode)268   static inline bool IsFromParallelOptimizer(const CNodePtr &cnode) {
269     auto primitive = GetCNodePrimitive(cnode);
270     return (primitive != nullptr) && primitive->instance_name().find("parallel_optimizer") != std::string::npos;
271   }
272 
IsRecompute(const CNodePtr & cnode)273   static inline bool IsRecompute(const CNodePtr &cnode) {
274     auto attr_dup = cnode->GetAttr(kAttrDuplicated);
275     return attr_dup != nullptr && GetValue<bool>(attr_dup);
276   }
277 
278   // Check whether the node has Ref abstract.
HasAbstractRef(const AnfNodePtr & node)279   static inline bool HasAbstractRef(const AnfNodePtr &node) {
280     MS_EXCEPTION_IF_NULL(node);
281     auto &abs = node->abstract();
282     return (abs != nullptr) && abs->isa<abstract::AbstractRefTensor>();
283   }
284 
285   // Check whether the sequence node has Ref abstract.
SequenceHasAbstractRef(const AnfNodePtr & node)286   static inline bool SequenceHasAbstractRef(const AnfNodePtr &node) {
287     MS_EXCEPTION_IF_NULL(node);
288     auto &abs = node->abstract();
289     if ((abs != nullptr) && (abs->isa<abstract::AbstractSequence>())) {
290       auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
291       const auto &elements = abs_seq->elements();
292       return std::any_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
293         return (element != nullptr) && element->isa<abstract::AbstractRefTensor>();
294       });
295     }
296     return false;
297   }
298 
299   // Get the real output node and indexes of get item, make tuple, depend, load.
300   static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *const index_stack);
301   static bool IsNopNode(const AnfNodePtr &node);
302 
303   template <typename T>
304   static bool CheckAbsType(const AnfNodePtr &node);
305   static bool CheckAbsSparseTensor(const AnfNodePtr &node);
306   static bool CheckAbsSparseTensor(const abstract::AbstractBasePtr &abs);
307   static TypeId GetSparseTypeIdAt(const AnfNodePtr &node, size_t idx);
308 
309   static std::string GetTensorValueString(const tensor::BaseTensorPtr &tensor);
310   static abstract::AbstractBasePtr FrontendGetNodeAbstractByIndex(const AnfNodePtr &node, size_t index);
311 
312   // Get jit level from func_graph
313   static std::string GetJitLevel(const FuncGraphPtr &func_graph);
314 
315   static bool IsNodeMutableScalar(const AnfNodePtr &node);
316   static bool IsDynamicSequence(const AnfNodePtr &node);
317   static bool IsAnyTypeOutput(const AnfNodePtr &node);
318   static bool IsAnyTypeInput(const std::vector<AnfNodePtr> &inputs);
319   static bool HasTupleInput(const CNodePtr &node);
320   static bool HasDynamicTupleInput(const CNodePtr &node);
321   static bool IsReduceOp(const std::string &op_name);
322   static bool IsTypeTransformOp(const std::string &op_name);
323   // Get the element shape of dynamic sequence shape.
324   static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
325   // Fetch the sub abstract from the top abstract by the index.
326   static abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
327 
328   static std::string GetInputName(const CNodePtr &origin_op, size_t input_index);
329   static bool IsNoOuputNode(const AnfNodePtr &node);
330   static ValuePtr ValueToScalar(const ValuePtr &value, TypeId type_id);
331   static std::vector<ValuePtr> TransformVectorRefToMultiValue(const VectorRef &base_ref);
332   static bool HasIncorporateCallNode(const CNodePtr &cnode);
333   static bool IsDynamicGraph(const FuncGraphPtr &func_graph);
334 };
335 
CreateShapeVectorNode(const ShapeVector & value)336 inline AnfNodePtr CreateShapeVectorNode(const ShapeVector &value) {
337   auto value_node = NewValueNode(value);
338   ShapeVector value_node_shape = {SizeToLong(value.size())};
339   common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeInt64}, {value_node_shape}, value_node.get());
340   return value_node;
341 }
342 
CreateReshapeNode(const FuncGraphPtr & graph,const AnfNodePtr & input_node,const ShapeVector & shape)343 inline CNodePtr CreateReshapeNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, const ShapeVector &shape) {
344   MS_EXCEPTION_IF_NULL(input_node);
345 
346   auto shape_node = CreateShapeVectorNode(shape);
347   AnfNodePtrList reshape_inputs = {NewValueNode(std::make_shared<Primitive>(kReshapeOpName)), input_node, shape_node};
348   auto reshape_node = NewCNode(reshape_inputs, graph);
349   MS_EXCEPTION_IF_NULL(reshape_node);
350   reshape_node->set_scope(input_node->scope());
351   common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape_node);
352   common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(shape), reshape_node);
353   auto data_type = common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0);
354   common::AnfAlgo::SetOutputInferTypeAndShape({data_type}, {shape}, reshape_node.get());
355 
356   return reshape_node;
357 }
358 }  // namespace common
359 }  // namespace mindspore
360 #endif  // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_ANFALGO_H
361