• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H
18 #define MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H
19 #include <iostream>
20 #include <string>
21 #include <vector>
22 #include <set>
23 #include <tuple>
24 #include <utility>
25 #include <memory>
26 #include <unordered_set>
27 #include <map>
28 #include "ir/anf.h"
29 #include "ir/dtype.h"
30 #include "base/base.h"
31 #include "ir/primitive.h"
32 #include "ir/kernel_info_dev.h"
33 #include "runtime/device/device_address.h"
34 #include "backend/kernel_compiler/kernel.h"
35 #include "backend/kernel_compiler/kernel_build_info.h"
36 #include "base/core_ops.h"
37 #include "utils/contract.h"
38 #include "backend/session/kernel_graph.h"
39 
40 namespace mindspore {
41 namespace session {
42 using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
43 using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
44 using DeviceAddress = device::DeviceAddress;
45 using DeviceAddressPtr = device::DeviceAddressPtr;
46 using Address = kernel::Address;
47 using AddressPtr = kernel::AddressPtr;
48 
49 class OpRuntimeInfo {
50  public:
OpRuntimeInfo(std::vector<std::string> output_format,std::vector<TypeId> output_type,std::vector<size_t> output_tensor_size)51   OpRuntimeInfo(std::vector<std::string> output_format, std::vector<TypeId> output_type,
52                 std::vector<size_t> output_tensor_size)
53       : output_format_(std::move(output_format)),
54         output_type_(std::move(output_type)),
55         output_tensor_size_(std::move(output_tensor_size)) {}
56   ~OpRuntimeInfo() = default;
57 
58   // Key for user data.
59   constexpr static char key[] = "OpRuntimeInfo";
60 
61   std::string output_format(size_t index) const;
62   TypeId output_type(size_t index) const;
63   size_t output_tensor_size(size_t index) const;
64 
65  private:
66   std::vector<std::string> output_format_;
67   std::vector<TypeId> output_type_;
68   std::vector<size_t> output_tensor_size_;
69 };
70 
71 class AnfRuntimeAlgorithm {
72  public:
73   static AnfNodePtr MakeMonadValueNode(const KernelGraphPtr &kg);
74   static void KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter);
75   // get real input node of tuple_get_item
76   static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item);
77   static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
78   // get input_anf_node's real kernel by recurse
79   static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
80   static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
81                                                    bool skip_nop_node = false,
82                                                    const std::vector<PrimitivePtr> &return_types = {
83                                                      prim::kPrimMakeTuple});
84   static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
85                                               const std::vector<PrimitivePtr> &return_types = {});
86   static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node);
87   // get cnode primitive
88   static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
89   static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
90   static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
91   // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
92   static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
93   // get cnode primitive
94   static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node);
95   // get kernel_name of anf node
96   static std::string GetCNodeName(const AnfNodePtr &node);
97   // get detail info of anf node
98   static std::string GetNodeDebugString(const AnfNodePtr &node);
99   // get attr of anf node
100   template <typename T>
GetNodeAttr(const AnfNodePtr & node,const std::string & key)101   static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
102     MS_EXCEPTION_IF_NULL(node);
103     if (!node->isa<CNode>()) {
104       std::string node_debug_log = node->DebugString();
105       MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
106     }
107     // single op cnode.
108     if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) {
109       return GetValue<T>(primitive->GetAttr(key));
110     }
111     // graph kernel cnode.
112     auto fg = GetCNodeFuncGraphPtr(node);
113     MS_EXCEPTION_IF_NULL(fg);
114     return GetValue<T>(fg->get_attr(key));
115   }
116   static bool IsTupleOutput(const AnfNodePtr &anf);
117   // set attr of anf node
118   static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
119   // set attr of key from 'from' node to 'to' node
120   static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
121   // set a new key for attr from 'from' node to 'to' node
122   static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
123                            const AnfNodePtr &to);
124   // set all attrs from 'from' node to 'to' node
125   static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
126   // check whether a cnode has the specified attr.
127   static bool HasNodeAttr(const std::string &key, const CNodePtr &node);
128   // delete attr of anf node
129   static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
130   // get the num of inputs include monads for a cnode
131   static size_t GetInputNum(const CNodePtr &cnode);
132   // get the num of inputs exclude monads for real_kernel (which can be build and run in device)
133   static size_t GetInputTensorNum(const AnfNodePtr &node);
134   // get the num of output real_kernel(which can be build and run in device)
135   static size_t GetOutputTensorNum(const AnfNodePtr &node);
136   // Get the memory size of output tensor of node.
137   static size_t GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index);
138   // get all outputs format select of anf node
139   static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node);
140   // get all inputs format select of anf node
141   static std::vector<std::string> GetAllInputFormats(const AnfNodePtr &node);
142   // get all inputs type select of anf node
143   static std::vector<TypeId> GetAllInputDeviceTypes(const AnfNodePtr &node);
144   // get all outputs type select of anf node
145   static std::vector<TypeId> GetAllOutputDeviceTypes(const AnfNodePtr &node);
146   // get origin data format select of anf node
147   static std::string GetOriginDataFormat(const AnfNodePtr &node);
148   // get output format select of anf node
149   static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
150   // get input format select of anf node
151   static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
152   // get prev node output width output index
153   static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = false);
154   // get output format from prev node,input_index is the input index of current node related to prev node
155   static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
156   // get reshape_type of from the output of input node.
157   static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
158   // get output shapes inferred by ME from input nodes.
159   static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
160   static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape,
161                                                  size_t output_idx);
162   // get input shapes inferred by ME from input nodes.
163   static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
164   // get output shapes which will built and run in device
165   static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
166   // get input shapes which will built and run in device
167   static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
168   // get output shapes for tbe build
169   static std::vector<int64_t> GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, const size_t output_idx,
170                                                               const std::string &format);
171   // Get Input Padding Axis
172   static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
173   // Get Output Padding Axis
174   static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
175   // get output data type inferred by ME of anf node
176   static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
177   static TypeId GetOutputInferDataType(const TypePtr &type_ptr, size_t output_idx);
178   // get output original data type from prev node,input_index is the input index of current node related to prev node
179   static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
180   // get output select data type of anf node
181   static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx);
182   // get input select data type of anf node
183   static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
184   // get output select data type from prev node,input_index is the input index of current node related to prev node
185   static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
186   // get output device addr of anf_node
187   static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
188   // get mutable output device addr of anf_node
189   static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
190   // check whether output addr is exist or not
191   static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = false);
192   // check whether workspace addr is exist or not
193   static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx);
194   // get address from prev node,input_index is the input index of current node related to prev node
195   static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
196                                                     bool visit_nop_node = true);
197   static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
198                                                        bool visit_nop_node = true);
199   // set output device addr of anf_node
200   static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
201   // set workspace device addr of anf_node
202   static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
203   // get workspace device addr of anf_node
204   static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
205   // get workspace device mutable addr of anf_node
206   static DeviceAddressPtr GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index);
207   // set infer shapes and types of anf node
208   static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
209                                          const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
210   // get and set output shape ptr
211   static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx);
212   static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx);
213   static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types,
214                                           const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node);
215   static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
216   // get op pattern of the node
217   static kernel::OpPattern GetOpPattern(const AnfNodePtr &node);
218   // get KernelBuildType of node ,such as ATT,RT,FWK and so on
219   static KernelType GetKernelType(const AnfNodePtr &node);
220   // get processor type:AICORE,AICPU...
221   static kernel::Processor GetProcessor(const AnfNodePtr &node);
222   // get fusion type:AICORE,AICPU...
223   static kernel::FusionType GetFusionType(const AnfNodePtr &node);
224   static void SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type);
225   static void SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc);
226   static std::vector<nlohmann::json> GetOutputDataDesc(const AnfNodePtr &node);
227   // set select kernel_build_info
228   static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
229   // get select kernel_build_info
230   static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node);
231   // get kernelMode
232   static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
233   // set kernel mod
234   static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
235   // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
236   static bool IsRealKernel(const AnfNodePtr &node);
237   // checkout whether the anf node is a real kernel that is a cnode and can run on device
238   static bool IsRealCNodeKernel(const AnfNodePtr &node);
239   // checkout whether the anf node is a graph kernel.
240   static bool IsGraphKernel(const AnfNodePtr &node);
241   // checkout whether the anf node is an inner node of graph kernel.
242   static bool IsNodeInGraphKernel(const AnfNodePtr &node);
243   // get the real output of GraphKernel.
244   static AnfNodePtr GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index);
245   // check parameter is weight or data
246   static bool IsParameterWeight(const ParameterPtr &node);
247   // checkout whether the anf node is include the label_index.
248   static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
249   // set stream id of kernel,which will be set in stream assign and be used in stream generate
250   static void SetStreamId(uint32_t stream_id, AnfNode *node);
251   // get stream id
252   static uint32_t GetStreamId(const AnfNodePtr &node);
253   // set stream distinction label to distinguish different ops in different streams
254   static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node);
255   // get stream distinction label
256   static uint32_t GetStreamDistinctionLabel(const AnfNode *node);
257   // set graph id
258   static void SetGraphId(uint32_t graph_id, AnfNode *node);
259   // get graph id
260   static uint32_t GetGraphId(const AnfNode *node);
261   static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
262   // charge if the node's output is a feature map output
263   static bool IsFeatureMapOutput(const AnfNodePtr &node);
264   // charge if the node's input is from a feature map output
265   static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
266   // get real input index for some tbe ops which input order is different between me and tbe impl
267   static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
268   // get me input index for some tbe ops which input order is different between me and tbe impl
269   static size_t GetOriginalInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
270   static bool IsCommunicationOp(const AnfNodePtr &node);
271   static bool IsFusedCommunicationOp(const AnfNodePtr &node);
272   static bool IsInplaceNode(const AnfNodePtr &node, const string &type);
273   static bool IsGetNext(const NotNull<AnfNodePtr> &node);
274   static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
275   static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode);
276   static bool IsSwitchCall(const CNodePtr &call_node);
277   static bool IsScalarInput(const CNodePtr &cnode, size_t index);
278   static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
279   static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
280   static void ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list);
281   // get fix output precision of cnode.
282   static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
283   // get fix output precision from prev node, input_idx is the input index of current node related to prev node.
284   static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx);
285   static bool IsDynamicShape(const AnfNodePtr &node);
286   static bool HasDynamicShapeFlag(const PrimitivePtr &prim);
287   static bool IsCondControlKernel(const CNodePtr &node);
288   static bool IsIndependentNode(const CNodePtr &node);
289   static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
290   static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
291   static std::vector<int64_t> GetInputMaxShape(const AnfNodePtr &anf_node, size_t index);
292   static std::vector<int64_t> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);
293   static std::vector<int64_t> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
294   static std::vector<int64_t> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
295   static bool IsNodeDynamicShape(const AnfNodePtr &node);
296   static void InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors = nullptr);
297   static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
298   static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
299   // Find real input nodes.
300   static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
301                                    std::set<AnfNodePtr> *visited);
302   static void GetAllVisitedCNode(const CNodePtr &cnode, std::vector<AnfNodePtr> *used_kernels,
303                                  std::set<AnfNodePtr> *visited);
304   static void InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph);
305   static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node);
306   // Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc.
307   static void GetRealInputs(const AnfNodePtr &anf_node, std::vector<session::KernelWithIndex> *inputs);
308   // Check whether tensors need broadcast or not.
309   static bool IsTensorBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs);
310   // Calc tensor size in byte.
311   template <typename T>
TensorSizeInByte(const std::vector<size_t> & shape)312   static size_t TensorSizeInByte(const std::vector<size_t> &shape) {
313     size_t result = sizeof(T);
314     for (size_t i = 0; i < shape.size(); i++) {
315       result *= shape[i];
316     }
317     return result;
318   }
319   static bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
320 
321   // Judge a control operator need be compiled into kernel graph rather than be cut into single op and
322   // executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch
323   // in backend in PyNative mode.
324   static bool IsControlOpExecInBackend(const AnfNodePtr &node);
325 
326   static bool IsNodeInputContainMonad(const AnfNodePtr &node);
327   // Save inputs/outputs/workspace address in kernel_mod.
328   static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph);
329   static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
330   static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
331 };
332 }  // namespace session
333 using AnfAlgo = session::AnfRuntimeAlgorithm;
334 }  // namespace mindspore
335 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H
336