• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
18 #define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <utility>
25 #include "utils/ms_context.h"
26 
27 namespace mindspore {
28 namespace context {
29 constexpr unsigned int OptLevel_0 = 0;  // Disabled
30 constexpr unsigned int OptLevel_1 = 1;  // Basic functions
31 constexpr unsigned int OptLevel_2 = 2;  // Default functions
32 constexpr unsigned int OptLevel_3 = 3;  // Experimental functions
33 constexpr unsigned int OptLevel_MAX = 4;
34 
35 constexpr unsigned int OpLevel_0 = 0;
36 constexpr unsigned int OpLevel_1 = 1;
37 constexpr unsigned int OpLevel_MAX = 2;
38 
39 class GraphKernelFlags {
40  public:
GetInstance()41   static const GraphKernelFlags &GetInstance() {
42     static std::unique_ptr<GraphKernelFlags> flags(nullptr);
43     auto contexts = GetGraphKernelContext();
44     if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_graph_kernel_) {
45       flags.reset(new GraphKernelFlags(contexts.first, contexts.second));
46       flags->Refresh();
47     }
48     return *flags;
49   }
50 
51   // Dump all flags to json-format string
52   std::string DumpAllFlags() const;
53 
54   // Check whether graph_kernel is enabled
IsEnableGraphKernel()55   bool IsEnableGraphKernel() const { return opt_level > OptLevel_0; }
56 
57   GraphKernelFlags(const GraphKernelFlags &flags) = delete;
58   ~GraphKernelFlags() = default;
59 
60  public:
61   /**
62    * Dump info as human-readable text.
63    * A directory "graph_kernel_dump" will be created, and all information will be dumped in this directory.
64    */
65   bool dump_as_text{false};
66 
67   /**
68    * Enable stitch fusion in graph kernel fusion strategy.
69    *
70    * Experimental feature, enabled by default when opt_level=3
71    */
72   bool enable_stitch_fusion{false};
73 
74   /**
75    * Enable recompute fusion in graph kernel fusion strategy, enabled when op_level>=2.
76    */
77   bool enable_recompute_fusion{false};
78 
79   /**
80    * Enable parallel fusion in graph kernel fusion strategy.
81    *
82    * Experimental feature, enabled by default when opt_level=3
83    */
84   bool enable_parallel_fusion{false};
85 
86   /**
87    * Enable low precision in data transferring between graph kernel and computing in graph kernel
88    * in graph kernel.
89    * Experimental feature, enabled by the enable_low_precision flag
90    */
91   bool enable_low_precision{false};
92 
93   /**
94    * Expand and cluster AKG's operators by level.
95    */
96   unsigned int fusion_ops_level{OpLevel_0};
97 
98   /**
99    * Enable optimization for transform operators (Transpose/TransData)
100    *
101    * Experimental feature, enabled by default when opt_level=3.
102    */
103   bool enable_trans_op_optimize{false};
104 
105   /**
106    * Optimization level, value from 0 to 3.
107    * 0: Disable GraphKernel
108    * 1: Enable GraphKernel with basic features only.
109    * 2: Enable GraphKernel with all stable features.
110    * 3: Enable GraphKernel with all experimental features.
111    * The default value is OptLevel_2 when the context "enable_graph_kernel" is set,
112    * but if it's also changed in "graph_kernel_flags", then the "graph_kernel_flags" will prevail.
113    */
114   unsigned int opt_level{0};  // defaults 0 or 2
115 
116   /**
117    * Online tuning level, value from 0 to 3.
118    * 0: Disable online tuning
119    * 1-3: The higher level, the larger tuning space, and the more time it takes.
120    */
121   unsigned int online_tuning{0};
122 
123   /**
124    * AKG's operator repository file path.
125    */
126   std::string repository_path;
127 
128   /**
129    * Additional expanding operators (case sensitive).
130    * The operators to be added into the default expanding operator list.
131    */
132   std::vector<std::string> enable_expand_ops;
133 
134   /**
135    * Expanding operators to be enabled (case sensitive).
136    * Unlike the "enable_expand_ops", the default list will be overwritten by this list.
137    * Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set.
138    */
139   std::vector<std::string> enable_expand_ops_only;
140 
141   /**
142    * Expanding operators to be disabled (case sensitive).
143    * The behavior is undefined when this list overlaps with "enable_expand_ops".
144    */
145   std::vector<std::string> disable_expand_ops;
146 
147   /**
148    * Additional clustering operators (case sensitive).
149    * The operators to be added into the default clustering operator list.
150    */
151   std::vector<std::string> enable_cluster_ops;
152 
153   /**
154    * Clustering operators to be enabled (case sensitive).
155    * Unlike the "enable_cluster_ops", the default list will be overwritten by this list.
156    * Note that the "enable_cluster_ops" and "disable_cluster_ops" will be ignored if this flag is set.
157    */
158   std::vector<std::string> enable_cluster_ops_only;
159 
160   /**
161    * Clustering operators to be disabled (case sensitive).
162    * The behavior is undefined when this list overlaps with "enable_cluster_ops".
163    */
164   std::vector<std::string> disable_cluster_ops;
165 
166   /**
167    * Arithmetic simplify expressions to be enabled (case sensitive).
168    * The default list will be overwritten by this list.
169    * Note that "disable_simplify_exprs" will be ignored if this flag is set.
170    */
171   std::vector<std::string> enable_simplify_exprs_only;
172 
173   /**
174    * Arithmetic simplify expressions to be disabled (case sensitive).
175    */
176   std::vector<std::string> disable_simplify_exprs;
177 
178   /**
179    * Passes to be enabled.
180    * By default, the passes is controlled by "opt_level" and target device,
181    * user can manually enable some passes by setting this flag.
182    * The format is "stage_id.pass_id" or "stage_name.pass_name", which corresponds to the ir filename.
183    */
184   std::vector<std::string> enable_pass;
185 
186   /**
187    * Passes to be disabled.
188    * By default, the passes is controlled by "opt_level" and target device,
189    * user can manually disable some passes by setting this flag.
190    * The format is "stage_id.pass_id" or "stage_name.pass_name", which corresponds to the ir filename.
191    */
192   std::vector<std::string> disable_pass;
193 
194  private:
GraphKernelFlags(const std::string & graph_kernel_flags,bool enable_graph_kernel)195   GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel)
196       : flags_cache_(graph_kernel_flags), enable_graph_kernel_(enable_graph_kernel) {}
197 
198   // get the `graph_kernel_flags` and `enable_graph_kernel`
GetGraphKernelContext()199   static std::pair<std::string, bool> GetGraphKernelContext() {
200     auto context = MsContext::GetInstance();
201     MS_EXCEPTION_IF_NULL(context);
202     // Use the environment variable in priority
203     auto env_flags = std::getenv("MS_GRAPH_KERNEL_FLAGS");
204     std::string flags = env_flags ? std::string(env_flags) : context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS);
205     return std::make_pair(flags, context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL));
206   }
207 
208   // parse and refresh the flags
209   void Refresh();
210   // register the flags defined above
211   void RegisterFlags(std::map<std::string, std::string> *flag_map);
212 
213   // cache the flag string to check whether the flags is changed.
214   std::string flags_cache_;
215   // cache the enable_graph_kernel value to check whether the context is changed.
216   bool enable_graph_kernel_;
217 };
218 }  // namespace context
219 }  // namespace mindspore
220 #endif  // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
221