• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_ANF_UTILS_H_
18 #define MINDSPORE_CORE_UTILS_ANF_UTILS_H_
19 #include <functional>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include <utility>
24 #include <memory>
25 #include "ir/anf.h"
26 #include "ir/dtype.h"
27 #include "base/base.h"
28 #include "ir/primitive.h"
29 #include "utils/hash_map.h"
30 
31 namespace mindspore {
32 constexpr auto kInfer = "DS_Infer";
33 constexpr auto kInit = "DS_Init";
34 constexpr auto kUpdate = "DS_Update";
35 constexpr auto kSkipCheckInputNum = "skip_check_input_num";
36 constexpr auto kInputRealTuple = "input_real_tuple";
37 constexpr auto kOutputRealTuple = "output_real_tuple";
38 
39 // Define constant about size number here.
40 constexpr size_t kSizeZero = 0;
41 constexpr size_t kSizeOne = 1;
42 constexpr size_t kSizeTwo = 2;
43 constexpr size_t kSizeThree = 3;
44 constexpr size_t kSizeFour = 4;
45 constexpr size_t kSizeFive = 5;
46 constexpr size_t kSizeEight = 8;
47 
48 // Define constant about index number here.
49 constexpr size_t kIndexZero = 0;
50 constexpr size_t kIndexOne = 1;
51 constexpr size_t kIndexTwo = 2;
52 constexpr size_t kIndexThree = 3;
53 constexpr size_t kIndexFour = 4;
54 constexpr size_t kIndexFive = 5;
55 
56 class MS_CORE_API AbstractScope {
57  public:
58   explicit AbstractScope(std::recursive_mutex *mu);
59   AbstractScope(const AbstractScope &other) = delete;
60   AbstractScope operator=(const AbstractScope &other) = delete;
61   AbstractScope(AbstractScope &&other);
62   AbstractScope &operator=(AbstractScope &&other);
63   ~AbstractScope();
64 
65  private:
66   std::recursive_mutex *mu_;
67 };
68 
69 class MS_CORE_API AnfUtils {
70  public:
71   using CustomActorCallback = std::function<void(void *args)>;
72   static bool IsNodeOutputShapeDynamic(const AnfNodePtr &node);
73   // check whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
74   static bool IsRealKernel(const AnfNodePtr &node);
75   // check whether the anf node is a real kernel that is a cnode and can run on device
76   static bool IsRealCNodeKernel(const AnfNodePtr &node);
77   // get kernel name of anf node
78   static std::string GetCNodeName(const AnfNodePtr &node);
79   // get the num of inputs exclude monads for real_kernel (which can be build and run in device)
80   static size_t GetInputTensorNum(const AnfNodePtr &node);
81   // get the num of output real_kernel(which can be build and run in device)
82   static size_t GetOutputTensorNum(const AnfNodePtr &node);
83   // set attr of anf node
84   static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
85   // get int value of value node
86   static int64_t GetIntValue(const AnfNodePtr &anf_node);
87   // get int value of value pointer
88   static int64_t GetIntValue(const ValuePtr &value);
89   // get the node's real kernel recursively
90   static std::pair<AnfNodePtr, size_t> VisitKernel(const AnfNodePtr &anf_node, size_t index);
91   // check whether the node is a GraphKernel node.
92   static bool IsGraphKernel(const AnfNodePtr &node);
93   // check whether the node is a node in GraphKernel's subgraph.
94   static bool IsNodeInGraphKernel(const AnfNodePtr &node);
95   // Set dump flag to CNode's primitive.
96   static void SetDumpFlag(const AnfNodePtr &node);
97   // Get dump flag from CNode's primitive.
98   static bool GetDumpFlag(const AnfNodePtr &node);
99   // Check whether the node has dump flag or not.
100   static bool HasDumpFlag(const AnfNodePtr &node);
101   static AbstractScope GetAbstractLock(const AnfNode *node);
102   static void OpenAbstractLock();
103   static void CloseAbstractLock();
104 
105   // Custom actor node is for dynamic shape.
106   // Generate a Init custom actor node.
107   static AnfNodePtr NewInitActorNode(CustomActorCallback f, const CNodePtr &base_cnode);
108   // Generate a Infer custom actor node.
109   static AnfNodePtr NewInferActorNode(CustomActorCallback f, const CNodePtr &base_cnode);
110   static bool IsCustomActorNode(const AnfNodePtr &node);
111   static std::string GetCustomActorType(const AnfNodePtr &node);
112   static std::string GetCustomActorName(const AnfNodePtr &node);
113   static CNodePtr GetCustomActorBaseNode(const AnfNodePtr &node);
114   static CustomActorCallback GetCustomFunc(const AnfNodePtr &node);
115   static bool IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2);
116   // set the inferop,initop to base_node's user_data
117   static void SetCustomInfoToBaseNode(const AnfNodePtr &base_cnode, const AnfNodePtr &inferop,
118                                       const AnfNodePtr &initop);
119   static AnfNodePtr GetCustomInferopNode(const AnfNodePtr &base_cnode);
120   static mindspore::HashMap<size_t, std::pair<AnfNodeWeakPtr, size_t>> &GetRealInputNodes(const CNodePtr &cnode);
121   static std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape);
122   // Judge whether node's monad output should be skipped. Currently this method returns true in one scenarios:
123   // 1. The node is a RpcRecv node with monad type output.
124   static bool NeedJumpMonadOutput(const AnfNodePtr &node);
125 };
126 
127 //
128 // FlatParameterFinder finds flat parameters from parameters.
129 //
130 class MS_CORE_API FlatParameterFinder {
131  public:
132   FlatParameterFinder() = default;
133   ~FlatParameterFinder() = default;
134 
135   // Add a parameter for search.
136   void AddParameter(const ParameterPtr &param);
137 
138   // Add nodes for search, parameter nodes will be added.
139   void AddNodes(const std::vector<AnfNodePtr> &nodes);
140 
141   // Get the flat parameter and data offset for the given parameter.
142   // return (nullptr, 0) when flat parameter not found.
143   std::pair<ParameterPtr, size_t> FindFlatParameter(const ParameterPtr &param);
144 
145   // Get all flat parameters.
146   const std::set<ParameterPtr> &GetFlatParameters();
147 
148  private:
149   struct FlatParamInfo {
150     ParameterPtr flat_param = nullptr;
151     void *chunk = nullptr;
152     size_t offset = 0;
153   };
154 
155   void UpdateFlatParameters();
156 
157   mindspore::HashMap<void *, ParameterPtr> candidate_flat_params_;
158   mindspore::HashMap<ParameterPtr, FlatParamInfo> param_to_flat_param_;
159   std::set<ParameterPtr> flat_params_;
160 };
161 }  // namespace mindspore
162 #endif  // MINDSPORE_CORE_UTILS_ANF_UTILS_H_
163