From c2f6e9f2c5cc9ec0260437a1dfd70b5d535ee740 Mon Sep 17 00:00:00 2001 From: Zhu Guodong Date: Thu, 1 Jun 2023 22:02:23 +0800 Subject: [PATCH] auto-apply 0006-Support-converting-THIRDPARTY-model-in-MSLite.patch --- .../lite/include/registry/converter_context.h | 2 + mindspore/lite/test/CMakeLists.txt | 4 +- mindspore/lite/test/runtest.sh | 6 +- .../test/ut/test_data/third_party_model.cfg | 8 + .../tools/converter/api/converter_api_test.cc | 28 ++ .../third_party_param_parser_test.cc | 176 +++++++++++ mindspore/lite/tools/converter/CMakeLists.txt | 4 + .../config_parser/config_file_parser.cc | 27 ++ .../config_parser/config_file_parser.h | 17 + .../config_parser/third_party_param_parser.cc | 299 ++++++++++++++++++ .../config_parser/third_party_param_parser.h | 44 +++ mindspore/lite/tools/converter/converter.cc | 58 ++-- .../converter_lite/converter_flags.cc | 8 +- .../tools/converter/cxx_api/converter_para.h | 14 + .../tools/converter/graphdef_transform.cc | 44 +++ .../parser/third_party/CMakeLists.txt | 4 + .../third_party/third_party_model_parser.cc | 277 ++++++++++++++++ .../third_party/third_party_model_parser.h | 50 +++ .../registry/model_parser_registry.cc | 4 +- 19 files changed, 1045 insertions(+), 29 deletions(-) create mode 100644 mindspore/lite/test/ut/test_data/third_party_model.cfg create mode 100644 mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc create mode 100644 mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.h create mode 100644 mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h index a92a3a34..dd6e6d08 100644 --- a/mindspore/lite/include/registry/converter_context.h +++ b/mindspore/lite/include/registry/converter_context.h @@ -33,6 +33,8 @@ enum MS_API FmkType : int { kFmkTypeMs = 3, kFmkTypeTflite = 4, kFmkTypePytorch = 5, + kFmkTypeThirdParty = 6, + kFmkTypeEnd = 7, // For range check purpose, valid range: [0, kFmkTypeEnd) }; /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index c0ba8e39..5fa7bea0 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -120,6 +120,8 @@ if(MSLITE_ENABLE_CONVERTER) file(GLOB_RECURSE TEST_CONVERTER_UT_SRC ${TEST_DIR}/ut/tools/converter/registry/*.cc ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc + ${TEST_DIR}/ut/tools/converter/api/*.cc + ${TEST_DIR}/ut/tools/converter/config_parser/*.cc ${TEST_DIR}/st/converter_test.cc ${TEST_DIR}/st/delegate_test.cc ${TEST_DIR}/st/mindrt_parallel_test.cc @@ -234,7 +236,7 @@ endif() if(MSLITE_ENABLE_CONVERTER) target_link_libraries(lite-test-converter tflite_parser_mid caffe_parser_mid - onnx_parser_mid tf_parser_mid) + onnx_parser_mid tf_parser_mid third_party_parser_mid) endif() if(ENABLE_MODEL_OBF) diff --git a/mindspore/lite/test/runtest.sh b/mindspore/lite/test/runtest.sh index 57c9c0aa..921d3fb3 100644 --- a/mindspore/lite/test/runtest.sh +++ b/mindspore/lite/test/runtest.sh @@ -61,10 +61,12 @@ echo 'run common ut tests' # test cases of INT8 OP ## ./lite-test --gtest_filter=TestPadInt8.* ./lite-test --gtest_filter=TestDeconvInt8.* -if [ "$ENABLE_CONVERTER_TEST" = true ];then +if [ "$ENABLE_CONVERTER_TEST" = true ]; then ./lite-test-converter --gtest_filter="ModelParserRegistryTest.TestRegistry" ./lite-test-converter --gtest_filter="NodeParserRegistryTest.TestRegistry" ./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry" + ./lite-test-converter --gtest_filter="TestConverterAPI.*" + ./lite-test-converter --gtest_filter="TestThirdPartyParamParser.*" fi ./lite-test --gtest_filter="TestRegistry.TestAdd" ./lite-test --gtest_filter="TestRegistryCustomOp.TestCustomAdd" @@ -87,7 +89,7 @@ echo 'run inference ut tests' ./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel" echo 'run mindrt parallel ut test' -if [ "$ENABLE_CONVERTER_TEST" = true ];then +if [ "$ENABLE_CONVERTER_TEST" = true ]; then ./lite-test-converter --gtest_filter="MindrtParallelTest.*" echo 'user set output tensors st test' ./lite-test --gtest_filter="GraphTest.UserSetGraphOutput*" diff --git a/mindspore/lite/test/ut/test_data/third_party_model.cfg b/mindspore/lite/test/ut/test_data/third_party_model.cfg new file mode 100644 index 00000000..b5fcba75 --- /dev/null +++ b/mindspore/lite/test/ut/test_data/third_party_model.cfg @@ -0,0 +1,8 @@ +[third_party_model] +input_names=demo_in_0;demo_in_1;demo_in_2 +input_dtypes=float32;float16;float64 +input_shapes=1;2,3;4,5,6 +output_names=demo_out_0;demo_out_1;demo_out_2;demo_out_4 +output_dtypes=int32;int16;int8;uint8 +output_shapes=10;20,30;40;50,60,70 +extended_parameters=foo:foo_value;bar:bar_value diff --git a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc new file mode 100644 index 00000000..0d434575 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "include/converter.h" + +TEST(TestConverterAPI, ConvertThirdParty) { + std::string third_party_model = "./relu.mindir"; + std::string config_model = "./third_party_model.cfg"; + std::string output_model = "./demo_third_party.ms"; + + mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeThirdParty, third_party_model, output_model); + converter.SetConfigFile(config_model); + ASSERT_TRUE(converter.Convert().IsOk()); +} diff --git a/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc new file mode 100644 index 00000000..c8eb5536 --- /dev/null +++ b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc @@ -0,0 +1,176 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "tools/converter/config_parser/third_party_param_parser.h" + +using mindspore::ThirdPartyModelParam; +using mindspore::TypeId; +using mindspore::lite::RET_OK; +using mindspore::lite::ThirdPartyModelString; +using mindspore::lite::ThirdPartyParamParser; + +const ThirdPartyModelString kDemoSISOParam = { + // SISO is short for single-input-single-output. + .input_dtypes = "float32", + .input_shapes = "1,2,3,4", + .input_names = "siso_input", + .output_dtypes = "int32", + .output_shapes = "2", + .output_names = "siso_output", + .extended_parameters = "siso_foo:siso_foo_value;siso_bar:siso_bar_value", +}; + +const ThirdPartyModelString kDemoMIMOParam = { + // MIMO is short for multiple-input-multiple-output. + .input_dtypes = "float32;int8;float16", + .input_shapes = "1,2,3,4;5,6;7,8,9", + .input_names = "mimo_in_0;mimo_in_1;mimo_in_2", + .output_dtypes = "int32;float32", + .output_shapes = "2,4;10,20,30", + .output_names = "mimo_out_0;mimo_out_1", + .extended_parameters = "mimo_foo:mimo_foo_value;mimo_bar:mimo_bar_value", +}; + +TEST(TestThirdPartyParamParser, ParseSISOParam) { + ThirdPartyModelString param_string = kDemoSISOParam; + ThirdPartyModelParam result; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + ASSERT_EQ(result.input_names, std::vector{"siso_input"}); + ASSERT_EQ(result.input_shapes.size(), 1U); + std::vector expect_in_shape = {1, 2, 3, 4}; + ASSERT_EQ(result.input_shapes[0], expect_in_shape); + ASSERT_EQ(result.input_dtypes, std::vector{TypeId::kNumberTypeFloat32}); + + ASSERT_EQ(result.output_names, std::vector{"siso_output"}); + ASSERT_EQ(result.output_shapes.size(), 1U); + std::vector expect_out_shape = {2}; + ASSERT_EQ(result.output_shapes[0], expect_out_shape); + ASSERT_EQ(result.output_dtypes, std::vector{TypeId::kNumberTypeInt32}); + + const auto &ext_param = result.extended_parameters; + ASSERT_EQ(ext_param.size(), 2U); + ASSERT_TRUE(ext_param.find("siso_foo") != ext_param.end()); + auto expect_foo_value = ext_param.at("siso_foo"); + ASSERT_EQ(std::string(expect_foo_value.begin(), expect_foo_value.end()), "siso_foo_value"); + ASSERT_TRUE(ext_param.find("siso_bar") != ext_param.end()); + auto expect_bar_value = ext_param.at("siso_bar"); + ASSERT_EQ(std::string(expect_bar_value.begin(), expect_bar_value.end()), "siso_bar_value"); +} + +TEST(TestThirdPartyParamParser, ParseValidDtype) { + ThirdPartyModelString param_string = kDemoSISOParam; + const std::vector kValidDtypeStrings = { + "float64", "float32", "float16", "int64", "int32", "int16", "int8", "uint8", "bool", + }; + + const std::vector kExpects = { + TypeId::kNumberTypeFloat64, TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat16, + TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt16, + TypeId::kNumberTypeInt8, TypeId::kNumberTypeUInt8, TypeId::kNumberTypeBool}; + + for (size_t i = 0; i < kValidDtypeStrings.size(); i++) { + param_string.input_dtypes = kValidDtypeStrings[i]; + ThirdPartyModelParam result; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + ASSERT_EQ(result.input_dtypes[0], kExpects[i]); + } +} + +TEST(TestThirdPartyParamParser, ParseInvalidDtype) { + ThirdPartyModelParam result; + ThirdPartyModelString param_string = kDemoSISOParam; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + param_string.input_dtypes = "bad_dtype"; + ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); +} + +TEST(TestThirdPartyParamParser, ParseValidShape) { + ThirdPartyModelString param_string = kDemoSISOParam; + param_string.input_shapes = "256,256,1024,96"; // Only support fixed shape. + ThirdPartyModelParam result; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + std::vector expect = {256, 256, 1024, 96}; + ASSERT_EQ(result.input_shapes[0], expect); +} + +TEST(TestThirdPartyParamParser, ParseInvalidShape) { + ThirdPartyModelParam result; + ThirdPartyModelString param_string = kDemoSISOParam; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + param_string.input_shapes = "256,256,1024,-1"; + ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + param_string.input_shapes = "256,256,0,96"; + ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + param_string.input_shapes = "256,-256,1024,96"; + ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + param_string.input_shapes = "256,foo,1024,96"; + ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); +} + +TEST(TestThirdPartyParamParser, ParseDefaultName) { + ThirdPartyModelParam result; + ThirdPartyModelString param_string = kDemoSISOParam; + param_string.input_names = ""; + param_string.output_names = ""; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + ASSERT_EQ(result.input_names[0], "in_0"); + ASSERT_EQ(result.output_names[0], "out_0"); +} + +TEST(TestThirdPartyParamParser, ParseMIMOParam) { + ThirdPartyModelString param_string = kDemoMIMOParam; + ThirdPartyModelParam result; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + std::vector expect_input_names = {"mimo_in_0", "mimo_in_1", "mimo_in_2"}; + ASSERT_EQ(result.input_names, expect_input_names); + std::vector> expect_input_shapes = {{1, 2, 3, 4}, {5, 6}, {7, 8, 9}}; + ASSERT_EQ(result.input_shapes, expect_input_shapes); + std::vector expect_input_dtypes = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt8, + TypeId::kNumberTypeFloat16}; + ASSERT_EQ(result.input_dtypes, expect_input_dtypes); + + std::vector expect_output_names = {"mimo_out_0", "mimo_out_1"}; + ASSERT_EQ(result.output_names, expect_output_names); + std::vector> expect_output_shapes = {{2, 4}, {10, 20, 30}}; + ASSERT_EQ(result.output_shapes, expect_output_shapes); + std::vector expect_output_dtypes = {TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32}; + ASSERT_EQ(result.output_dtypes, expect_output_dtypes); +} + +TEST(TestThirdPartyParamParser, ParseMismatchedShapeAndDtypeSize) { + ThirdPartyModelString param_string = kDemoMIMOParam; + ThirdPartyModelParam result; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + param_string.input_shapes = "1,2,3,4;5,6"; // shape size is 2 while dtype size is 3. + ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); +} + +TEST(TestThirdPartyParamParser, ParseMismatchedNameAndDtypeSize) { + ThirdPartyModelString param_string = kDemoMIMOParam; + ThirdPartyModelParam result; + ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); + + param_string.input_names = "mimo_in_0;mimo_in_1"; // name size is 2 while dtype size is 3. + ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); +} diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 215d2e17..8ce0304e 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -8,6 +8,8 @@ endif() set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/lite/) + set(CCSRC_SRC ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc ${CCSRC_DIR}/backend/common/optimizer/visit.cc @@ -81,6 +83,7 @@ add_subdirectory(parser/caffe) add_subdirectory(parser/tflite) add_subdirectory(parser/onnx) add_subdirectory(parser/tf) +add_subdirectory(parser/third_party) if(ENABLE_CONVERT_PYTORCH_MODEL) add_subdirectory(parser/pytorch) endif() @@ -343,6 +346,7 @@ target_link_libraries(mindspore_converter tf_parser_mid caffe_parser_mid onnx_parser_mid + third_party_parser_mid lite_exporter_mid graph_pass_mid fusion_mid diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc index bb3e7bcf..e2e15230 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc @@ -30,6 +30,7 @@ constexpr auto kDataPreprocessParam = "data_preprocess_param"; constexpr auto kRegistry = "registry"; constexpr auto kAclOptionParam = "acl_option_cfg_param"; constexpr auto kMicroParam = "micro_param"; +constexpr auto kThirdPartyModelParam = "third_party_model"; } // namespace int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) { std::map> maps; @@ -93,6 +94,12 @@ int ConfigFileParser::ParseConfigParam(std::maperase(kThirdPartyModelParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ParseTransformQuantString failed."; + return ret; + } for (const auto &config_info : *maps) { ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second); @@ -223,5 +230,25 @@ int ConfigFileParser::ParseMicroParamString(const std::map> §ions) { + if (sections.find(kThirdPartyModelParam) == sections.end()) { + return RET_OK; + } + const auto &input_args = sections.at(kThirdPartyModelParam); + const std::map kValidArgs = { + {"input_shapes", third_party_model_string_.input_shapes}, + {"input_dtypes", third_party_model_string_.input_dtypes}, + {"input_names", third_party_model_string_.input_names}, + {"input_formats", third_party_model_string_.input_formats}, + {"output_shapes", third_party_model_string_.output_shapes}, + {"output_dtypes", third_party_model_string_.output_dtypes}, + {"output_names", third_party_model_string_.output_names}, + {"output_formats", third_party_model_string_.output_formats}, + {"extended_parameters", third_party_model_string_.extended_parameters}, + }; + return SetMapData(input_args, kValidArgs, kThirdPartyModelParam); +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h index 876bb307..6bad9e85 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h @@ -19,6 +19,8 @@ #include #include #include +#include +#include "tools/converter/cxx_api/converter_para.h" namespace mindspore { namespace lite { @@ -86,6 +88,18 @@ struct MicroParamString { std::string enable_micro; }; +struct ThirdPartyModelString { + std::string input_dtypes; + std::string input_shapes; + std::string input_names; // optional, default: "" + std::string input_formats; // optional, default: NHWC + std::string output_dtypes; + std::string output_shapes; + std::string output_names; // optional, default: "" + std::string output_formats; // optional, default: NHWC + std::string extended_parameters; // format: {key1:value1;ker2:value2} +}; + class ConfigFileParser { public: int ParseConfigFile(const std::string &config_file_path); @@ -98,6 +112,7 @@ class ConfigFileParser { RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; } AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; } MicroParamString GetMicroParamString() { return this->micro_param_string_; } + lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; } private: int ParseDataPreProcessString(const std::map> &maps); @@ -109,6 +124,7 @@ class ConfigFileParser { int SetMapData(const std::map &input_map, const std::map &parse_map, const std::string §ion); int ParseMicroParamString(const std::map> &maps); + int ParseThirdPartyParamString(const std::map> §ions); private: DataPreProcessString data_pre_process_string_; @@ -118,6 +134,7 @@ class ConfigFileParser { RegistryInfoString registry_info_string_; AclOptionCfgString acl_option_cfg_string_; MicroParamString micro_param_string_; + lite::ThirdPartyModelString third_party_model_string_; }; } // namespace lite diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc new file mode 100644 index 00000000..aee6a29c --- /dev/null +++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc @@ -0,0 +1,299 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/config_parser/third_party_param_parser.h" +#include +#include +#include +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "nnacl/op_base.h" +#include "tools/common/string_util.h" + +namespace mindspore { +namespace lite { +namespace { +const std::map kDataTypeMap = { + {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32}, + {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64}, + {"int32", TypeId::kNumberTypeInt32}, {"int16", TypeId::kNumberTypeInt16}, + {"int8", TypeId::kNumberTypeInt8}, {"uint8", TypeId::kNumberTypeUInt8}, + {"bool", TypeId::kNumberTypeBool}, +}; + +TypeId ConvertDataType(const std::string &type) { + auto iter = kDataTypeMap.find(type); + if (iter == kDataTypeMap.end()) { + return TypeId::kTypeUnknown; + } + return iter->second; +} +} // namespace + +/** + * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]]. + */ +int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector> *dst_shapes) { + MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR); + dst_shapes->clear(); + + auto tmp_shapes = SplitStringToVector(src, ";"); + for (auto tmp_shape : tmp_shapes) { + auto tmp = SplitStringToVector(tmp_shape, ","); + std::vector shape = {}; + for (auto t : tmp) { + int value = 0; + if (!ConvertIntNum(t, &value)) { + MS_LOG(ERROR) << "Found error when convert shape string to integer"; + return RET_ERROR; + } + if (value <= 0) { // Valid shape value should be greater than 0. + MS_LOG(ERROR) << "Only support fixed shapes in third party param"; + return RET_ERROR; + } + shape.push_back(value); + } + dst_shapes->push_back(shape); + } + return RET_OK; +} + +/** + * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}. + */ +int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src, + std::map> *dst_ext_param) { + MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR); + constexpr size_t kKeyIndex = 0U; + constexpr size_t kValueIndex = 1U; + constexpr size_t kKeyValueSize = 2U; + + if (src == "") { // Just return if 'extended_parameters' is configured. + return RET_OK; + } + + auto tmp_list = SplitStringToVector(src, ";"); + std::map> tmp_map = {}; + for (auto tmp : tmp_list) { + auto key_and_value = SplitStringToVector(tmp, ":"); + if (key_and_value.size() != kKeyValueSize) { + MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format"; + return RET_ERROR; + } + auto key = key_and_value[kKeyIndex]; + auto value = key_and_value[kValueIndex]; + if (tmp_map.find(key) != tmp_map.end()) { + MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated"; + return RET_ERROR; + } + tmp_map.emplace(key, std::vector(value.begin(), value.end())); + } + + *dst_ext_param = tmp_map; + return RET_OK; +} + +/** + * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32] + */ +int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector *dst_dtypes) { + MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR); + dst_dtypes->clear(); + auto tmp_dtypes = SplitStringToVector(src, ";"); + for (auto tmp_dtype : tmp_dtypes) { + TypeId type = ConvertDataType(tmp_dtype); + if (type == kTypeUnknown) { + MS_LOG(ERROR) << "Parse dtypes in third party model config failed"; + return RET_ERROR; + } + dst_dtypes->push_back(type); + } + return RET_OK; +} + +/** + * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"] + * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n" + */ +int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, + std::vector *dst_names) { + MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR); + std::string tmp_names = src; + if (tmp_names.empty()) { + std::string tmp = ""; + for (size_t i = 0; i < num; i++) { + tmp += default_prefix + "_" + std::to_string(i); + if (i + 1 < num) { + tmp += ";"; + } + } + tmp_names = tmp; + } + + *dst_names = SplitStringToVector(tmp_names, ";"); + if (dst_names->size() != num) { + MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal"; + return RET_ERROR; + } + return RET_OK; +} + +/** + * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC] + */ +namespace { + int StringToFormat(const std::string &format_string, schema::Format *format) { + static const std::unordered_map kFormatTable = { + {"NCHW", schema::Format::Format_NCHW}, + {"NHWC", schema::Format::Format_NHWC}, + {"NHWC4", schema::Format::Format_NHWC4}, + {"HWKC", schema::Format::Format_HWKC}, + {"HWCK", schema::Format::Format_HWCK}, + {"KCHW", schema::Format::Format_KCHW}, + {"CKHW", schema::Format::Format_CKHW}, + {"KHWC", schema::Format::Format_KHWC}, + {"CHWK", schema::Format::Format_CHWK}, + {"HW", schema::Format::Format_HW}, + {"HW4", schema::Format::Format_HW4}, + {"NC", schema::Format::Format_NC}, + {"NC4", schema::Format::Format_NC4}, + {"NC4HW4", schema::Format::Format_NC4HW4}, + {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT}, + {"NCDHW", schema::Format::Format_NCDHW}, + {"NWC", schema::Format::Format_NWC}, + {"NCW", schema::Format::Format_NCW}, + }; + + if (format == nullptr) { + return RET_NULL_PTR; + } + + auto iter = kFormatTable.find(format_string); + if (iter == kFormatTable.end()) { + return RET_PARAM_INVALID; + } + + *format = iter->second; + return RET_OK; + } +} + +int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num, + std::vector *result_formats) { + MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR); + std::string tmp_names = src; + if (tmp_names.empty()) { + std::vector default_formats(num, schema::Format::Format_NHWC); + *result_formats = default_formats; + return RET_OK; + } + + auto format_strings = SplitStringToVector(tmp_names, ";"); + if (format_strings.size() != num) { + MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal"; + return RET_ERROR; + } + + std::vector result(num); + for (size_t i = 0; i < num; i++) { + if (StringToFormat(format_strings[i], &result[i]) != RET_OK) { + MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid"; + return RET_PARAM_INVALID; + } + } + *result_formats = result; + return RET_OK; +} + +int ThirdPartyParamParser::Parse(const ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param) { + MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR); + + auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse input shapes of third party param failed"; + return RET_ERROR; + } + + ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse input dtypes of third party param failed"; + return RET_ERROR; + } + + auto input_shape_num = param->input_shapes.size(); + auto input_dtype_num = param->input_dtypes.size(); + if (input_shape_num != input_dtype_num) { + MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num + << " are not equal"; + return RET_ERROR; + } + + ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse input formats of third party param failed"; + return RET_ERROR; + } + + const std::string kInputNamePrefix = "in"; + ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse input names of third party param failed"; + return RET_ERROR; + } + + ret = DoParseShape(param_string.output_shapes, &(param->output_shapes)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse output shaped of third party param failed"; + return RET_ERROR; + } + + ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse output dtypes of third party param failed"; + return RET_ERROR; + } + + auto output_shape_num = param->output_shapes.size(); + auto output_dtype_num = param->output_dtypes.size(); + if (output_shape_num != output_dtype_num) { + MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num + << " are not equal"; + return RET_ERROR; + } + + ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse output formats of third party param failed"; + return RET_ERROR; + } + + const std::string kOutputNamePrefix = "out"; + ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse output names of third party param failed"; + return RET_ERROR; + } + + ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse extended parameter of third party param failed"; + return RET_ERROR; + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h new file mode 100644 index 00000000..5cf6e8fb --- /dev/null +++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ +#include +#include +#include +#include "include/errorcode.h" +#include "tools/converter/cxx_api/converter_para.h" +#include "tools/converter/config_parser/config_file_parser.h" + +namespace mindspore { +namespace lite { +class ThirdPartyParamParser { + public: + static int Parse(const lite::ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param); + + private: + static int DoParseShape(const std::string &src, std::vector> *dst_shapes); + static int DoParseExtendedParameters(const std::string &src, + std::map> *dst_ext_param); + static int DoParseDtypes(const std::string &src, std::vector *dst_dtypes); + static int DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, + std::vector *dst_names); + static int DoParseFormats(const std::string &src, size_t num, std::vector *result_formats); +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index f3d4d658..449c6ef9 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -44,6 +44,7 @@ #include "tools/converter/config_parser/micro_param_parser.h" #include "tools/converter/config_parser/preprocess_parser.h" #include "tools/converter/config_parser/quant_param_parser.h" +#include "tools/converter/config_parser/third_party_param_parser.h" #include "tools/common/string_util.h" #include "src/common/file_utils.h" @@ -89,6 +90,7 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr converter_parameters.fmk = param->fmk_type; converter_parameters.model_file = param->model_file; converter_parameters.weight_file = param->weight_file; + converter_parameters.attrs.emplace("config_file", param->config_file); func_graph_base = model_parser_->Parse(converter_parameters); } if (func_graph_base == nullptr) { @@ -96,6 +98,7 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT); return nullptr; } + auto func_graph = ConvertGraph(func_graph_base); if (func_graph == nullptr) { MS_LOG(ERROR) << "func graph is invalid."; @@ -137,9 +140,13 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr return nullptr; } MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed."); - // funcgraph_transform - graph = funcgraph_transform_->Transform(graph, param); - MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr."); + + if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) { + // funcgraph_transform + graph = funcgraph_transform_->Transform(graph, param); + MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr."); + } + // export protobuf auto status = MindIRSerialize(param, graph); if (status != RET_OK) { @@ -186,11 +193,14 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr } MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed"); - // funcgraph transform - graph = funcgraph_transform_->Transform(graph, param); - if (graph == nullptr) { - MS_LOG(ERROR) << "Transform anf graph return nullptr"; - return nullptr; + + if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) { + // funcgraph transform + graph = funcgraph_transform_->Transform(graph, param); + if (graph == nullptr) { + MS_LOG(ERROR) << "Transform anf graph return nullptr"; + return nullptr; + } } // export protobuf @@ -354,6 +364,12 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr ¶m) MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; return ret; } + ret = lite::ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), + ¶m->thirdPartyModelParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse third party param failed."; + return ret; + } ret = InitExtendedIntegrationInfo(param, config_parser); if (ret != RET_OK) { MS_LOG(ERROR) << "Parse extended integration info failed."; @@ -535,17 +551,19 @@ std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const s int CheckFmkType(const std::shared_ptr ¶m) { if (param != nullptr) { - std::set valid_values = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, - FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch}; - if (std::find(valid_values.begin(), valid_values.end(), param->fmk_type) == valid_values.end()) { - MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be kFmkTypeTf|kFmkTypeCaffe|kFmkTypeOnnx|kFmkTypeMs|kFmkTypeTflite" - << ", but got " << param->fmk_type; - return RET_INPUT_PARAM_INVALID; - } - if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) { - MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; - return RET_INPUT_PARAM_INVALID; - } + return RET_OK; + } + const std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, + FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, + FmkType::kFmkTypeThirdParty}; + if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { + MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|THIRDPARTY" + << ", but got " << param->fmk_type; + return RET_INPUT_PARAM_INVALID; + } + if ((param->fmk_type != converter::kFmkTypeCaffe) && (!param->weight_file.empty())) { + MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; + return RET_INPUT_PARAM_INVALID; } return RET_OK; } @@ -594,7 +612,7 @@ int CheckInputShape(const std::shared_ptr ¶m) { bool has_negative_dim = std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim < 0; }); if (has_negative_dim) { MS_LOG(ERROR) << "INPUT ILLEGAL: Unsupported dim < 0."; - return lite::RET_ERROR; + return lite::RET_INPUT_PARAM_INVALID; } } } diff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc index 033db968..595b59ed 100644 --- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc @@ -118,13 +118,13 @@ int Flags::InitInputOutputDataType() { int Flags::InitFmk() { // value check not here, it is in converter c++ API's CheckValueParam method. - std::map StrToEnumFmkTypeMap = {{"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, - {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, - {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}}; + std::map StrToEnumFmkTypeMap = { + {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, + {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}}; if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) { this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn); } else { - std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl; + std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|PYTORCH|THIRDPARTY" << std::endl; return RET_INPUT_PARAM_INVALID; } diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h index 58bc4c7c..00b7fa3c 100644 --- a/mindspore/lite/tools/converter/cxx_api/converter_para.h +++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h @@ -21,6 +21,7 @@ #include #include #include "include/converter.h" +#include "mindapi/base/type_id.h" #include "tools/converter/quantizer/quant_params.h" #include "tools/converter/preprocess/preprocess_param.h" #include "tools/converter/adapter/acl/common/acl_types.h" @@ -35,6 +36,18 @@ struct ParallelSplitConfig { std::vector parallel_devices_; }; +struct ThirdPartyModelParam { + std::vector input_dtypes; + std::vector> input_shapes; + std::vector input_names; + std::vector input_formats; + std::vector output_dtypes; + std::vector> output_shapes; + std::vector output_names; + std::vector output_formats; + std::map> extended_parameters; +}; + struct ConverterPara { converter::FmkType fmk_type; std::string model_file; @@ -68,6 +81,7 @@ struct ConverterPara { lite::acl::AclModelOptionCfg aclModelOptionCfgParam; lite::micro::MicroParam microParam; ParallelSplitConfig parallel_split_config; + ThirdPartyModelParam thirdPartyModelParam; }; } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_ diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 538b1ab1..7361204d 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -92,10 +92,54 @@ int QuantTransform(const std::shared_ptr ¶m, schema::MetaGrap } return RET_OK; } + +int FillGraphOutputShape(MetaGraphT *meta_graph, const std::vector> output_shapes) { + const auto &out_indices = meta_graph->outputIndex; + for (size_t i = 0; i < out_indices.size(); i++) { + auto &out_tensor = meta_graph->allTensors[out_indices[i]]; + out_tensor->dims = {}; + for (size_t k = 0; k < output_shapes[i].size(); k++) { + out_tensor->dims.push_back(static_cast(output_shapes[i][k])); + } + } + return RET_OK; +} + +void FillGraphInputAndOutputFormats(MetaGraphT *meta_graph, const ConverterPara ¶) { + const auto &in_indices = meta_graph->inputIndex; + for (size_t i = 0; i < in_indices.size(); i++) { + auto &in_tensor = meta_graph->allTensors[in_indices[i]]; + in_tensor->format = para.thirdPartyModelParam.input_formats[i]; + MS_LOG_DEBUG << "input " << i << " format: " << EnumNameFormat(in_tensor->format); + } + + const auto &out_indices = meta_graph->outputIndex; + for (size_t i = 0; i < out_indices.size(); i++) { + auto &out_tensor = meta_graph->allTensors[out_indices[i]]; + out_tensor->format = para.thirdPartyModelParam.output_formats[i]; + MS_LOG_DEBUG << "output " << i << " format: " << EnumNameFormat(out_tensor->format); + } +} } // namespace int GraphDefTransform::Transform(const std::shared_ptr ¶m) { STATUS status; + + if (param->fmk_type == converter::kFmkTypeThirdParty) { + + // Legacy optimizer infer shape, but op Custom which wraps third party model has no infer-shape function. + // So we don't perform legacy optimization for kFmkTypeThirdParty case. + auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Fill output shape of third party model failed, ret:" << ret; + return ret; + } + + // Tensor of FuncGraph has no attribute of format, so set format in MetaGraph. + FillGraphInputAndOutputFormats(graph_defT_, *param); + return RET_OK; + } + { auto old_nodes = GetGraphNodes(*graph_defT_); Optimizer unused_op_remove_optimizer; diff --git a/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt new file mode 100644 index 00000000..b55e0194 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt @@ -0,0 +1,4 @@ +add_library(third_party_parser_mid OBJECT third_party_model_parser.cc) +add_dependencies(third_party_parser_mid proto_mid) +add_dependencies(third_party_parser_mid fbs_src) +add_dependencies(third_party_parser_mid fbs_inner_src) \ No newline at end of file diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc new file mode 100644 index 00000000..652db4af --- /dev/null +++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc @@ -0,0 +1,277 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/third_party/third_party_model_parser.h" +#include +#include +#include +#include "ir/value.h" +#include "mindapi/base/type_id.h" +#include "src/common/log_util.h" +#include "src/common/file_utils.h" +#include "nnacl/op_base.h" +#include "ops/primitive_c.h" +#include "ops/custom.h" +#include "ops/tuple_get_item.h" +#include "ops/make_tuple.h" +#include "ops/return.h" +#include "tools/converter/config_parser/config_file_parser.h" +#include "include/registry/model_parser_registry.h" +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" +#include "tools/converter/converter_context.h" +#include "tools/converter/parser/lite_model_parser_creator.h" + +using mindspore::converter::kFmkTypeThirdParty; + +namespace mindspore { +namespace lite { +api::FuncGraphPtr ThirdPartyModelParser::Parse(const converter::ConverterParameters &flag) { + model_file_ = flag.model_file; + auto &attrs = flag.attrs; + auto iter = attrs.find("config_file"); + if (iter == attrs.end()) { + return nullptr; + } + auto config_file = iter->second; + + auto ret = InitConfig(config_file); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init config for third party model parsing failed"; + return nullptr; + } + + return CreateFuncGraph(); +} + +STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) { + lite::ConfigFileParser config_parser; + if (config_file.empty()) { + MS_LOG(ERROR) << "Missing config file in converting third party model"; + return RET_ERROR; + } + auto ret = config_parser.ParseConfigFile(config_file); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Get third party model section from config file failed"; + return RET_ERROR; + } + + ret = ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), ¶m_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse third party model param failed."; + return ret; + } + return RET_OK; +} + +api::FuncGraphPtr ThirdPartyModelParser::CreateFuncGraph() { + auto func_graph = std::make_shared(); + MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); + auto type_value = MakeValue(static_cast(converter::kFmkTypeThirdParty)); + MS_CHECK_TRUE_RET(type_value != nullptr, nullptr); + func_graph->set_attr("fmk", type_value); + auto attr_value = MakeValue("third_party"); + MS_CHECK_TRUE_RET(attr_value != nullptr, nullptr); + func_graph->set_attr("graph_name", attr_value); + + std::vector input_nodes = {}; + auto ret = BuildGraphInputs(func_graph, &input_nodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Create func graph input nodes failed"; + return nullptr; + } + + CNodePtr custom_node = nullptr; + ret = BuildCustomOp(func_graph, input_nodes, &custom_node); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Create func graph custom op node failed"; + return nullptr; + } + + ret = BuildGraphOutputs(func_graph, custom_node); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Create func graph output nodes failed"; + return nullptr; + } + + static auto manager = Manage(func_graph); + func_graph->set_manager(manager); + + auto result_graph = api::MakeShared(func_graph); + return result_graph; +} + +STATUS ThirdPartyModelParser::BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector *op_inputs) { + MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); + auto &dtypes = param_.input_dtypes; + auto &shapes = param_.input_shapes; + auto &names = param_.input_names; + + auto input_size = dtypes.size(); + + // Create parameter nodes for graph inputs + for (size_t i = 0; i < input_size; i++) { + auto parameter = func_graph->add_parameter(); + MSLITE_CHECK_PTR(parameter); + auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstract failed"; + return RET_ERROR; + } + parameter->set_abstract(abstract_tensor); + parameter->set_name(names[i]); + op_inputs->push_back(parameter); + } + + // Create parameter nodes for const tensor which wrapped third model buffer. + size_t model_size = 0U; + auto model_data = ReadFile(model_file_.c_str(), &model_size); + std::vector model_shape = {static_cast(model_size)}; + auto tensor_info = CreateTensorInfo(nullptr, 0, model_shape, kNumberTypeUInt8); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "init tensor info failed"; + delete model_data; + return RET_NULL_PTR; + } + auto tensor_data = reinterpret_cast(tensor_info->data_c()); + if (memcpy_s(tensor_data, tensor_info->Size(), model_data, model_size) != EOK) { + MS_LOG(ERROR) << "memcpy failed."; + delete model_data; + return RET_ERROR; + } + delete model_data; + auto parameter = func_graph->add_parameter(); + MSLITE_CHECK_PTR(parameter); + auto status = InitParameterFromTensorInfo(parameter, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed."; + return RET_ERROR; + } + parameter->set_name("ThirdPartyModel"); + op_inputs->push_back(parameter); + return RET_OK; +} + +STATUS ThirdPartyModelParser::BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector &op_inputs, + CNodePtr *operator_node) { + MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); + NotSupportOp::GetInstance()->set_fmk_type("THIRDPARTY"); + STATUS status = RET_OK; + + // create primitive and build CNode of CUSTOM operator + ops::PrimitiveCPtr primitive_c; + auto prim = std::make_unique(); + MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR); + prim->set_type("ThirdPartyModel"); + + const auto &attr = param_.extended_parameters; + prim->set_attr(attr); + primitive_c = prim->GetPrim(); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "failed to create primitive: custom"; + return RET_ERROR; + } + + auto operator_cnode = func_graph->NewCNode(primitive_c, op_inputs); + MSLITE_CHECK_PTR(operator_cnode); + operator_cnode->set_fullname_with_scope("Custom"); + *operator_node = operator_cnode; + return status; +} + +STATUS ThirdPartyModelParser::BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node) { + MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); + + auto dtypes = param_.output_dtypes; + auto shapes = param_.output_shapes; + auto names = param_.output_names; + + auto output_size = dtypes.size(); + std::vector output_nodes = {}; + + // Use TupleGetItem to wrap op outputs. + AbstractBasePtrList abstract_list; + for (size_t i = 0; i < output_size; i++) { + auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstract failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract_tensor); + auto tuple_get_item_prim_ptr = std::make_shared(); + if (tuple_get_item_prim_ptr == nullptr) { + MS_LOG(ERROR) << "new TupleGetItem failed"; + return RET_NULL_PTR; + } + auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim(); + MSLITE_CHECK_PTR(tuple_get_item_prim_c); + auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c); + MSLITE_CHECK_PTR(tuple_get_item_prim); + auto get_item_value = NewValueNode(MakeValue(i)); + MSLITE_CHECK_PTR(get_item_value); + std::vector inputs = {tuple_get_item_prim, operator_node, get_item_value}; + CNodePtr get_item_cnode = func_graph->NewCNode(inputs); + MSLITE_CHECK_PTR(get_item_cnode); + std::string output_item_name = operator_node->fullname_with_scope() + "_getitem_" + std::to_string(i); + auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32); + if (get_item_abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + get_item_cnode->set_fullname_with_scope(output_item_name); + get_item_cnode->set_abstract(get_item_abstract); + output_nodes.push_back(get_item_cnode); + } + auto abstract_tuple = std::make_shared(abstract_list); + MSLITE_CHECK_PTR(abstract_tuple); + operator_node->set_abstract(abstract_tuple); + + // Use MakeTuple node to wrap all outputs as single input of Return node. + auto make_tuple_prim_ptr = std::make_shared(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(ERROR) << "new MakeTuple failed"; + return RET_NULL_PTR; + } + auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim(); + MSLITE_CHECK_PTR(make_tuple_prim_c); + auto make_tuple_prim = NewValueNode(make_tuple_prim_c); + MSLITE_CHECK_PTR(make_tuple_prim); + std::vector make_tuple_inputs = output_nodes; + make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim); + auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs); + MSLITE_CHECK_PTR(make_tuple_cnode); + make_tuple_cnode->set_fullname_with_scope("return_tuple"); + + auto return_prim_ptr = std::make_shared(); + if (return_prim_ptr == nullptr) { + MS_LOG(ERROR) << "new Return failed"; + return RET_NULL_PTR; + } + auto return_prim_c = return_prim_ptr->GetPrim(); + MSLITE_CHECK_PTR(return_prim_c); + std::vector op_inputs{make_tuple_cnode}; + auto cnode = func_graph->NewCNode(return_prim_c, op_inputs); + MSLITE_CHECK_PTR(cnode); + cnode->set_fullname_with_scope("Return"); + func_graph->set_return(cnode); + + // Save original output tensor names. + ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(names); + return RET_OK; +} + +REG_MODEL_PARSER(kFmkTypeThirdParty, LiteModelParserCreator) +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h new file mode 100644 index 00000000..c4b197b8 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h @@ -0,0 +1,50 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ + +#include +#include +#include "schema/inner/model_generated.h" +#include "base/base.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "include/errorcode.h" +#include "include/registry/model_parser.h" +#include "tools/converter/config_parser/third_party_param_parser.h" + +namespace mindspore { +namespace lite { +class ThirdPartyModelParser : public converter::ModelParser { + public: + api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; + + private: + STATUS InitConfig(const std::string &config_file); + api::FuncGraphPtr CreateFuncGraph(); + STATUS BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector *op_inputs); + STATUS BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector &op_inputs, + CNodePtr *operator_node); + STATUS BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node); + + std::string model_file_ = ""; + ThirdPartyModelParam param_; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc index bbdafd96..c6337ea4 100644 --- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc +++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc @@ -26,7 +26,7 @@ std::map model_parser_room; } // namespace ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { - if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { + if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; return; } @@ -34,7 +34,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator } converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) { - if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { + if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; return nullptr; } -- 2.34.1