• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/ms_context.h"
18 #include <thread>
19 #include <atomic>
20 #include <fstream>
21 #include "ir/tensor.h"
22 #include "utils/ms_utils.h"
23 
24 namespace mindspore {
25 std::atomic<bool> thread_1_must_end(false);
26 
27 std::shared_ptr<MsContext> MsContext::inst_context_ = nullptr;
28 std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBackendGePrior},
29                                                                  {"vm", kMsBackendVmOnly},
30                                                                  {"ms", kMsBackendMsPrior},
31                                                                  {"ge_only", kMsBackendGeOnly},
32                                                                  {"vm_prior", kMsBackendVmPrior}};
33 
MsContext(const std::string & policy,const std::string & target)34 MsContext::MsContext(const std::string &policy, const std::string &target) {
35 #ifndef ENABLE_SECURITY
36   set_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG, false);
37   set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, ".");
38 #else
39   // Need set a default value for arrays even if running in the security mode.
40   bool_params_[MS_CTX_SAVE_GRAPHS_FLAG - MS_CTX_TYPE_BOOL_BEGIN] = false;
41   string_params_[MS_CTX_SAVE_GRAPHS_PATH - MS_CTX_TYPE_STRING_BEGIN] = ".";
42 #endif
43   set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, "python");
44   set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, "");
45   set_param<bool>(MS_CTX_ENABLE_DUMP, false);
46   set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
47   set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, "");
48   set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE");
49   set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, "");
50   set_param<uint32_t>(MS_CTX_TSD_REF, 0);
51   set_param<uint32_t>(MS_CTX_GE_REF, 0);
52 
53   set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
54   set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
55   set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true);
56   auto env_device = common::GetEnv("DEVICE_ID");
57   if (!env_device.empty()) {
58     try {
59       uint32_t device_id = UlongToUint(std::stoul(env_device));
60       set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
61     } catch (std::invalid_argument &e) {
62       MS_LOG(WARNING) << "Invalid DEVICE_ID env:" << env_device << ". Please set DEVICE_ID to 0-7";
63       set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
64     }
65   } else {
66     set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
67   }
68 
69   set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
70   set_param<std::string>(MS_CTX_DEVICE_TARGET, target);
71   set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
72   set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
73   set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
74   set_param<bool>(MS_CTX_ENABLE_HCCL, false);
75   set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
76   set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
77   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
78   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
79   set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
80   set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0");
81   set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
82   set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
83   set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
84   set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, false);
85   set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
86   set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
87   set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
88   set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
89   set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
90   set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
91   set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
92   set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
93   set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
94   set_param<bool>(MS_CTX_SAVE_COMPILE_CACHE, false);
95   set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
96   set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
97   set_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT, false);
98   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
99   set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
100 
101   backend_policy_ = policy_map_[policy];
102 }
103 
GetInstance()104 std::shared_ptr<MsContext> MsContext::GetInstance() {
105   if (inst_context_ == nullptr) {
106     MS_LOG(DEBUG) << "Create new mindspore context";
107     if (device_type_seter_) {
108       device_type_seter_(inst_context_);
109     }
110   }
111   return inst_context_;
112 }
113 
set_backend_policy(const std::string & policy)114 bool MsContext::set_backend_policy(const std::string &policy) {
115   if (policy_map_.find(policy) == policy_map_.end()) {
116     MS_LOG(ERROR) << "invalid backend policy name: " << policy;
117     return false;
118   }
119   backend_policy_ = policy_map_[policy];
120   MS_LOG(INFO) << "ms set context backend policy:" << policy;
121   return true;
122 }
123 
124 #ifdef ENABLE_TDTQUE
CreateTensorPrintThread(const PrintThreadCrt & ctr)125 void MsContext::CreateTensorPrintThread(const PrintThreadCrt &ctr) {
126   uint32_t device_id = get_param<uint32_t>(MS_CTX_DEVICE_ID);
127   std::string kReceivePrefix = "TF_RECEIVE_";
128   std::string channel_name = "_npu_log";
129   acl_handle_ = acltdtCreateChannel(device_id, (kReceivePrefix + channel_name).c_str());
130   if (acl_handle_ == nullptr) {
131     MS_LOG(EXCEPTION) << "Get acltdt handle failed";
132   }
133   MS_LOG(INFO) << "Success to create acltdt handle, tsd reference = " << get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
134   std::string print_file_path = get_param<std::string>(MS_CTX_PRINT_FILE_PATH);
135   acl_tdt_print_ = ctr(print_file_path, acl_handle_);
136   TdtHandle::AddHandle(&acl_handle_, &acl_tdt_print_);
137 }
138 
JoinAclPrintThread(std::thread * thread)139 static void JoinAclPrintThread(std::thread *thread) {
140   try {
141     if (thread->joinable()) {
142       MS_LOG(INFO) << "join acl tdt host receive process";
143       thread->join();
144     }
145   } catch (const std::exception &e) {
146     MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
147   }
148 }
149 
DestroyTensorPrintThread()150 void MsContext::DestroyTensorPrintThread() {
151   // if TdtHandle::DestroyHandle called at taskmanger, all acl_handle_ will be set to nullptr;
152   // but not joined the print thread, so add a protect to join the thread.
153   if (acl_handle_ == nullptr) {
154     MS_LOG(INFO) << "The acl handle has been destroyed and the point is nullptr";
155     JoinAclPrintThread(&acl_tdt_print_);
156     return;
157   }
158   aclError stopStatus = acltdtStopChannel(acl_handle_);
159   if (stopStatus != ACL_SUCCESS) {
160     MS_LOG(ERROR) << "Failed stop acl data channel and the stopStatus is " << stopStatus << std::endl;
161     return;
162   }
163   MS_LOG(INFO) << "Succeed stop acl data channel for host queue ";
164   JoinAclPrintThread(&acl_tdt_print_);
165   aclError destroydStatus = acltdtDestroyChannel(acl_handle_);
166   if (destroydStatus != ACL_SUCCESS) {
167     MS_LOG(ERROR) << "Failed destroy acl channel and the destroyStatus is " << destroydStatus << std::endl;
168     return;
169   }
170   TdtHandle::DelHandle(&acl_handle_);
171   MS_LOG(INFO) << "Succeed destroy acl channel";
172 }
173 
174 #endif
175 
backend_policy() const176 std::string MsContext::backend_policy() const {
177   auto res = std::find_if(
178     policy_map_.begin(), policy_map_.end(),
179     [&, this](const std::pair<std::string, MsBackendPolicy> &item) { return item.second == backend_policy_; });
180   if (res != policy_map_.end()) {
181     return res->first;
182   }
183   return "unknown";
184 }
185 
enable_dump_ir() const186 bool MsContext::enable_dump_ir() const {
187 #ifdef ENABLE_DUMP_IR
188   return true;
189 #else
190   return false;
191 #endif
192 }
193 
194 }  // namespace mindspore
195