1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ 18 #define MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ 19 #include <thread> 20 #include <memory> 21 #include <map> 22 #include <set> 23 #include <vector> 24 #include <string> 25 #include <utility> 26 #include <functional> 27 #include "utils/log_adapter.h" 28 #include "utils/ms_utils.h" 29 #ifdef ENABLE_TDTQUE 30 #include "pybind11/pybind11.h" 31 #include "mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_handle.h" 32 using mindspore::dataset::TdtHandle; 33 #endif 34 #ifndef NO_DLIB 35 #include "acl/acl_tdt.h" 36 #endif 37 38 namespace mindspore { 39 enum MsBackendPolicy { 40 kMsBackendGeOnly = 0, 41 kMsBackendVmOnly = 1, 42 kMsBackendGePrior = 2, 43 kMsBackendVmPrior = 3, 44 kMsBackendMsPrior = 4, 45 kMsBackendUnknown = 5, 46 }; 47 48 const int kGraphMode = 0; 49 const int kPynativeMode = 1; 50 const char kCPUDevice[] = "CPU"; 51 const char kGPUDevice[] = "GPU"; 52 const char kAscendDevice[] = "Ascend"; 53 const char kDavinciInferenceDevice[] = "AscendInference"; 54 const char kDavinciMultiGraphInferenceDevice[] = "AscendMultiGraphInference"; 55 const char kGpuInferenceDevice[] = "GpuInference"; 56 const char kDavinciDevice[] = "Davinci"; 57 const char KNpuLog[] = "_npu_log"; 58 const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; 59 60 const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; 61 // The default max available device memory is 1024GB. 62 const float kDefaultMaxDeviceMemory = 1024; 63 64 // enum definition for MindSpore Context Parameter 65 enum MsCtxParam : unsigned { 66 // parameter of type bool 67 MS_CTX_TYPE_BOOL_BEGIN, 68 MS_CTX_CHECK_BPROP_FLAG = MS_CTX_TYPE_BOOL_BEGIN, 69 MS_CTX_ENABLE_DUMP, 70 MS_CTX_ENABLE_DYNAMIC_MEM_POOL, 71 MS_CTX_ENABLE_GPU_SUMMARY, 72 MS_CTX_ENABLE_GRAPH_KERNEL, 73 MS_CTX_ENABLE_HCCL, 74 MS_CTX_ENABLE_LOOP_SINK, 75 MS_CTX_ENABLE_MEM_SCHEDULER, 76 MS_CTX_ENABLE_PYNATIVE_HOOK, 77 MS_CTX_ENABLE_PYNATIVE_INFER, 78 MS_CTX_ENABLE_REDUCE_PRECISION, 79 MS_CTX_ENABLE_SPARSE, 80 MS_CTX_ENABLE_TASK_SINK, 81 MS_CTX_IR_FUSION_FLAG, 82 MS_CTX_IS_MULTI_GRAPH_SINK, 83 MS_CTX_IS_PYNATIVE_GE_INIT, 84 MS_CTX_PRECOMPILE_ONLY, 85 MS_CTX_ENABLE_PROFILING, 86 MS_CTX_SAVE_GRAPHS_FLAG, 87 MS_CTX_ENABLE_PARALLEL_SPLIT, 88 MS_CTX_ENABLE_INFER_OPT, 89 MS_CTX_GRAD_FOR_SCALAR, 90 MS_CTX_SAVE_COMPILE_CACHE, 91 MS_CTX_LOAD_COMPILE_CACHE, 92 MS_CTX_ENABLE_MINDRT, 93 MS_CTX_ALREADY_SET_ENABLE_MINDRT, 94 MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, 95 MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, 96 MS_CTX_TYPE_BOOL_END, 97 98 // parameter of type int 99 MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, 100 MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, 101 MS_CTX_TYPE_INT_END, 102 103 // parameter of type uint32 104 MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, 105 MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, 106 MS_CTX_GE_REF, 107 MS_CTX_MAX_CALL_DEPTH, 108 MS_CTX_TSD_REF, 109 MS_CTX_TYPE_UINT32_END, 110 111 // parameter of type float 112 MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, 113 MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, 114 MS_CTX_TYPE_FLOAT_END, 115 116 // parameter of type string 117 MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, 118 MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, 119 MS_CTX_GRAPH_MEMORY_MAX_SIZE, 120 MS_CTX_PRINT_FILE_PATH, 121 MS_CTX_PROFILING_OPTIONS, 122 MS_CTX_SAVE_DUMP_PATH, 123 MS_CTX_SAVE_GRAPHS_PATH, 124 MS_CTX_VARIABLE_MEMORY_MAX_SIZE, 125 MS_CTX_PYTHON_EXE_PATH, 126 MS_CTX_KERNEL_BUILD_SERVER_DIR, 127 MS_CTX_ENV_CONFIG_PATH, 128 MS_CTX_TUNE_MODE, 129 MS_CTX_GRAPH_KERNEL_FLAGS, 130 MS_CTX_INFER_PRECISION_MODE, // GPU inference precision mode configured by Serving or Unify API. 131 MS_CTX_TYPE_STRING_END, 132 133 // parameter numbers of each type 134 NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN, 135 NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN, 136 NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN, 137 NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN, 138 NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN 139 }; 140 141 class MsContext { 142 public: 143 MsContext(const std::string &backend_policy, const std::string &target); 144 ~MsContext() = default; 145 MsContext(const MsContext &) = delete; 146 MsContext &operator=(const MsContext &) = delete; 147 using DeviceSeter = std::function<void(const std::string &device_target)>; 148 using DeviceTypeSeter = std::function<void(std::shared_ptr<MsContext> &)>; 149 static std::shared_ptr<MsContext> GetInstance(); 150 151 bool enable_dump_ir() const; 152 std::string backend_policy() const; 153 bool set_backend_policy(const std::string &policy); 154 #ifdef ENABLE_TDTQUE 155 using PrintThreadCrt = std::function<std::thread(std::string &, acltdtChannelHandle *)>; 156 void CreateTensorPrintThread(const PrintThreadCrt &ctr); 157 void DestroyTensorPrintThread(); 158 #endif device_seter(DeviceSeter device)159 static void device_seter(DeviceSeter device) { seter_ = device; } device_type_seter(DeviceTypeSeter device_type)160 static void device_type_seter(DeviceTypeSeter device_type) { device_type_seter_ = device_type; } 161 162 template <typename T> set_param(MsCtxParam param,const T & value)163 void set_param(MsCtxParam param, const T &value) { 164 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 165 } 166 167 template <typename T> get_param(MsCtxParam param)168 const T &get_param(MsCtxParam param) const { 169 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 170 } 171 172 template <typename T> increase_param(MsCtxParam param)173 void increase_param(MsCtxParam param) { 174 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 175 } 176 177 template <typename T> decrease_param(MsCtxParam param)178 void decrease_param(MsCtxParam param) { 179 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 180 } 181 182 private: 183 inline static DeviceSeter seter_ = nullptr; 184 inline static DeviceTypeSeter device_type_seter_ = nullptr; 185 static std::shared_ptr<MsContext> inst_context_; 186 static std::map<std::string, MsBackendPolicy> policy_map_; 187 188 bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS]; 189 int int_params_[MsCtxParam::NUM_INT_PARAMS]; 190 uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; 191 float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; 192 std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; 193 MsBackendPolicy backend_policy_; 194 #ifdef ENABLE_TDTQUE 195 acltdtChannelHandle *acl_handle_ = nullptr; 196 std::thread acl_tdt_print_; 197 #endif 198 }; 199 200 // set method implementation for type bool/int/uint32_t/float/std::string 201 template <> 202 inline void MsContext::set_param<bool>(MsCtxParam param, const bool &value) { 203 #ifdef ENABLE_SECURITY 204 if (param == MS_CTX_SAVE_GRAPHS_FLAG) { 205 MS_EXCEPTION(ValueError) << "The save_graphs is not supported, please without '-s on' and recompile source."; 206 } 207 #endif 208 bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value; 209 } 210 211 template <> 212 inline void MsContext::set_param<int>(MsCtxParam param, const int &value) { 213 int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value; 214 } 215 216 template <> 217 inline void MsContext::set_param<uint32_t>(MsCtxParam param, const uint32_t &value) { 218 uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value; 219 } 220 221 template <> 222 inline void MsContext::set_param<float>(MsCtxParam param, const float &value) { 223 float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value; 224 } 225 226 template <> 227 inline void MsContext::set_param<std::string>(MsCtxParam param, const std::string &value) { 228 #ifdef ENABLE_SECURITY 229 if (param == MS_CTX_SAVE_GRAPHS_PATH) { 230 MS_EXCEPTION(ValueError) << "The save_graphs is not supported, please without '-s on' and recompile source."; 231 } 232 #endif 233 if (seter_ != nullptr && param == MS_CTX_DEVICE_TARGET) { 234 MS_LOG(INFO) << "ms set context device target:" << value; 235 seter_(value); 236 } 237 string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value; 238 } 239 240 // get method implementation for type bool/int/uint32_t/float/std::string 241 template <> 242 inline const bool &MsContext::get_param<bool>(MsCtxParam param) const { 243 return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN]; 244 } 245 246 template <> 247 inline const int &MsContext::get_param<int>(MsCtxParam param) const { 248 return int_params_[param - MS_CTX_TYPE_INT_BEGIN]; 249 } 250 251 template <> 252 inline const uint32_t &MsContext::get_param<uint32_t>(MsCtxParam param) const { 253 return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]; 254 } 255 256 template <> 257 inline const float &MsContext::get_param<float>(MsCtxParam param) const { 258 return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN]; 259 } 260 261 template <> 262 inline const std::string &MsContext::get_param<std::string>(MsCtxParam param) const { 263 return string_params_[param - MS_CTX_TYPE_STRING_BEGIN]; 264 } 265 266 // increate method implementation for type uint32_t 267 template <> 268 inline void MsContext::increase_param<uint32_t>(MsCtxParam param) { 269 uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++; 270 } 271 272 // decreate method implementation for type uint32_t 273 template <> 274 inline void MsContext::decrease_param<uint32_t>(MsCtxParam param) { 275 uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--; 276 } 277 } // namespace mindspore 278 279 #endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ 280