• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/context.h"
17 #include "pybind11/pybind11.h"
18 #include "pybind11/stl.h"
19 
20 namespace mindspore::lite {
21 namespace py = pybind11;
22 
ContextPyBind(const py::module & m)23 void ContextPyBind(const py::module &m) {
24   (void)py::enum_<DeviceType>(m, "DeviceType", py::arithmetic())
25     .value("kCPU", DeviceType::kCPU)
26     .value("kGPU", DeviceType::kGPU)
27     .value("kKirinNPU", DeviceType::kKirinNPU)
28     .value("kAscend", DeviceType::kAscend);
29 
30   (void)py::class_<DeviceInfoContext, std::shared_ptr<DeviceInfoContext>>(m, "DeviceInfoContextBind")
31     .def("set_provider", py::overload_cast<const std::string &>(&DeviceInfoContext::SetProvider))
32     .def("get_provider", &DeviceInfoContext::GetProvider);
33 
34   (void)py::class_<CPUDeviceInfo, DeviceInfoContext, std::shared_ptr<CPUDeviceInfo>>(m, "CPUDeviceInfoBind")
35     .def(py::init<>())
36     .def("get_device_type", &CPUDeviceInfo::GetDeviceType)
37     .def("set_enable_fp16", &CPUDeviceInfo::SetEnableFP16)
38     .def("get_enable_fp16", &CPUDeviceInfo::GetEnableFP16);
39 
40   (void)py::class_<GPUDeviceInfo, DeviceInfoContext, std::shared_ptr<GPUDeviceInfo>>(m, "GPUDeviceInfoBind")
41     .def(py::init<>())
42     .def("get_device_type", &GPUDeviceInfo::GetDeviceType)
43     .def("set_device_id", &GPUDeviceInfo::SetDeviceID)
44     .def("get_device_id", &GPUDeviceInfo::GetDeviceID)
45     .def("set_enable_fp16", &GPUDeviceInfo::SetEnableFP16)
46     .def("get_enable_fp16", &GPUDeviceInfo::GetEnableFP16)
47     .def("get_rank_id", &GPUDeviceInfo::GetRankID)
48     .def("get_group_size", &GPUDeviceInfo::GetGroupSize);
49 
50   (void)py::class_<AscendDeviceInfo, DeviceInfoContext, std::shared_ptr<AscendDeviceInfo>>(m, "AscendDeviceInfoBind")
51     .def(py::init<>())
52     .def("get_device_type", &AscendDeviceInfo::GetDeviceType)
53     .def("set_device_id", &AscendDeviceInfo::SetDeviceID)
54     .def("get_device_id", &AscendDeviceInfo::GetDeviceID)
55     .def("set_rank_id", &AscendDeviceInfo::SetRankID)
56     .def("get_rank_id", &AscendDeviceInfo::GetRankID)
57     .def("set_input_format",
58          [](AscendDeviceInfo &device_info, const std::string &format) { device_info.SetInputFormat(format); })
59     .def("get_input_format", &AscendDeviceInfo::GetInputFormat)
60     .def("set_input_shape", &AscendDeviceInfo::SetInputShapeMap)
61     .def("get_input_shape", &AscendDeviceInfo::GetInputShapeMap)
62     .def("set_precision_mode", [](AscendDeviceInfo &device_info,
63                                   const std::string &precision_mode) { device_info.SetPrecisionMode(precision_mode); })
64     .def("get_precision_mode", &AscendDeviceInfo::GetPrecisionMode)
65     .def("set_op_select_impl_mode",
66          [](AscendDeviceInfo &device_info, const std::string &op_select_impl_mode) {
67            device_info.SetOpSelectImplMode(op_select_impl_mode);
68          })
69     .def("get_op_select_impl_mode", &AscendDeviceInfo::GetOpSelectImplMode)
70     .def("set_dynamic_batch_size", &AscendDeviceInfo::SetDynamicBatchSize)
71     .def("get_dynamic_batch_size", &AscendDeviceInfo::GetDynamicBatchSize)
72     .def("set_dynamic_image_size",
73          [](AscendDeviceInfo &device_info, const std::string &dynamic_image_size) {
74            device_info.SetDynamicImageSize(dynamic_image_size);
75          })
76     .def("get_dynamic_image_size", &AscendDeviceInfo::GetDynamicImageSize)
77     .def("set_fusion_switch_config_path",
78          [](AscendDeviceInfo &device_info, const std::string &cfg_path) {
79            device_info.SetFusionSwitchConfigPath(cfg_path);
80          })
81     .def("get_fusion_switch_config_path", &AscendDeviceInfo::GetFusionSwitchConfigPath)
82     .def("set_insert_op_cfg_path", [](AscendDeviceInfo &device_info,
83                                       const std::string &cfg_path) { device_info.SetInsertOpConfigPath(cfg_path); })
84     .def("get_insert_op_cfg_path", &AscendDeviceInfo::GetInsertOpConfigPath);
85 
86   (void)py::class_<Context, std::shared_ptr<Context>>(m, "ContextBind")
87     .def(py::init<>())
88     .def("append_device_info",
89          [](Context &context, const std::shared_ptr<DeviceInfoContext> &device_info) {
90            context.MutableDeviceInfo().push_back(device_info);
91          })
92     .def("clear_device_info", [](Context &context) { context.MutableDeviceInfo().clear(); })
93     .def("set_thread_num", &Context::SetThreadNum)
94     .def("get_thread_num", &Context::GetThreadNum)
95     .def("set_inter_op_parallel_num", &Context::SetInterOpParallelNum)
96     .def("get_inter_op_parallel_num", &Context::GetInterOpParallelNum)
97     .def("set_group_info_file", &Context::SetGroupInfoFile)
98     .def("get_group_info_file", &Context::GetGroupInfoFile)
99     .def("set_thread_affinity_mode", py::overload_cast<int>(&Context::SetThreadAffinity))
100     .def("get_thread_affinity_mode", &Context::GetThreadAffinityMode)
101     .def("set_thread_affinity_core_list", py::overload_cast<const std::vector<int> &>(&Context::SetThreadAffinity))
102     .def("get_thread_affinity_core_list", &Context::GetThreadAffinityCoreList)
103     .def("set_enable_parallel", &Context::SetEnableParallel)
104     .def("get_enable_parallel", &Context::GetEnableParallel)
105     .def("get_device_list", [](Context &context) {
106       std::string result;
107       auto &device_list = context.MutableDeviceInfo();
108       for (auto &device : device_list) {
109         result += std::to_string(device->GetDeviceType());
110         result += ", ";
111       }
112       return result;
113     });
114 }
115 }  // namespace mindspore::lite
116