1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/converter.h"
17 #include "pybind11/pybind11.h"
18 #include "pybind11/stl.h"
19 #include "pybind11/functional.h"
20
21 namespace mindspore::lite {
22 namespace py = pybind11;
23
ConverterPyBind(const py::module & m)24 void ConverterPyBind(const py::module &m) {
25 (void)py::enum_<converter::FmkType>(m, "FmkType")
26 .value("kFmkTypeTf", converter::FmkType::kFmkTypeTf)
27 .value("kFmkTypeCaffe", converter::FmkType::kFmkTypeCaffe)
28 .value("kFmkTypeOnnx", converter::FmkType::kFmkTypeOnnx)
29 .value("kFmkTypeMs", converter::FmkType::kFmkTypeMs)
30 .value("kFmkTypeTflite", converter::FmkType::kFmkTypeTflite)
31 .value("kFmkTypePytorch", converter::FmkType::kFmkTypePytorch);
32
33 (void)py::class_<Converter, std::shared_ptr<Converter>>(m, "ConverterBind")
34 .def(py::init<>())
35 .def("set_config_file", py::overload_cast<const std::string &>(&Converter::SetConfigFile))
36 .def("get_config_file", &Converter::GetConfigFile)
37 .def("set_config_info",
38 py::overload_cast<const std::string &, const std::map<std::string, std::string> &>(&Converter::SetConfigInfo))
39 .def("get_config_info", &Converter::GetConfigInfo)
40 .def("set_weight_fp16", &Converter::SetWeightFp16)
41 .def("get_weight_fp16", &Converter::GetWeightFp16)
42 .def("set_input_shape",
43 py::overload_cast<const std::map<std::string, std::vector<int64_t>> &>(&Converter::SetInputShape))
44 .def("get_input_shape", &Converter::GetInputShape)
45 .def("set_input_format", &Converter::SetInputFormat)
46 .def("get_input_format", &Converter::GetInputFormat)
47 .def("set_input_data_type", &Converter::SetInputDataType)
48 .def("get_input_data_type", &Converter::GetInputDataType)
49 .def("set_output_data_type", &Converter::SetOutputDataType)
50 .def("get_output_data_type", &Converter::GetOutputDataType)
51 .def("set_save_type", &Converter::SetSaveType)
52 .def("get_save_type", &Converter::GetSaveType)
53 .def("set_decrypt_key", py::overload_cast<const std::string &>(&Converter::SetDecryptKey))
54 .def("get_decrypt_key", &Converter::GetDecryptKey)
55 .def("set_decrypt_mode", py::overload_cast<const std::string &>(&Converter::SetDecryptMode))
56 .def("get_decrypt_mode", &Converter::GetDecryptMode)
57 .def("set_enable_encryption", &Converter::SetEnableEncryption)
58 .def("get_enable_encryption", &Converter::GetEnableEncryption)
59 .def("set_encrypt_key", py::overload_cast<const std::string &>(&Converter::SetEncryptKey))
60 .def("get_encrypt_key", &Converter::GetEncryptKey)
61 .def("set_infer", &Converter::SetInfer)
62 .def("get_infer", &Converter::GetInfer)
63 #if !defined(ENABLE_CLOUD_FUSION_INFERENCE) && !defined(ENABLE_CLOUD_INFERENCE)
64 .def("set_train_model", &Converter::SetTrainModel)
65 .def("get_train_model", &Converter::GetTrainModel)
66 #endif
67 .def("set_no_fusion", &Converter::SetNoFusion)
68 .def("get_no_fusion", &Converter::GetNoFusion)
69 .def("set_device", py::overload_cast<const std::string &>(&Converter::SetDevice))
70 .def("get_device", &Converter::GetDevice)
71 .def("set_chip_name", py::overload_cast<const std::string &>(&Converter::SetChipName))
72 .def("get_chip_name", &Converter::GetChipName)
73 .def("set_device_id", &Converter::SetDeviceId)
74 .def("get_device_id", &Converter::GetDeviceId)
75 .def("set_rank_id", &Converter::SetRankId)
76 .def("get_rank_id", &Converter::GetRankId)
77 .def("convert",
78 py::overload_cast<converter::FmkType, const std::string &, const std::string &, const std::string &>(
79 &Converter::Convert),
80 py::call_guard<py::gil_scoped_release>());
81 }
82 } // namespace mindspore::lite
83