1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H 18 #define MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H 19 #include <iostream> 20 #include <string> 21 #include <vector> 22 #include <set> 23 #include <tuple> 24 #include <utility> 25 #include <memory> 26 #include <unordered_set> 27 #include <map> 28 #include "ir/anf.h" 29 #include "ir/dtype.h" 30 #include "base/base.h" 31 #include "ir/primitive.h" 32 #include "ir/kernel_info_dev.h" 33 #include "runtime/device/device_address.h" 34 #include "backend/kernel_compiler/kernel.h" 35 #include "backend/kernel_compiler/kernel_build_info.h" 36 #include "base/core_ops.h" 37 #include "utils/contract.h" 38 #include "backend/session/kernel_graph.h" 39 40 namespace mindspore { 41 namespace session { 42 using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>; 43 using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>; 44 using DeviceAddress = device::DeviceAddress; 45 using DeviceAddressPtr = device::DeviceAddressPtr; 46 using Address = kernel::Address; 47 using AddressPtr = kernel::AddressPtr; 48 49 class OpRuntimeInfo { 50 public: OpRuntimeInfo(std::vector<std::string> output_format,std::vector<TypeId> output_type,std::vector<size_t> output_tensor_size)51 OpRuntimeInfo(std::vector<std::string> output_format, std::vector<TypeId> output_type, 52 std::vector<size_t> output_tensor_size) 53 : output_format_(std::move(output_format)), 54 output_type_(std::move(output_type)), 55 output_tensor_size_(std::move(output_tensor_size)) {} 56 ~OpRuntimeInfo() = default; 57 58 // Key for user data. 59 constexpr static char key[] = "OpRuntimeInfo"; 60 61 std::string output_format(size_t index) const; 62 TypeId output_type(size_t index) const; 63 size_t output_tensor_size(size_t index) const; 64 65 private: 66 std::vector<std::string> output_format_; 67 std::vector<TypeId> output_type_; 68 std::vector<size_t> output_tensor_size_; 69 }; 70 71 class AnfRuntimeAlgorithm { 72 public: 73 static AnfNodePtr MakeMonadValueNode(const KernelGraphPtr &kg); 74 static void KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter); 75 // get real input node of tuple_get_item 76 static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item); 77 static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); 78 // get input_anf_node's real kernel by recurse 79 static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); 80 static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, 81 bool skip_nop_node = false, 82 const std::vector<PrimitivePtr> &return_types = { 83 prim::kPrimMakeTuple}); 84 static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node, 85 const std::vector<PrimitivePtr> &return_types = {}); 86 static std::vector<KernelWithIndex> GetAllOutputWithIndex(const AnfNodePtr &node); 87 // get cnode primitive 88 static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); 89 static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); 90 static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); 91 // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple 92 static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); 93 // get cnode primitive 94 static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node); 95 // get kernel_name of anf node 96 static std::string GetCNodeName(const AnfNodePtr &node); 97 // get detail info of anf node 98 static std::string GetNodeDebugString(const AnfNodePtr &node); 99 // get attr of anf node 100 template <typename T> GetNodeAttr(const AnfNodePtr & node,const std::string & key)101 static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) { 102 MS_EXCEPTION_IF_NULL(node); 103 if (!node->isa<CNode>()) { 104 std::string node_debug_log = node->DebugString(); 105 MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); 106 } 107 // single op cnode. 108 if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) { 109 return GetValue<T>(primitive->GetAttr(key)); 110 } 111 // graph kernel cnode. 112 auto fg = GetCNodeFuncGraphPtr(node); 113 MS_EXCEPTION_IF_NULL(fg); 114 return GetValue<T>(fg->get_attr(key)); 115 } 116 static bool IsTupleOutput(const AnfNodePtr &anf); 117 // set attr of anf node 118 static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); 119 // set attr of key from 'from' node to 'to' node 120 static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to); 121 // set a new key for attr from 'from' node to 'to' node 122 static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, 123 const AnfNodePtr &to); 124 // set all attrs from 'from' node to 'to' node 125 static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to); 126 // check whether a cnode has the specified attr. 127 static bool HasNodeAttr(const std::string &key, const CNodePtr &node); 128 // delete attr of anf node 129 static void EraseNodeAttr(const std::string &key, AnfNodePtr node); 130 // get the num of inputs include monads for a cnode 131 static size_t GetInputNum(const CNodePtr &cnode); 132 // get the num of inputs exclude monads for real_kernel (which can be build and run in device) 133 static size_t GetInputTensorNum(const AnfNodePtr &node); 134 // get the num of output real_kernel(which can be build and run in device) 135 static size_t GetOutputTensorNum(const AnfNodePtr &node); 136 // Get the memory size of output tensor of node. 137 static size_t GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index); 138 // get all outputs format select of anf node 139 static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node); 140 // get all inputs format select of anf node 141 static std::vector<std::string> GetAllInputFormats(const AnfNodePtr &node); 142 // get all inputs type select of anf node 143 static std::vector<TypeId> GetAllInputDeviceTypes(const AnfNodePtr &node); 144 // get all outputs type select of anf node 145 static std::vector<TypeId> GetAllOutputDeviceTypes(const AnfNodePtr &node); 146 // get origin data format select of anf node 147 static std::string GetOriginDataFormat(const AnfNodePtr &node); 148 // get output format select of anf node 149 static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); 150 // get input format select of anf node 151 static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); 152 // get prev node output width output index 153 static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = false); 154 // get output format from prev node,input_index is the input index of current node related to prev node 155 static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); 156 // get reshape_type of from the output of input node. 157 static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); 158 // get output shapes inferred by ME from input nodes. 159 static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); 160 static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, const abstract::BaseShapePtr &base_shape, 161 size_t output_idx); 162 // get input shapes inferred by ME from input nodes. 163 static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx); 164 // get output shapes which will built and run in device 165 static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); 166 // get input shapes which will built and run in device 167 static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); 168 // get output shapes for tbe build 169 static std::vector<int64_t> GetOutputDeviceShapeForTbeBuild(const AnfNodePtr &node, const size_t output_idx, 170 const std::string &format); 171 // Get Input Padding Axis 172 static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); 173 // Get Output Padding Axis 174 static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); 175 // get output data type inferred by ME of anf node 176 static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); 177 static TypeId GetOutputInferDataType(const TypePtr &type_ptr, size_t output_idx); 178 // get output original data type from prev node,input_index is the input index of current node related to prev node 179 static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx); 180 // get output select data type of anf node 181 static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx); 182 // get input select data type of anf node 183 static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx); 184 // get output select data type from prev node,input_index is the input index of current node related to prev node 185 static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); 186 // get output device addr of anf_node 187 static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); 188 // get mutable output device addr of anf_node 189 static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); 190 // check whether output addr is exist or not 191 static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = false); 192 // check whether workspace addr is exist or not 193 static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx); 194 // get address from prev node,input_index is the input index of current node related to prev node 195 static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, 196 bool visit_nop_node = true); 197 static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, 198 bool visit_nop_node = true); 199 // set output device addr of anf_node 200 static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); 201 // set workspace device addr of anf_node 202 static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); 203 // get workspace device addr of anf_node 204 static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); 205 // get workspace device mutable addr of anf_node 206 static DeviceAddressPtr GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index); 207 // set infer shapes and types of anf node 208 static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types, 209 const std::vector<std::vector<size_t>> &shapes, AnfNode *node); 210 // get and set output shape ptr 211 static abstract::BaseShapePtr GetOutputDetailShape(const AnfNodePtr &node, size_t output_idx); 212 static abstract::BaseShapePtr GetPrevNodeOutputDetailShape(const AnfNodePtr &node, size_t input_idx); 213 static void SetOutputTypeAndDetailShape(const std::vector<TypeId> &types, 214 const std::vector<abstract::BaseShapePtr> &shapes, AnfNode *node); 215 static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); 216 // get op pattern of the node 217 static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); 218 // get KernelBuildType of node ,such as ATT,RT,FWK and so on 219 static KernelType GetKernelType(const AnfNodePtr &node); 220 // get processor type:AICORE,AICPU... 221 static kernel::Processor GetProcessor(const AnfNodePtr &node); 222 // get fusion type:AICORE,AICPU... 223 static kernel::FusionType GetFusionType(const AnfNodePtr &node); 224 static void SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type); 225 static void SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc); 226 static std::vector<nlohmann::json> GetOutputDataDesc(const AnfNodePtr &node); 227 // set select kernel_build_info 228 static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node); 229 // get select kernel_build_info 230 static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node); 231 // get kernelMode 232 static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node); 233 // set kernel mod 234 static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node); 235 // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too 236 static bool IsRealKernel(const AnfNodePtr &node); 237 // checkout whether the anf node is a real kernel that is a cnode and can run on device 238 static bool IsRealCNodeKernel(const AnfNodePtr &node); 239 // checkout whether the anf node is a graph kernel. 240 static bool IsGraphKernel(const AnfNodePtr &node); 241 // checkout whether the anf node is an inner node of graph kernel. 242 static bool IsNodeInGraphKernel(const AnfNodePtr &node); 243 // get the real output of GraphKernel. 244 static AnfNodePtr GetOutputOfGraphkernel(const KernelWithIndex &kernel_with_index); 245 // check parameter is weight or data 246 static bool IsParameterWeight(const ParameterPtr &node); 247 // checkout whether the anf node is include the label_index. 248 static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index); 249 // set stream id of kernel,which will be set in stream assign and be used in stream generate 250 static void SetStreamId(uint32_t stream_id, AnfNode *node); 251 // get stream id 252 static uint32_t GetStreamId(const AnfNodePtr &node); 253 // set stream distinction label to distinguish different ops in different streams 254 static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node); 255 // get stream distinction label 256 static uint32_t GetStreamDistinctionLabel(const AnfNode *node); 257 // set graph id 258 static void SetGraphId(uint32_t graph_id, AnfNode *node); 259 // get graph id 260 static uint32_t GetGraphId(const AnfNode *node); 261 static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); 262 // charge if the node's output is a feature map output 263 static bool IsFeatureMapOutput(const AnfNodePtr &node); 264 // charge if the node's input is from a feature map output 265 static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); 266 // get real input index for some tbe ops which input order is different between me and tbe impl 267 static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); 268 // get me input index for some tbe ops which input order is different between me and tbe impl 269 static size_t GetOriginalInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); 270 static bool IsCommunicationOp(const AnfNodePtr &node); 271 static bool IsFusedCommunicationOp(const AnfNodePtr &node); 272 static bool IsInplaceNode(const AnfNodePtr &node, const string &type); 273 static bool IsGetNext(const NotNull<AnfNodePtr> &node); 274 static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); 275 static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode); 276 static bool IsSwitchCall(const CNodePtr &call_node); 277 static bool IsScalarInput(const CNodePtr &cnode, size_t index); 278 static bool IsScalarOutput(const CNodePtr &cnode, size_t index); 279 static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list); 280 static void ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list); 281 // get fix output precision of cnode. 282 static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); 283 // get fix output precision from prev node, input_idx is the input index of current node related to prev node. 284 static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); 285 static bool IsDynamicShape(const AnfNodePtr &node); 286 static bool HasDynamicShapeFlag(const PrimitivePtr &prim); 287 static bool IsCondControlKernel(const CNodePtr &node); 288 static bool IsIndependentNode(const CNodePtr &node); 289 static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); 290 static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape); 291 static std::vector<int64_t> GetInputMaxShape(const AnfNodePtr &anf_node, size_t index); 292 static std::vector<int64_t> GetInputMinShape(const AnfNodePtr &anf_node, size_t index); 293 static std::vector<int64_t> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index); 294 static std::vector<int64_t> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); 295 static bool IsNodeDynamicShape(const AnfNodePtr &node); 296 static void InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors = nullptr); 297 static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); 298 static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); 299 // Find real input nodes. 300 static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, 301 std::set<AnfNodePtr> *visited); 302 static void GetAllVisitedCNode(const CNodePtr &cnode, std::vector<AnfNodePtr> *used_kernels, 303 std::set<AnfNodePtr> *visited); 304 static void InsertMakeTupleForOutput(const NotNull<KernelGraphPtr> &root_graph); 305 static AnfNodeIndexSet GetUpdateStateUsers(const FuncGraphManagerPtr &manager, const AnfNodePtr &node); 306 // Get node real inputs, skip `MakeTuple`, `TupleGetItem`, `Depend`, `Load`, `UpdateState` etc. 307 static void GetRealInputs(const AnfNodePtr &anf_node, std::vector<session::KernelWithIndex> *inputs); 308 // Check whether tensors need broadcast or not. 309 static bool IsTensorBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs); 310 // Calc tensor size in byte. 311 template <typename T> TensorSizeInByte(const std::vector<size_t> & shape)312 static size_t TensorSizeInByte(const std::vector<size_t> &shape) { 313 size_t result = sizeof(T); 314 for (size_t i = 0; i < shape.size(); i++) { 315 result *= shape[i]; 316 } 317 return result; 318 } 319 static bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set); 320 321 // Judge a control operator need be compiled into kernel graph rather than be cut into single op and 322 // executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch 323 // in backend in PyNative mode. 324 static bool IsControlOpExecInBackend(const AnfNodePtr &node); 325 326 static bool IsNodeInputContainMonad(const AnfNodePtr &node); 327 // Save inputs/outputs/workspace address in kernel_mod. 328 static void CacheAddrForGraph(const KernelGraphPtr &kernel_graph); 329 static void CacheAddrForKernel(const AnfNodePtr &node, kernel::KernelMod *kernel_mod); 330 static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod); 331 }; 332 } // namespace session 333 using AnfAlgo = session::AnfRuntimeAlgorithm; 334 } // namespace mindspore 335 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_ANF_RUNTIME_ALGORITHM_H 336