1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_ANF_UTILS_H_ 18 #define MINDSPORE_CORE_UTILS_ANF_UTILS_H_ 19 #include <functional> 20 #include <set> 21 #include <vector> 22 #include <string> 23 #include <utility> 24 #include <memory> 25 #include "ir/anf.h" 26 #include "ir/dtype.h" 27 #include "base/base.h" 28 #include "ir/primitive.h" 29 #include "utils/hash_map.h" 30 31 namespace mindspore { 32 constexpr auto kInfer = "DS_Infer"; 33 constexpr auto kInit = "DS_Init"; 34 constexpr auto kUpdate = "DS_Update"; 35 constexpr auto kSkipCheckInputNum = "skip_check_input_num"; 36 constexpr auto kInputRealTuple = "input_real_tuple"; 37 constexpr auto kOutputRealTuple = "output_real_tuple"; 38 39 // Define constant about size number here. 40 constexpr size_t kSizeZero = 0; 41 constexpr size_t kSizeOne = 1; 42 constexpr size_t kSizeTwo = 2; 43 constexpr size_t kSizeThree = 3; 44 constexpr size_t kSizeFour = 4; 45 constexpr size_t kSizeFive = 5; 46 constexpr size_t kSizeEight = 8; 47 48 // Define constant about index number here. 49 constexpr size_t kIndexZero = 0; 50 constexpr size_t kIndexOne = 1; 51 constexpr size_t kIndexTwo = 2; 52 constexpr size_t kIndexThree = 3; 53 constexpr size_t kIndexFour = 4; 54 constexpr size_t kIndexFive = 5; 55 56 class MS_CORE_API AbstractScope { 57 public: 58 explicit AbstractScope(std::recursive_mutex *mu); 59 AbstractScope(const AbstractScope &other) = delete; 60 AbstractScope operator=(const AbstractScope &other) = delete; 61 AbstractScope(AbstractScope &&other); 62 AbstractScope &operator=(AbstractScope &&other); 63 ~AbstractScope(); 64 65 private: 66 std::recursive_mutex *mu_; 67 }; 68 69 class MS_CORE_API AnfUtils { 70 public: 71 using CustomActorCallback = std::function<void(void *args)>; 72 static bool IsNodeOutputShapeDynamic(const AnfNodePtr &node); 73 // check whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too 74 static bool IsRealKernel(const AnfNodePtr &node); 75 // check whether the anf node is a real kernel that is a cnode and can run on device 76 static bool IsRealCNodeKernel(const AnfNodePtr &node); 77 // get kernel name of anf node 78 static std::string GetCNodeName(const AnfNodePtr &node); 79 // get the num of inputs exclude monads for real_kernel (which can be build and run in device) 80 static size_t GetInputTensorNum(const AnfNodePtr &node); 81 // get the num of output real_kernel(which can be build and run in device) 82 static size_t GetOutputTensorNum(const AnfNodePtr &node); 83 // set attr of anf node 84 static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); 85 // get int value of value node 86 static int64_t GetIntValue(const AnfNodePtr &anf_node); 87 // get int value of value pointer 88 static int64_t GetIntValue(const ValuePtr &value); 89 // get the node's real kernel recursively 90 static std::pair<AnfNodePtr, size_t> VisitKernel(const AnfNodePtr &anf_node, size_t index); 91 // check whether the node is a GraphKernel node. 92 static bool IsGraphKernel(const AnfNodePtr &node); 93 // check whether the node is a node in GraphKernel's subgraph. 94 static bool IsNodeInGraphKernel(const AnfNodePtr &node); 95 // Set dump flag to CNode's primitive. 96 static void SetDumpFlag(const AnfNodePtr &node); 97 // Get dump flag from CNode's primitive. 98 static bool GetDumpFlag(const AnfNodePtr &node); 99 // Check whether the node has dump flag or not. 100 static bool HasDumpFlag(const AnfNodePtr &node); 101 static AbstractScope GetAbstractLock(const AnfNode *node); 102 static void OpenAbstractLock(); 103 static void CloseAbstractLock(); 104 105 // Custom actor node is for dynamic shape. 106 // Generate a Init custom actor node. 107 static AnfNodePtr NewInitActorNode(CustomActorCallback f, const CNodePtr &base_cnode); 108 // Generate a Infer custom actor node. 109 static AnfNodePtr NewInferActorNode(CustomActorCallback f, const CNodePtr &base_cnode); 110 static bool IsCustomActorNode(const AnfNodePtr &node); 111 static std::string GetCustomActorType(const AnfNodePtr &node); 112 static std::string GetCustomActorName(const AnfNodePtr &node); 113 static CNodePtr GetCustomActorBaseNode(const AnfNodePtr &node); 114 static CustomActorCallback GetCustomFunc(const AnfNodePtr &node); 115 static bool IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2); 116 // set the inferop,initop to base_node's user_data 117 static void SetCustomInfoToBaseNode(const AnfNodePtr &base_cnode, const AnfNodePtr &inferop, 118 const AnfNodePtr &initop); 119 static AnfNodePtr GetCustomInferopNode(const AnfNodePtr &base_cnode); 120 static mindspore::HashMap<size_t, std::pair<AnfNodeWeakPtr, size_t>> &GetRealInputNodes(const CNodePtr &cnode); 121 static std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape); 122 // Judge whether node's monad output should be skipped. Currently this method returns true in one scenarios: 123 // 1. The node is a RpcRecv node with monad type output. 124 static bool NeedJumpMonadOutput(const AnfNodePtr &node); 125 }; 126 127 // 128 // FlatParameterFinder finds flat parameters from parameters. 129 // 130 class MS_CORE_API FlatParameterFinder { 131 public: 132 FlatParameterFinder() = default; 133 ~FlatParameterFinder() = default; 134 135 // Add a parameter for search. 136 void AddParameter(const ParameterPtr ¶m); 137 138 // Add nodes for search, parameter nodes will be added. 139 void AddNodes(const std::vector<AnfNodePtr> &nodes); 140 141 // Get the flat parameter and data offset for the given parameter. 142 // return (nullptr, 0) when flat parameter not found. 143 std::pair<ParameterPtr, size_t> FindFlatParameter(const ParameterPtr ¶m); 144 145 // Get all flat parameters. 146 const std::set<ParameterPtr> &GetFlatParameters(); 147 148 private: 149 struct FlatParamInfo { 150 ParameterPtr flat_param = nullptr; 151 void *chunk = nullptr; 152 size_t offset = 0; 153 }; 154 155 void UpdateFlatParameters(); 156 157 mindspore::HashMap<void *, ParameterPtr> candidate_flat_params_; 158 mindspore::HashMap<ParameterPtr, FlatParamInfo> param_to_flat_param_; 159 std::set<ParameterPtr> flat_params_; 160 }; 161 } // namespace mindspore 162 #endif // MINDSPORE_CORE_UTILS_ANF_UTILS_H_ 163