1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/common/context_util.h"
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include "src/common/log_adapter.h"
24 #include "src/common/utils.h"
25
26 namespace mindspore {
27 namespace lite {
28 namespace {
29 template <class T>
PassBasicProperties(std::shared_ptr<T> device_info,const lite::DeviceContext & device_context)30 void PassBasicProperties(std::shared_ptr<T> device_info, const lite::DeviceContext &device_context) {
31 MS_ASSERT(device_info != nullptr);
32 device_info->SetProvider(device_context.provider_);
33 device_info->SetProviderDevice(device_context.provider_device_);
34 device_info->SetAllocator(device_context.allocator_);
35 }
36
CPUDeviceInfoFromCPUDeviceContext(const lite::DeviceContext & cpu_context)37 std::shared_ptr<mindspore::CPUDeviceInfo> CPUDeviceInfoFromCPUDeviceContext(const lite::DeviceContext &cpu_context) {
38 if (cpu_context.device_type_ != DT_CPU) {
39 MS_LOG(ERROR) << "function input parameter is not cpu context.";
40 return nullptr;
41 }
42 auto cpu_info = std::make_shared<mindspore::CPUDeviceInfo>();
43 MS_CHECK_TRUE_RET(cpu_info != nullptr, nullptr);
44 cpu_info->SetEnableFP16(cpu_context.device_info_.cpu_device_info_.enable_float16_);
45 PassBasicProperties(cpu_info, cpu_context);
46 return cpu_info;
47 }
48
GPUDeviceInfoFromGPUDeviceContext(const lite::DeviceContext & gpu_context)49 std::shared_ptr<mindspore::GPUDeviceInfo> GPUDeviceInfoFromGPUDeviceContext(const lite::DeviceContext &gpu_context) {
50 if (gpu_context.device_type_ != DT_GPU) {
51 MS_LOG(ERROR) << "function input parameter is not gpu context.";
52 return nullptr;
53 }
54 auto gpu_info = std::make_shared<mindspore::GPUDeviceInfo>();
55 MS_CHECK_TRUE_RET(gpu_info != nullptr, nullptr);
56 gpu_info->SetEnableFP16(gpu_context.device_info_.gpu_device_info_.enable_float16_);
57 gpu_info->SetDeviceID(gpu_context.device_info_.gpu_device_info_.gpu_device_id_);
58 PassBasicProperties(gpu_info, gpu_context);
59 return gpu_info;
60 }
61
NPUDeviceInfoFromNPUDeviceContext(const lite::DeviceContext & npu_context)62 std::shared_ptr<mindspore::KirinNPUDeviceInfo> NPUDeviceInfoFromNPUDeviceContext(
63 const lite::DeviceContext &npu_context) {
64 if (npu_context.device_type_ != DT_NPU) {
65 MS_LOG(ERROR) << "function input parameter is not npu context.";
66 return nullptr;
67 }
68 auto npu_info = std::make_shared<mindspore::KirinNPUDeviceInfo>();
69 MS_CHECK_TRUE_RET(npu_info != nullptr, nullptr);
70 npu_info->SetEnableFP16(npu_context.device_info_.npu_device_info_.enable_float16_);
71 npu_info->SetFrequency(npu_context.device_info_.npu_device_info_.frequency_);
72 PassBasicProperties(npu_info, npu_context);
73 return npu_info;
74 }
75
GetBatchSize(const std::string & batch_size)76 std::vector<size_t> GetBatchSize(const std::string &batch_size) {
77 std::vector<size_t> res;
78 std::vector<std::string> batch_size_vec = StrSplit(batch_size, ",");
79 for (const auto &item : batch_size_vec) {
80 int32_t val;
81 if (ConvertStrToInt(item, &val)) {
82 auto tmp_val = static_cast<size_t>(val);
83 res.push_back(tmp_val);
84 } else {
85 MS_LOG(ERROR) << "Convert str to num failed, val = " << item;
86 return res;
87 }
88 }
89 MS_LOG(INFO) << "Batch size of context: " << batch_size;
90 return res;
91 }
92
AscendDeviceInfoFromAscendDeviceContext(const lite::DeviceContext & ascend_context)93 std::shared_ptr<mindspore::AscendDeviceInfo> AscendDeviceInfoFromAscendDeviceContext(
94 const lite::DeviceContext &ascend_context) {
95 if (ascend_context.device_type_ != DT_ASCEND) {
96 MS_LOG(ERROR) << "Function input parameter is not ascend context.";
97 return nullptr;
98 }
99 auto ascend_info = std::make_shared<mindspore::AscendDeviceInfo>();
100 MS_CHECK_TRUE_RET(ascend_info != nullptr, nullptr);
101 ascend_info->SetDeviceID(ascend_context.device_info_.ascend_device_info_.device_id_);
102 std::string batch_size = ascend_context.device_info_.ascend_device_info_.batch_size_;
103 if (!batch_size.empty()) {
104 auto val = GetBatchSize(batch_size);
105 ascend_info->SetDynamicBatchSize(val);
106 }
107 ascend_info->SetDynamicImageSize(ascend_context.device_info_.ascend_device_info_.image_size_);
108 return ascend_info;
109 }
110
CustomDeviceInfoFromCustomDeviceContext(const lite::DeviceContext & inner_context)111 std::shared_ptr<mindspore::DeviceInfoContext> CustomDeviceInfoFromCustomDeviceContext(
112 const lite::DeviceContext &inner_context) {
113 if (inner_context.device_type_ != DT_CUSTOM) {
114 MS_LOG(ERROR) << "Function input parameter is not extended context.";
115 return nullptr;
116 }
117 auto device_info = inner_context.device_info_.custom_device_info_.user_defined_device_info_;
118 MS_CHECK_TRUE_RET(device_info != nullptr, nullptr);
119 return device_info;
120 }
121
NNRtDeviceInfoFromNNRtDeviceContext(const lite::DeviceContext & nnrt_context)122 std::shared_ptr<mindspore::NNRTDeviceInfo> NNRtDeviceInfoFromNNRtDeviceContext(
123 const lite::DeviceContext &nnrt_context) {
124 if (nnrt_context.device_type_ != DT_NNRT) {
125 MS_LOG(ERROR) << "Function input parameter is not NNRt context.";
126 return nullptr;
127 }
128 auto nnrt_info = std::make_shared<mindspore::NNRTDeviceInfo>();
129 MS_CHECK_TRUE_RET(nnrt_info != nullptr, nullptr);
130 return nnrt_info;
131 }
132 } // namespace
133
MSContextFromContext(const std::shared_ptr<InnerContext> & context)134 mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) {
135 if (context == nullptr) {
136 MS_LOG(ERROR) << "context is nullptr";
137 return nullptr;
138 }
139 auto ms_context = new (std::nothrow) mindspore::Context();
140 if (ms_context == nullptr) {
141 MS_LOG(ERROR) << "New Context failed";
142 return nullptr;
143 }
144 ms_context->SetThreadNum(context->thread_num_);
145 ms_context->SetThreadAffinity(context->affinity_core_list_);
146 #ifndef ENABLE_CLOUD_FUSION_INFERENCE
147 ms_context->SetEnableParallel(context->enable_parallel_);
148 #endif
149 if (context->delegate) {
150 ms_context->SetDelegate(context->delegate);
151 }
152 auto &device_infos = ms_context->MutableDeviceInfo();
153 std::map<DeviceType, std::function<std::shared_ptr<mindspore::DeviceInfoContext>(const lite::DeviceContext &)>>
154 transfer_funcs = {{DT_CPU, CPUDeviceInfoFromCPUDeviceContext},
155 {DT_GPU, GPUDeviceInfoFromGPUDeviceContext},
156 {DT_NPU, NPUDeviceInfoFromNPUDeviceContext},
157 {DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext},
158 {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext},
159 {DT_NNRT, NNRtDeviceInfoFromNNRtDeviceContext}};
160 for (auto &device_context : context->device_list_) {
161 auto device_type = device_context.device_type_;
162 if (transfer_funcs.find(device_type) == transfer_funcs.end()) {
163 MS_LOG(ERROR) << "device type is invalid.";
164 delete ms_context;
165 return nullptr;
166 }
167 auto device_info = transfer_funcs[device_type](device_context);
168 if (device_info == nullptr) {
169 MS_LOG(ERROR) << "transfer device context to device info failed.";
170 delete ms_context;
171 return nullptr;
172 }
173 if (device_type == DT_CPU) {
174 ms_context->SetThreadAffinity(static_cast<int>(device_context.device_info_.cpu_device_info_.cpu_bind_mode_));
175 }
176 device_infos.push_back(device_info);
177 }
178 return ms_context;
179 }
180
DeviceTypePriority(const InnerContext * context,int device_type1,int device_type2)181 bool DeviceTypePriority(const InnerContext *context, int device_type1, int device_type2) {
182 /* dt1 > dt2 true
183 * dt1 < dt2 false */
184
185 if (context == nullptr) {
186 return false;
187 }
188 for (const DeviceContext& device_info : context->device_list_) {
189 if (device_info.device_type_ == device_type1) {
190 return true;
191 }
192 if (device_info.device_type_ == device_type2) {
193 return false;
194 }
195 }
196 return false;
197 }
198
KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch)199 DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch) {
200 switch (kernel_arch) {
201 case kernel::KERNEL_ARCH::kCPU:
202 return DT_CPU;
203 case kernel::KERNEL_ARCH::kGPU:
204 return DT_GPU;
205 case kernel::KERNEL_ARCH::kNPU:
206 return DT_NPU;
207 case kernel::KERNEL_ARCH::kACL:
208 return DT_ASCEND;
209 default:
210 return DT_CPU;
211 }
212 }
213 } // namespace lite
214 } // namespace mindspore
215