• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/converter.h"
17 #include "pybind11/pybind11.h"
18 #include "pybind11/stl.h"
19 #include "pybind11/functional.h"
20 
21 namespace mindspore::lite {
22 namespace py = pybind11;
23 
ConverterPyBind(const py::module & m)24 void ConverterPyBind(const py::module &m) {
25   (void)py::enum_<converter::FmkType>(m, "FmkType")
26     .value("kFmkTypeTf", converter::FmkType::kFmkTypeTf)
27     .value("kFmkTypeCaffe", converter::FmkType::kFmkTypeCaffe)
28     .value("kFmkTypeOnnx", converter::FmkType::kFmkTypeOnnx)
29     .value("kFmkTypeMs", converter::FmkType::kFmkTypeMs)
30     .value("kFmkTypeTflite", converter::FmkType::kFmkTypeTflite)
31     .value("kFmkTypePytorch", converter::FmkType::kFmkTypePytorch);
32 
33   (void)py::class_<Converter, std::shared_ptr<Converter>>(m, "ConverterBind")
34     .def(py::init<>())
35     .def("set_config_file", py::overload_cast<const std::string &>(&Converter::SetConfigFile))
36     .def("get_config_file", &Converter::GetConfigFile)
37     .def("set_config_info",
38          py::overload_cast<const std::string &, const std::map<std::string, std::string> &>(&Converter::SetConfigInfo))
39     .def("get_config_info", &Converter::GetConfigInfo)
40     .def("set_weight_fp16", &Converter::SetWeightFp16)
41     .def("get_weight_fp16", &Converter::GetWeightFp16)
42     .def("set_input_shape",
43          py::overload_cast<const std::map<std::string, std::vector<int64_t>> &>(&Converter::SetInputShape))
44     .def("get_input_shape", &Converter::GetInputShape)
45     .def("set_input_format", &Converter::SetInputFormat)
46     .def("get_input_format", &Converter::GetInputFormat)
47     .def("set_input_data_type", &Converter::SetInputDataType)
48     .def("get_input_data_type", &Converter::GetInputDataType)
49     .def("set_output_data_type", &Converter::SetOutputDataType)
50     .def("get_output_data_type", &Converter::GetOutputDataType)
51     .def("set_save_type", &Converter::SetSaveType)
52     .def("get_save_type", &Converter::GetSaveType)
53     .def("set_decrypt_key", py::overload_cast<const std::string &>(&Converter::SetDecryptKey))
54     .def("get_decrypt_key", &Converter::GetDecryptKey)
55     .def("set_decrypt_mode", py::overload_cast<const std::string &>(&Converter::SetDecryptMode))
56     .def("get_decrypt_mode", &Converter::GetDecryptMode)
57     .def("set_enable_encryption", &Converter::SetEnableEncryption)
58     .def("get_enable_encryption", &Converter::GetEnableEncryption)
59     .def("set_encrypt_key", py::overload_cast<const std::string &>(&Converter::SetEncryptKey))
60     .def("get_encrypt_key", &Converter::GetEncryptKey)
61     .def("set_infer", &Converter::SetInfer)
62     .def("get_infer", &Converter::GetInfer)
63 #if !defined(ENABLE_CLOUD_FUSION_INFERENCE) && !defined(ENABLE_CLOUD_INFERENCE)
64     .def("set_train_model", &Converter::SetTrainModel)
65     .def("get_train_model", &Converter::GetTrainModel)
66 #endif
67     .def("set_no_fusion", &Converter::SetNoFusion)
68     .def("get_no_fusion", &Converter::GetNoFusion)
69     .def("set_device", py::overload_cast<const std::string &>(&Converter::SetDevice))
70     .def("get_device", &Converter::GetDevice)
71     .def("set_chip_name", py::overload_cast<const std::string &>(&Converter::SetChipName))
72     .def("get_chip_name", &Converter::GetChipName)
73     .def("set_device_id", &Converter::SetDeviceId)
74     .def("get_device_id", &Converter::GetDeviceId)
75     .def("set_rank_id", &Converter::SetRankId)
76     .def("get_rank_id", &Converter::GetRankId)
77     .def("convert",
78          py::overload_cast<converter::FmkType, const std::string &, const std::string &, const std::string &>(
79            &Converter::Convert),
80          py::call_guard<py::gil_scoped_release>());
81 }
82 }  // namespace mindspore::lite
83