• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/common/context_util.h"
18 #include <map>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include "src/common/log_adapter.h"
24 #include "src/common/utils.h"
25 
26 namespace mindspore {
27 namespace lite {
28 namespace {
29 template <class T>
PassBasicProperties(std::shared_ptr<T> device_info,const lite::DeviceContext & device_context)30 void PassBasicProperties(std::shared_ptr<T> device_info, const lite::DeviceContext &device_context) {
31   MS_ASSERT(device_info != nullptr);
32   device_info->SetProvider(device_context.provider_);
33   device_info->SetProviderDevice(device_context.provider_device_);
34   device_info->SetAllocator(device_context.allocator_);
35 }
36 
CPUDeviceInfoFromCPUDeviceContext(const lite::DeviceContext & cpu_context)37 std::shared_ptr<mindspore::CPUDeviceInfo> CPUDeviceInfoFromCPUDeviceContext(const lite::DeviceContext &cpu_context) {
38   if (cpu_context.device_type_ != DT_CPU) {
39     MS_LOG(ERROR) << "function input parameter is not cpu context.";
40     return nullptr;
41   }
42   auto cpu_info = std::make_shared<mindspore::CPUDeviceInfo>();
43   MS_CHECK_TRUE_RET(cpu_info != nullptr, nullptr);
44   cpu_info->SetEnableFP16(cpu_context.device_info_.cpu_device_info_.enable_float16_);
45   PassBasicProperties(cpu_info, cpu_context);
46   return cpu_info;
47 }
48 
GPUDeviceInfoFromGPUDeviceContext(const lite::DeviceContext & gpu_context)49 std::shared_ptr<mindspore::GPUDeviceInfo> GPUDeviceInfoFromGPUDeviceContext(const lite::DeviceContext &gpu_context) {
50   if (gpu_context.device_type_ != DT_GPU) {
51     MS_LOG(ERROR) << "function input parameter is not gpu context.";
52     return nullptr;
53   }
54   auto gpu_info = std::make_shared<mindspore::GPUDeviceInfo>();
55   MS_CHECK_TRUE_RET(gpu_info != nullptr, nullptr);
56   gpu_info->SetEnableFP16(gpu_context.device_info_.gpu_device_info_.enable_float16_);
57   gpu_info->SetDeviceID(gpu_context.device_info_.gpu_device_info_.gpu_device_id_);
58   PassBasicProperties(gpu_info, gpu_context);
59   return gpu_info;
60 }
61 
NPUDeviceInfoFromNPUDeviceContext(const lite::DeviceContext & npu_context)62 std::shared_ptr<mindspore::KirinNPUDeviceInfo> NPUDeviceInfoFromNPUDeviceContext(
63   const lite::DeviceContext &npu_context) {
64   if (npu_context.device_type_ != DT_NPU) {
65     MS_LOG(ERROR) << "function input parameter is not npu context.";
66     return nullptr;
67   }
68   auto npu_info = std::make_shared<mindspore::KirinNPUDeviceInfo>();
69   MS_CHECK_TRUE_RET(npu_info != nullptr, nullptr);
70   npu_info->SetEnableFP16(npu_context.device_info_.npu_device_info_.enable_float16_);
71   npu_info->SetFrequency(npu_context.device_info_.npu_device_info_.frequency_);
72   PassBasicProperties(npu_info, npu_context);
73   return npu_info;
74 }
75 
GetBatchSize(const std::string & batch_size)76 std::vector<size_t> GetBatchSize(const std::string &batch_size) {
77   std::vector<size_t> res;
78   std::vector<std::string> batch_size_vec = StrSplit(batch_size, ",");
79   for (const auto &item : batch_size_vec) {
80     int32_t val;
81     if (ConvertStrToInt(item, &val)) {
82       auto tmp_val = static_cast<size_t>(val);
83       res.push_back(tmp_val);
84     } else {
85       MS_LOG(ERROR) << "Convert str to num failed, val = " << item;
86       return res;
87     }
88   }
89   MS_LOG(INFO) << "Batch size of context: " << batch_size;
90   return res;
91 }
92 
AscendDeviceInfoFromAscendDeviceContext(const lite::DeviceContext & ascend_context)93 std::shared_ptr<mindspore::AscendDeviceInfo> AscendDeviceInfoFromAscendDeviceContext(
94   const lite::DeviceContext &ascend_context) {
95   if (ascend_context.device_type_ != DT_ASCEND) {
96     MS_LOG(ERROR) << "Function input parameter is not ascend context.";
97     return nullptr;
98   }
99   auto ascend_info = std::make_shared<mindspore::AscendDeviceInfo>();
100   MS_CHECK_TRUE_RET(ascend_info != nullptr, nullptr);
101   ascend_info->SetDeviceID(ascend_context.device_info_.ascend_device_info_.device_id_);
102   std::string batch_size = ascend_context.device_info_.ascend_device_info_.batch_size_;
103   if (!batch_size.empty()) {
104     auto val = GetBatchSize(batch_size);
105     ascend_info->SetDynamicBatchSize(val);
106   }
107   ascend_info->SetDynamicImageSize(ascend_context.device_info_.ascend_device_info_.image_size_);
108   return ascend_info;
109 }
110 
CustomDeviceInfoFromCustomDeviceContext(const lite::DeviceContext & inner_context)111 std::shared_ptr<mindspore::DeviceInfoContext> CustomDeviceInfoFromCustomDeviceContext(
112   const lite::DeviceContext &inner_context) {
113   if (inner_context.device_type_ != DT_CUSTOM) {
114     MS_LOG(ERROR) << "Function input parameter is not extended context.";
115     return nullptr;
116   }
117   auto device_info = inner_context.device_info_.custom_device_info_.user_defined_device_info_;
118   MS_CHECK_TRUE_RET(device_info != nullptr, nullptr);
119   return device_info;
120 }
121 
NNRtDeviceInfoFromNNRtDeviceContext(const lite::DeviceContext & nnrt_context)122 std::shared_ptr<mindspore::NNRTDeviceInfo> NNRtDeviceInfoFromNNRtDeviceContext(
123   const lite::DeviceContext &nnrt_context) {
124   if (nnrt_context.device_type_ != DT_NNRT) {
125     MS_LOG(ERROR) << "Function input parameter is not NNRt context.";
126     return nullptr;
127   }
128   auto nnrt_info = std::make_shared<mindspore::NNRTDeviceInfo>();
129   MS_CHECK_TRUE_RET(nnrt_info != nullptr, nullptr);
130   return nnrt_info;
131 }
132 }  // namespace
133 
MSContextFromContext(const std::shared_ptr<InnerContext> & context)134 mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) {
135   if (context == nullptr) {
136     MS_LOG(ERROR) << "context is nullptr";
137     return nullptr;
138   }
139   auto ms_context = new (std::nothrow) mindspore::Context();
140   if (ms_context == nullptr) {
141     MS_LOG(ERROR) << "New Context failed";
142     return nullptr;
143   }
144   ms_context->SetThreadNum(context->thread_num_);
145   ms_context->SetThreadAffinity(context->affinity_core_list_);
146 #ifndef ENABLE_CLOUD_FUSION_INFERENCE
147   ms_context->SetEnableParallel(context->enable_parallel_);
148 #endif
149   if (context->delegate) {
150     ms_context->SetDelegate(context->delegate);
151   }
152   auto &device_infos = ms_context->MutableDeviceInfo();
153   std::map<DeviceType, std::function<std::shared_ptr<mindspore::DeviceInfoContext>(const lite::DeviceContext &)>>
154     transfer_funcs = {{DT_CPU, CPUDeviceInfoFromCPUDeviceContext},
155                       {DT_GPU, GPUDeviceInfoFromGPUDeviceContext},
156                       {DT_NPU, NPUDeviceInfoFromNPUDeviceContext},
157                       {DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext},
158                       {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext},
159                       {DT_NNRT, NNRtDeviceInfoFromNNRtDeviceContext}};
160   for (auto &device_context : context->device_list_) {
161     auto device_type = device_context.device_type_;
162     if (transfer_funcs.find(device_type) == transfer_funcs.end()) {
163       MS_LOG(ERROR) << "device type is invalid.";
164       delete ms_context;
165       return nullptr;
166     }
167     auto device_info = transfer_funcs[device_type](device_context);
168     if (device_info == nullptr) {
169       MS_LOG(ERROR) << "transfer device context to device info failed.";
170       delete ms_context;
171       return nullptr;
172     }
173     if (device_type == DT_CPU) {
174       ms_context->SetThreadAffinity(static_cast<int>(device_context.device_info_.cpu_device_info_.cpu_bind_mode_));
175     }
176     device_infos.push_back(device_info);
177   }
178   return ms_context;
179 }
180 
DeviceTypePriority(const InnerContext * context,int device_type1,int device_type2)181 bool DeviceTypePriority(const InnerContext *context, int device_type1, int device_type2) {
182   /* dt1 > dt2    true
183    * dt1 < dt2    false    */
184 
185   if (context == nullptr) {
186     return false;
187   }
188   for (const DeviceContext& device_info : context->device_list_) {
189     if (device_info.device_type_ == device_type1) {
190       return true;
191     }
192     if (device_info.device_type_ == device_type2) {
193       return false;
194     }
195   }
196   return false;
197 }
198 
KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch)199 DeviceType KernelArchToDeviceType(kernel::KERNEL_ARCH kernel_arch) {
200   switch (kernel_arch) {
201     case kernel::KERNEL_ARCH::kCPU:
202       return DT_CPU;
203     case kernel::KERNEL_ARCH::kGPU:
204       return DT_GPU;
205     case kernel::KERNEL_ARCH::kNPU:
206       return DT_NPU;
207     case kernel::KERNEL_ARCH::kACL:
208       return DT_ASCEND;
209     default:
210       return DT_CPU;
211   }
212 }
213 }  // namespace lite
214 }  // namespace mindspore
215