1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H 18 #define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 #include <utility> 25 #include "utils/ms_context.h" 26 27 namespace mindspore { 28 namespace context { 29 constexpr unsigned int OptLevel_0 = 0; // Disabled 30 constexpr unsigned int OptLevel_1 = 1; // Basic functions 31 constexpr unsigned int OptLevel_2 = 2; // Default functions 32 constexpr unsigned int OptLevel_3 = 3; // Experimental functions 33 constexpr unsigned int OptLevel_MAX = 4; 34 35 constexpr unsigned int OpLevel_0 = 0; 36 constexpr unsigned int OpLevel_1 = 1; 37 constexpr unsigned int OpLevel_MAX = 2; 38 39 class GraphKernelFlags { 40 public: GetInstance()41 static const GraphKernelFlags &GetInstance() { 42 static std::unique_ptr<GraphKernelFlags> flags(nullptr); 43 auto contexts = GetGraphKernelContext(); 44 if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_graph_kernel_) { 45 flags.reset(new GraphKernelFlags(contexts.first, contexts.second)); 46 flags->Refresh(); 47 } 48 return *flags; 49 } 50 51 // Dump all flags to json-format string 52 std::string DumpAllFlags() const; 53 54 // Check whether graph_kernel is enabled IsEnableGraphKernel()55 bool IsEnableGraphKernel() const { return opt_level > OptLevel_0; } 56 57 GraphKernelFlags(const GraphKernelFlags &flags) = delete; 58 ~GraphKernelFlags() = default; 59 60 public: 61 /** 62 * Dump info as human-readable text. 63 * A directory "graph_kernel_dump" will be created, and all information will be dumped in this directory. 64 */ 65 bool dump_as_text{false}; 66 67 /** 68 * Enable stitch fusion in graph kernel fusion strategy. 69 * 70 * Experimental feature, enabled by default when opt_level=3 71 */ 72 bool enable_stitch_fusion{false}; 73 74 /** 75 * Enable recompute fusion in graph kernel fusion strategy, enabled when op_level>=2. 76 */ 77 bool enable_recompute_fusion{false}; 78 79 /** 80 * Enable parallel fusion in graph kernel fusion strategy. 81 * 82 * Experimental feature, enabled by default when opt_level=3 83 */ 84 bool enable_parallel_fusion{false}; 85 86 /** 87 * Enable low precision in data transferring between graph kernel and computing in graph kernel 88 * in graph kernel. 89 * Experimental feature, enabled by the enable_low_precision flag 90 */ 91 bool enable_low_precision{false}; 92 93 /** 94 * Expand and cluster AKG's operators by level. 95 */ 96 unsigned int fusion_ops_level{OpLevel_0}; 97 98 /** 99 * Enable optimization for transform operators (Transpose/TransData) 100 * 101 * Experimental feature, enabled by default when opt_level=3. 102 */ 103 bool enable_trans_op_optimize{false}; 104 105 /** 106 * Optimization level, value from 0 to 3. 107 * 0: Disable GraphKernel 108 * 1: Enable GraphKernel with basic features only. 109 * 2: Enable GraphKernel with all stable features. 110 * 3: Enable GraphKernel with all experimental features. 111 * The default value is OptLevel_2 when the context "enable_graph_kernel" is set, 112 * but if it's also changed in "graph_kernel_flags", then the "graph_kernel_flags" will prevail. 113 */ 114 unsigned int opt_level{0}; // defaults 0 or 2 115 116 /** 117 * Online tuning level, value from 0 to 3. 118 * 0: Disable online tuning 119 * 1-3: The higher level, the larger tuning space, and the more time it takes. 120 */ 121 unsigned int online_tuning{0}; 122 123 /** 124 * AKG's operator repository file path. 125 */ 126 std::string repository_path; 127 128 /** 129 * Additional expanding operators (case sensitive). 130 * The operators to be added into the default expanding operator list. 131 */ 132 std::vector<std::string> enable_expand_ops; 133 134 /** 135 * Expanding operators to be enabled (case sensitive). 136 * Unlike the "enable_expand_ops", the default list will be overwritten by this list. 137 * Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set. 138 */ 139 std::vector<std::string> enable_expand_ops_only; 140 141 /** 142 * Expanding operators to be disabled (case sensitive). 143 * The behavior is undefined when this list overlaps with "enable_expand_ops". 144 */ 145 std::vector<std::string> disable_expand_ops; 146 147 /** 148 * Additional clustering operators (case sensitive). 149 * The operators to be added into the default clustering operator list. 150 */ 151 std::vector<std::string> enable_cluster_ops; 152 153 /** 154 * Clustering operators to be enabled (case sensitive). 155 * Unlike the "enable_cluster_ops", the default list will be overwritten by this list. 156 * Note that the "enable_cluster_ops" and "disable_cluster_ops" will be ignored if this flag is set. 157 */ 158 std::vector<std::string> enable_cluster_ops_only; 159 160 /** 161 * Clustering operators to be disabled (case sensitive). 162 * The behavior is undefined when this list overlaps with "enable_cluster_ops". 163 */ 164 std::vector<std::string> disable_cluster_ops; 165 166 /** 167 * Arithmetic simplify expressions to be enabled (case sensitive). 168 * The default list will be overwritten by this list. 169 * Note that "disable_simplify_exprs" will be ignored if this flag is set. 170 */ 171 std::vector<std::string> enable_simplify_exprs_only; 172 173 /** 174 * Arithmetic simplify expressions to be disabled (case sensitive). 175 */ 176 std::vector<std::string> disable_simplify_exprs; 177 178 /** 179 * Passes to be enabled. 180 * By default, the passes is controlled by "opt_level" and target device, 181 * user can manually enable some passes by setting this flag. 182 * The format is "stage_id.pass_id" or "stage_name.pass_name", which corresponds to the ir filename. 183 */ 184 std::vector<std::string> enable_pass; 185 186 /** 187 * Passes to be disabled. 188 * By default, the passes is controlled by "opt_level" and target device, 189 * user can manually disable some passes by setting this flag. 190 * The format is "stage_id.pass_id" or "stage_name.pass_name", which corresponds to the ir filename. 191 */ 192 std::vector<std::string> disable_pass; 193 194 private: GraphKernelFlags(const std::string & graph_kernel_flags,bool enable_graph_kernel)195 GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel) 196 : flags_cache_(graph_kernel_flags), enable_graph_kernel_(enable_graph_kernel) {} 197 198 // get the `graph_kernel_flags` and `enable_graph_kernel` GetGraphKernelContext()199 static std::pair<std::string, bool> GetGraphKernelContext() { 200 auto context = MsContext::GetInstance(); 201 MS_EXCEPTION_IF_NULL(context); 202 // Use the environment variable in priority 203 auto env_flags = std::getenv("MS_GRAPH_KERNEL_FLAGS"); 204 std::string flags = env_flags ? std::string(env_flags) : context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS); 205 return std::make_pair(flags, context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL)); 206 } 207 208 // parse and refresh the flags 209 void Refresh(); 210 // register the flags defined above 211 void RegisterFlags(std::map<std::string, std::string> *flag_map); 212 213 // cache the flag string to check whether the flags is changed. 214 std::string flags_cache_; 215 // cache the enable_graph_kernel value to check whether the context is changed. 216 bool enable_graph_kernel_; 217 }; 218 } // namespace context 219 } // namespace mindspore 220 #endif // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H 221