1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/ms_context.h"
18 #include <thread>
19 #include <atomic>
20 #include <fstream>
21 #include "ir/tensor.h"
22 #include "utils/ms_utils.h"
23
24 namespace mindspore {
25 std::atomic<bool> thread_1_must_end(false);
26
27 std::shared_ptr<MsContext> MsContext::inst_context_ = nullptr;
28 std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBackendGePrior},
29 {"vm", kMsBackendVmOnly},
30 {"ms", kMsBackendMsPrior},
31 {"ge_only", kMsBackendGeOnly},
32 {"vm_prior", kMsBackendVmPrior}};
33
MsContext(const std::string & policy,const std::string & target)34 MsContext::MsContext(const std::string &policy, const std::string &target) {
35 #ifndef ENABLE_SECURITY
36 set_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG, false);
37 set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, ".");
38 #else
39 // Need set a default value for arrays even if running in the security mode.
40 bool_params_[MS_CTX_SAVE_GRAPHS_FLAG - MS_CTX_TYPE_BOOL_BEGIN] = false;
41 string_params_[MS_CTX_SAVE_GRAPHS_PATH - MS_CTX_TYPE_STRING_BEGIN] = ".";
42 #endif
43 set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, "python");
44 set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, "");
45 set_param<bool>(MS_CTX_ENABLE_DUMP, false);
46 set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
47 set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, "");
48 set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE");
49 set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, "");
50 set_param<uint32_t>(MS_CTX_TSD_REF, 0);
51 set_param<uint32_t>(MS_CTX_GE_REF, 0);
52
53 set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
54 set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
55 set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true);
56 auto env_device = common::GetEnv("DEVICE_ID");
57 if (!env_device.empty()) {
58 try {
59 uint32_t device_id = UlongToUint(std::stoul(env_device));
60 set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
61 } catch (std::invalid_argument &e) {
62 MS_LOG(WARNING) << "Invalid DEVICE_ID env:" << env_device << ". Please set DEVICE_ID to 0-7";
63 set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
64 }
65 } else {
66 set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
67 }
68
69 set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
70 set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
71 set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
72 set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
73 set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
74 set_param<bool>(MS_CTX_ENABLE_HCCL, false);
75 set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
76 set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
77 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
78 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
79 set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
80 set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0");
81 set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
82 set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
83 set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
84 set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
85 set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
86 set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
87 set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
88 set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
89 set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
90 set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
91 set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
92 set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
93 set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
94 set_param<bool>(MS_CTX_SAVE_COMPILE_CACHE, false);
95 set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
96 set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
97 set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
98 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
99 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
100
101 backend_policy_ = policy_map_[policy];
102 }
103
GetInstance()104 std::shared_ptr<MsContext> MsContext::GetInstance() {
105 if (inst_context_ == nullptr) {
106 MS_LOG(DEBUG) << "Create new mindspore context";
107 if (device_type_seter_) {
108 device_type_seter_(inst_context_);
109 }
110 }
111 return inst_context_;
112 }
113
set_backend_policy(const std::string & policy)114 bool MsContext::set_backend_policy(const std::string &policy) {
115 if (policy_map_.find(policy) == policy_map_.end()) {
116 MS_LOG(ERROR) << "invalid backend policy name: " << policy;
117 return false;
118 }
119 backend_policy_ = policy_map_[policy];
120 MS_LOG(INFO) << "ms set context backend policy:" << policy;
121 return true;
122 }
123
124 #ifdef ENABLE_TDTQUE
CreateTensorPrintThread(const PrintThreadCrt & ctr)125 void MsContext::CreateTensorPrintThread(const PrintThreadCrt &ctr) {
126 uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
127 std::string kReceivePrefix = "TF_RECEIVE_";
128 std::string channel_name = "_npu_log";
129 acl_handle_ = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
130 if (acl_handle_ == nullptr) {
131 MS_LOG(EXCEPTION) << "Get acltdt handle failed";
132 }
133 MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = " << get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
134 std::string print_file_path = get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
135 acl_tdt_print_ = ctr(print_file_path, acl_handle_);
136 TdtHandle::AddHandle(&acl_handle_, &acl_tdt_print_);
137 }
138
JoinAclPrintThread(std::thread * thread)139 static void JoinAclPrintThread(std::thread *thread) {
140 try {
141 if (thread->joinable()) {
142 MS_LOG(INFO) << "join acl tdt host receive process";
143 thread->join();
144 }
145 } catch (const std::exception &e) {
146 MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
147 }
148 }
149
DestroyTensorPrintThread()150 void MsContext::DestroyTensorPrintThread() {
151 // if TdtHandle::DestroyHandle called at taskmanger, all acl_handle_ will be set to nullptr;
152 // but not joined the print thread, so add a protect to join the thread.
153 if (acl_handle_ == nullptr) {
154 MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr";
155 JoinAclPrintThread(&acl_tdt_print_);
156 return;
157 }
158 aclError stopStatus = acltdtStopChannel(acl_handle_);
159 if (stopStatus != ACL_SUCCESS) {
160 MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl;
161 return;
162 }
163 MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
164 JoinAclPrintThread(&acl_tdt_print_);
165 aclError destroydStatus = acltdtDestroyChannel(acl_handle_);
166 if (destroydStatus != ACL_SUCCESS) {
167 MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl;
168 return;
169 }
170 TdtHandle::DelHandle(&acl_handle_);
171 MS_LOG(INFO) << "Succeed destroy acl channel";
172 }
173
174 #endif
175
backend_policy() const176 std::string MsContext::backend_policy() const {
177 auto res = std::find_if(
178 policy_map_.begin(), policy_map_.end(),
179 [&, this](const std::pair<std::string, MsBackendPolicy> &item) { return item.second == backend_policy_; });
180 if (res != policy_map_.end()) {
181 return res->first;
182 }
183 return "unknown";
184 }
185
enable_dump_ir() const186 bool MsContext::enable_dump_ir() const {
187 #ifdef ENABLE_DUMP_IR
188 return true;
189 #else
190 return false;
191 #endif
192 }
193
194 } // namespace mindspore
195