1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_ANFALGO_H
18 #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_ANFALGO_H
19
20 #include <functional>
21 #include <iostream>
22 #include <map>
23 #include <memory>
24 #include <optional>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 #include <utility>
29 #include <vector>
30 #include "base/base.h"
31 #include "include/common/utils/contract.h"
32 #include "include/common/utils/utils.h"
33 #include "include/common/visible.h"
34 #include "ir/anf.h"
35 #include "ir/dtype.h"
36 #include "ir/func_graph.h"
37 #include "ir/kernel_info_dev.h"
38 #include "ir/primitive.h"
39 #include "ops/array_op_name.h"
40 #include "ops/other_op_name.h"
41 #include "ops/sequence_ops.h"
42 #include "utils/anf_utils.h"
43
44 namespace mindspore {
45 namespace common {
46 using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
47
48 class COMMON_EXPORT AnfAlgo {
49 public:
50 // get real input node of tuple_get_item
51 static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
52 static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
53 // get input_anf_node's real kernel by recurse
54 static KernelWithIndex VisitKernel(const AnfNodePtr &anf_node, size_t index);
55 static KernelWithIndex VisitKernelWithReturnType(
56 const AnfNodePtr &anf_node, size_t index, bool skip_nop_node = false,
57 const std::vector<PrimitivePtr> &return_types = {prim::kPrimMakeTuple},
58 abstract::AbstractBasePtr *abstract = nullptr, bool is_index_valid = false);
59
60 // Skip the monad node to get the real node.
61 static KernelWithIndex FetchRealNodeSkipMonadControl(const KernelWithIndex &node_with_index);
62
63 static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
64 const std::vector<PrimitivePtr> &return_types = {});
65 static std::vector<KernelWithIndex> GetAllOutputIndexByReturnTypes(const AnfNodePtr &node,
66 const std::vector<PrimitivePtr> &return_types = {},
67 bool need_make_tuple = false);
68 static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node,
69 const std::vector<PrimitivePtr> &return_types = {});
70 static std::vector<KernelWithIndex> GetAllOutputWithOutMonadAndParameter(const AnfNodePtr &node);
71 // get cnode primitive
72 static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
73 static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
74 static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
75 // Get cnode primitive attr.
GetCNodePrimitiveAttr(const AnfNodePtr & node,const std::string & key)76 static ValuePtr GetCNodePrimitiveAttr(const AnfNodePtr &node, const std::string &key) {
77 const auto &primitive = GetCNodePrimitive(node);
78 return primitive != nullptr ? primitive->GetAttr(key) : nullptr;
79 }
80 // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
81 static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
82 // get cnode primitive
83 static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node);
84 // get kernel_name of anf node
85 static std::string GetCNodeName(const AnfNodePtr &node);
86 static bool IsGetNextNode(const AnfNodePtr &node);
87 // get detail info of anf node
88 static std::string GetNodeDebugString(const AnfNodePtr &node);
89 // get attr of anf node
90 template <typename T>
GetNodeAttr(const AnfNodePtr & node,const std::string & key)91 static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
92 MS_EXCEPTION_IF_NULL(node);
93 if (!node->isa<CNode>()) {
94 std::string node_debug_log = node->DebugString();
95 MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
96 }
97 // single op cnode.
98 if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) {
99 return GetValue<T>(primitive->GetAttr(key));
100 }
101 // graph kernel cnode.
102 auto fg = GetCNodeFuncGraphPtr(node);
103 MS_EXCEPTION_IF_NULL(fg);
104 return GetValue<T>(fg->get_attr(key));
105 }
106 static bool IsTupleOutput(const AnfNodePtr &anf);
107 // set attr of anf node
108 static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
109 // set attr of anf node safely(use a copy of primitive)
110 static void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
111 // set attr of key from 'from' node to 'to' node
112 static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
113 // set a new key for attr from 'from' node to 'to' node
114 static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
115 const AnfNodePtr &to);
116 // set all attrs from 'from' node to 'to' node
117 static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
118 // check whether a cnode has the specified attr.
119 static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
120 // delete attr of anf node
121 static void EraseNodeAttr(const std::string &key, const AnfNodePtr &node);
122 // get the num of inputs include monads for a cnode
123 static size_t GetInputNum(const CNodePtr &cnode);
124 // get the num of inputs exclude monads for real_kernel (which can be build and run in device)
125 static size_t GetInputTensorNum(const AnfNodePtr &node);
126 // get prev node output width output index has tuplegetitem
127 static bool IsPrevNodeHasTupleGetItem(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
128 // get prev node output width output index
129 static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
130 // get all the untuple real prev_nodes output
131 static std::vector<KernelWithIndex> GetRealPrevNodesOutput(const AnfNodePtr &anf_node, size_t input_idx,
132 bool skip_nop_node = false);
133
134 // get output shapes inferred by ME from input nodes.
135 static ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx,
136 bool is_real_squence_output = false);
137 // get input shapes inferred by ME from input nodes.
138 static ShapeVector GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
139 // get output data type inferred by ME of anf node
140 static TypePtr GetOutputInferType(const AnfNodePtr &node, size_t output_idx, bool is_real_tuple = false);
141 static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
142 static TypeId GetOutputInferDataType(const TypePtr &type, size_t output_idx);
143 // get output original data type from prev node,input_index is the input index of current node related to prev node
144 static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
145 static TypePtr GetPrevNodeOutputInferType(const AnfNodePtr &node, size_t input_idx);
146 // for tuple condition
147 static std::vector<TypeId> GetRealPrevNodesOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
148 // set infer shapes and types of anf node
149 static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, const std::vector<ShapeVector> &shapes,
150 AnfNode *node, bool disable_dynamic_len = false);
151 // set output shape ptr
152 static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
153 const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
154
155 static void SetSingleOutputTypeAndDetailShape(const std::vector<TypeId> &types,
156 const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
157
158 static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
159 // checkout whether the anf node is a graph kernel.
160 static bool IsGraphKernel(const AnfNodePtr &node);
161 // checkout whether the anf node is an inner node of graph kernel.
162 static bool IsNodeInGraphKernel(const AnfNodePtr &node);
163 // get the real output of GraphKernel.
164 static AnfNodePtr GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index);
165 // check parameter is weight or data
166 static bool IsParameterWeight(const ParameterPtr &node);
167 // checkout whether the anf node is include the label_index.
168 static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
169 // Check whether the cnode update parameter
170 static bool IsUpdateParameterKernel(const CNodePtr &node);
171 static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
172 static bool IsCommunicationOp(const AnfNodePtr &node);
173 static bool IsDtypeFormatSensitiveOp(const AnfNodePtr &node);
174 static bool IsFusedCommunicationOp(const AnfNodePtr &node);
175 static bool IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &type);
176 static bool IsGetNext(const NotNull<AnfNodePtr> &node);
177 static bool IsNeedSkipNopOpAddr(const AnfNodePtr &node);
178 static bool IsNeedSkipNopOpExecution(const AnfNodePtr &node);
179 static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
180 static bool IsSwitchCall(const CNodePtr &call_node);
181 static bool IsScalarInput(const CNodePtr &cnode, size_t index);
182 static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
183 static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
184 static void ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list);
185 // get fix output precision of cnode.
186 static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
187 // get fix output precision from prev node, input_idx is the input index of current node related to prev node.
188 static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
189 static bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr);
190 static bool IsNodeOutputDynamicShape(const AnfNodePtr &node);
191 static bool IsDynamicShape(const AnfNodePtr &node);
192 static bool IsDynamicRankNode(const AnfNodePtr &node);
193 static bool IsDynamicValue(const AnfNodePtr &node);
194 static bool IsNodeInputDynamicRank(const CNodePtr &anf_node_ptr);
195 static bool IsNodeOutputDynamicRank(const AnfNodePtr &node);
196 static bool IsInputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
197 static bool IsOutputAnchorDynamicRank(const AnfNodePtr &node, size_t idx);
198 static bool IsCondControlKernel(const CNodePtr &node);
199 static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
200 static std::optional<string> GetDumpFlag(const AnfNodePtr &node);
201 static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
202 static std::vector<int64_t> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
203 static bool IsHostKernel(const CNodePtr &kernel_node);
204 static void AddArgList(AbstractBasePtrList *args_spec_list, const AnfNodePtr &real_input, size_t real_input_index);
205 // Find real input nodes.
206 static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
207 std::set<AnfNodePtr> *visited);
208 static void GetAllVisitedCNode(const CNodePtr &node, std::vector<AnfNodePtr> *used_kernels,
209 std::set<AnfNodePtr> *visited);
210 static std::string GetGraphSplitGroup(const AnfNodePtr &node);
211 static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
212 // Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.
213 static void GetRealInputs(const AnfNodePtr &node, std::vector<KernelWithIndex> *inputs);
214 // Check whether tensors need broadcast or not.
215 template <typename T>
IsTensorBroadcast(const std::vector<T> & lhs,const std::vector<T> & rhs)216 static inline bool IsTensorBroadcast(const std::vector<T> &lhs, const std::vector<T> &rhs) {
217 if (lhs.size() != rhs.size()) {
218 return true;
219 }
220 for (size_t i = 0; i < lhs.size(); i++) {
221 if (lhs[i] != rhs[i]) {
222 return true;
223 }
224 }
225 return false;
226 }
227
228 // Calc tensor size in byte.
229 template <typename T>
TensorSizeInByte(const std::vector<int64_t> & shape)230 static size_t TensorSizeInByte(const std::vector<int64_t> &shape) {
231 return sizeof(T) * SizeOf(shape);
232 }
233
234 template <typename T>
TensorSizeInByte(const std::vector<size_t> & shape)235 static size_t TensorSizeInByte(const std::vector<size_t> &shape) {
236 size_t res = sizeof(T);
237 res = std::accumulate(shape.begin(), shape.end(), res, std::multiplies<size_t>());
238
239 return res;
240 }
241
242 // Judge a control operator need be compiled into kernel graph rather than be cut into single op and
243 // executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch
244 // in backend in PyNative mode.
245 static bool IsBpropCutOpExecInBackend(const AnfNodePtr &node);
246
247 static bool IsNodeInputContainMonad(const AnfNodePtr &node);
248 // Check whether a cnode has a monad input.
249 static bool HasMonadInput(const AnfNodePtr &node);
250
251 // Check if node is non-task op.
252 static bool IsNonTaskOp(const CNodePtr &node);
253 // Check if node has none input after IR fusion.
254 static bool IsNoneInput(const AnfNodePtr &node, size_t index);
255 // Check whether node is a call node, call nodes are those cnodes whose first input is not primitive node.
256 static bool IsCallNode(const AnfNodePtr &node);
257 // Get the output number according to abstract, when there is a tuple in abstract, it needs to get recursively.
258 static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract);
259 // Get attr groups
260 static int64_t GetAttrGroups(const AnfNodePtr &node, size_t index);
261
IsAllgather(const CNodePtr & cnode)262 static inline bool IsAllgather(const CNodePtr &cnode) { return GetCNodeName(cnode) == kAllGatherOpName; }
263
IsFusion(const CNodePtr & cnode)264 static inline bool IsFusion(const CNodePtr &cnode) {
265 return HasNodeAttr(kAttrFusion, cnode) && GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0;
266 }
267
IsFromParallelOptimizer(const CNodePtr & cnode)268 static inline bool IsFromParallelOptimizer(const CNodePtr &cnode) {
269 auto primitive = GetCNodePrimitive(cnode);
270 return (primitive != nullptr) && primitive->instance_name().find("parallel_optimizer") != std::string::npos;
271 }
272
IsRecompute(const CNodePtr & cnode)273 static inline bool IsRecompute(const CNodePtr &cnode) {
274 auto attr_dup = cnode->GetAttr(kAttrDuplicated);
275 return attr_dup != nullptr && GetValue<bool>(attr_dup);
276 }
277
278 // Check whether the node has Ref abstract.
HasAbstractRef(const AnfNodePtr & node)279 static inline bool HasAbstractRef(const AnfNodePtr &node) {
280 MS_EXCEPTION_IF_NULL(node);
281 auto &abs = node->abstract();
282 return (abs != nullptr) && abs->isa<abstract::AbstractRefTensor>();
283 }
284
285 // Check whether the sequence node has Ref abstract.
SequenceHasAbstractRef(const AnfNodePtr & node)286 static inline bool SequenceHasAbstractRef(const AnfNodePtr &node) {
287 MS_EXCEPTION_IF_NULL(node);
288 auto &abs = node->abstract();
289 if ((abs != nullptr) && (abs->isa<abstract::AbstractSequence>())) {
290 auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
291 const auto &elements = abs_seq->elements();
292 return std::any_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
293 return (element != nullptr) && element->isa<abstract::AbstractRefTensor>();
294 });
295 }
296 return false;
297 }
298
299 // Get the real output node and indexes of get item, make tuple, depend, load.
300 static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *const index_stack);
301 static bool IsNopNode(const AnfNodePtr &node);
302
303 template <typename T>
304 static bool CheckAbsType(const AnfNodePtr &node);
305 static bool CheckAbsSparseTensor(const AnfNodePtr &node);
306 static bool CheckAbsSparseTensor(const abstract::AbstractBasePtr &abs);
307 static TypeId GetSparseTypeIdAt(const AnfNodePtr &node, size_t idx);
308
309 static std::string GetTensorValueString(const tensor::BaseTensorPtr &tensor);
310 static abstract::AbstractBasePtr FrontendGetNodeAbstractByIndex(const AnfNodePtr &node, size_t index);
311
312 // Get jit level from func_graph
313 static std::string GetJitLevel(const FuncGraphPtr &func_graph);
314
315 static bool IsNodeMutableScalar(const AnfNodePtr &node);
316 static bool IsDynamicSequence(const AnfNodePtr &node);
317 static bool IsAnyTypeOutput(const AnfNodePtr &node);
318 static bool IsAnyTypeInput(const std::vector<AnfNodePtr> &inputs);
319 static bool HasTupleInput(const CNodePtr &node);
320 static bool HasDynamicTupleInput(const CNodePtr &node);
321 static bool IsReduceOp(const std::string &op_name);
322 static bool IsTypeTransformOp(const std::string &op_name);
323 // Get the element shape of dynamic sequence shape.
324 static abstract::BaseShapePtr GetDynamicSequenceShape(const AnfNodePtr &node, size_t output_idx);
325 // Fetch the sub abstract from the top abstract by the index.
326 static abstract::AbstractBasePtr FetchAbstractByIndex(const AbstractBasePtr &abstract, size_t index);
327
328 static std::string GetInputName(const CNodePtr &origin_op, size_t input_index);
329 static bool IsNoOuputNode(const AnfNodePtr &node);
330 static ValuePtr ValueToScalar(const ValuePtr &value, TypeId type_id);
331 static std::vector<ValuePtr> TransformVectorRefToMultiValue(const VectorRef &base_ref);
332 static bool HasIncorporateCallNode(const CNodePtr &cnode);
333 static bool IsDynamicGraph(const FuncGraphPtr &func_graph);
334 };
335
CreateShapeVectorNode(const ShapeVector & value)336 inline AnfNodePtr CreateShapeVectorNode(const ShapeVector &value) {
337 auto value_node = NewValueNode(value);
338 ShapeVector value_node_shape = {SizeToLong(value.size())};
339 common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeInt64}, {value_node_shape}, value_node.get());
340 return value_node;
341 }
342
CreateReshapeNode(const FuncGraphPtr & graph,const AnfNodePtr & input_node,const ShapeVector & shape)343 inline CNodePtr CreateReshapeNode(const FuncGraphPtr &graph, const AnfNodePtr &input_node, const ShapeVector &shape) {
344 MS_EXCEPTION_IF_NULL(input_node);
345
346 auto shape_node = CreateShapeVectorNode(shape);
347 AnfNodePtrList reshape_inputs = {NewValueNode(std::make_shared<Primitive>(kReshapeOpName)), input_node, shape_node};
348 auto reshape_node = NewCNode(reshape_inputs, graph);
349 MS_EXCEPTION_IF_NULL(reshape_node);
350 reshape_node->set_scope(input_node->scope());
351 common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape_node);
352 common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(shape), reshape_node);
353 auto data_type = common::AnfAlgo::GetOutputInferDataType(input_node, kIndex0);
354 common::AnfAlgo::SetOutputInferTypeAndShape({data_type}, {shape}, reshape_node.get());
355
356 return reshape_node;
357 }
358 } // namespace common
359 } // namespace mindspore
360 #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_ANFALGO_H
361