1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/api/context.h"
17 #include "pybind11/pybind11.h"
18 #include "pybind11/stl.h"
19
20 namespace mindspore::lite {
21 namespace py = pybind11;
22
ContextPyBind(const py::module & m)23 void ContextPyBind(const py::module &m) {
24 (void)py::enum_<DeviceType>(m, "DeviceType", py::arithmetic())
25 .value("kCPU", DeviceType::kCPU)
26 .value("kGPU", DeviceType::kGPU)
27 .value("kKirinNPU", DeviceType::kKirinNPU)
28 .value("kAscend", DeviceType::kAscend);
29
30 (void)py::class_<DeviceInfoContext, std::shared_ptr<DeviceInfoContext>>(m, "DeviceInfoContextBind")
31 .def("set_provider", py::overload_cast<const std::string &>(&DeviceInfoContext::SetProvider))
32 .def("get_provider", &DeviceInfoContext::GetProvider);
33
34 (void)py::class_<CPUDeviceInfo, DeviceInfoContext, std::shared_ptr<CPUDeviceInfo>>(m, "CPUDeviceInfoBind")
35 .def(py::init<>())
36 .def("get_device_type", &CPUDeviceInfo::GetDeviceType)
37 .def("set_enable_fp16", &CPUDeviceInfo::SetEnableFP16)
38 .def("get_enable_fp16", &CPUDeviceInfo::GetEnableFP16);
39
40 (void)py::class_<GPUDeviceInfo, DeviceInfoContext, std::shared_ptr<GPUDeviceInfo>>(m, "GPUDeviceInfoBind")
41 .def(py::init<>())
42 .def("get_device_type", &GPUDeviceInfo::GetDeviceType)
43 .def("set_device_id", &GPUDeviceInfo::SetDeviceID)
44 .def("get_device_id", &GPUDeviceInfo::GetDeviceID)
45 .def("set_enable_fp16", &GPUDeviceInfo::SetEnableFP16)
46 .def("get_enable_fp16", &GPUDeviceInfo::GetEnableFP16)
47 .def("get_rank_id", &GPUDeviceInfo::GetRankID)
48 .def("get_group_size", &GPUDeviceInfo::GetGroupSize);
49
50 (void)py::class_<AscendDeviceInfo, DeviceInfoContext, std::shared_ptr<AscendDeviceInfo>>(m, "AscendDeviceInfoBind")
51 .def(py::init<>())
52 .def("get_device_type", &AscendDeviceInfo::GetDeviceType)
53 .def("set_device_id", &AscendDeviceInfo::SetDeviceID)
54 .def("get_device_id", &AscendDeviceInfo::GetDeviceID)
55 .def("set_rank_id", &AscendDeviceInfo::SetRankID)
56 .def("get_rank_id", &AscendDeviceInfo::GetRankID)
57 .def("set_input_format",
58 [](AscendDeviceInfo &device_info, const std::string &format) { device_info.SetInputFormat(format); })
59 .def("get_input_format", &AscendDeviceInfo::GetInputFormat)
60 .def("set_input_shape", &AscendDeviceInfo::SetInputShapeMap)
61 .def("get_input_shape", &AscendDeviceInfo::GetInputShapeMap)
62 .def("set_precision_mode", [](AscendDeviceInfo &device_info,
63 const std::string &precision_mode) { device_info.SetPrecisionMode(precision_mode); })
64 .def("get_precision_mode", &AscendDeviceInfo::GetPrecisionMode)
65 .def("set_op_select_impl_mode",
66 [](AscendDeviceInfo &device_info, const std::string &op_select_impl_mode) {
67 device_info.SetOpSelectImplMode(op_select_impl_mode);
68 })
69 .def("get_op_select_impl_mode", &AscendDeviceInfo::GetOpSelectImplMode)
70 .def("set_dynamic_batch_size", &AscendDeviceInfo::SetDynamicBatchSize)
71 .def("get_dynamic_batch_size", &AscendDeviceInfo::GetDynamicBatchSize)
72 .def("set_dynamic_image_size",
73 [](AscendDeviceInfo &device_info, const std::string &dynamic_image_size) {
74 device_info.SetDynamicImageSize(dynamic_image_size);
75 })
76 .def("get_dynamic_image_size", &AscendDeviceInfo::GetDynamicImageSize)
77 .def("set_fusion_switch_config_path",
78 [](AscendDeviceInfo &device_info, const std::string &cfg_path) {
79 device_info.SetFusionSwitchConfigPath(cfg_path);
80 })
81 .def("get_fusion_switch_config_path", &AscendDeviceInfo::GetFusionSwitchConfigPath)
82 .def("set_insert_op_cfg_path", [](AscendDeviceInfo &device_info,
83 const std::string &cfg_path) { device_info.SetInsertOpConfigPath(cfg_path); })
84 .def("get_insert_op_cfg_path", &AscendDeviceInfo::GetInsertOpConfigPath);
85
86 (void)py::class_<Context, std::shared_ptr<Context>>(m, "ContextBind")
87 .def(py::init<>())
88 .def("append_device_info",
89 [](Context &context, const std::shared_ptr<DeviceInfoContext> &device_info) {
90 context.MutableDeviceInfo().push_back(device_info);
91 })
92 .def("clear_device_info", [](Context &context) { context.MutableDeviceInfo().clear(); })
93 .def("set_thread_num", &Context::SetThreadNum)
94 .def("get_thread_num", &Context::GetThreadNum)
95 .def("set_inter_op_parallel_num", &Context::SetInterOpParallelNum)
96 .def("get_inter_op_parallel_num", &Context::GetInterOpParallelNum)
97 .def("set_group_info_file", &Context::SetGroupInfoFile)
98 .def("get_group_info_file", &Context::GetGroupInfoFile)
99 .def("set_thread_affinity_mode", py::overload_cast<int>(&Context::SetThreadAffinity))
100 .def("get_thread_affinity_mode", &Context::GetThreadAffinityMode)
101 .def("set_thread_affinity_core_list", py::overload_cast<const std::vector<int> &>(&Context::SetThreadAffinity))
102 .def("get_thread_affinity_core_list", &Context::GetThreadAffinityCoreList)
103 .def("set_enable_parallel", &Context::SetEnableParallel)
104 .def("get_enable_parallel", &Context::GetEnableParallel)
105 .def("get_device_list", [](Context &context) {
106 std::string result;
107 auto &device_list = context.MutableDeviceInfo();
108 for (auto &device : device_list) {
109 result += std::to_string(device->GetDeviceType());
110 result += ", ";
111 }
112 return result;
113 });
114 }
115 } // namespace mindspore::lite
116