1From c2f6e9f2c5cc9ec0260437a1dfd70b5d535ee740 Mon Sep 17 00:00:00 2001 2From: Zhu Guodong <zhuguodong0001@163.com> 3Date: Thu, 1 Jun 2023 22:02:23 +0800 4Subject: [PATCH] auto-apply 5 0006-Support-converting-THIRDPARTY-model-in-MSLite.patch 6 7--- 8 .../lite/include/registry/converter_context.h | 2 + 9 mindspore/lite/test/CMakeLists.txt | 4 +- 10 mindspore/lite/test/runtest.sh | 6 +- 11 .../test/ut/test_data/third_party_model.cfg | 8 + 12 .../tools/converter/api/converter_api_test.cc | 28 ++ 13 .../third_party_param_parser_test.cc | 176 +++++++++++ 14 mindspore/lite/tools/converter/CMakeLists.txt | 4 + 15 .../config_parser/config_file_parser.cc | 27 ++ 16 .../config_parser/config_file_parser.h | 17 + 17 .../config_parser/third_party_param_parser.cc | 299 ++++++++++++++++++ 18 .../config_parser/third_party_param_parser.h | 44 +++ 19 mindspore/lite/tools/converter/converter.cc | 58 ++-- 20 .../converter_lite/converter_flags.cc | 8 +- 21 .../tools/converter/cxx_api/converter_para.h | 14 + 22 .../tools/converter/graphdef_transform.cc | 44 +++ 23 .../parser/third_party/CMakeLists.txt | 4 + 24 .../third_party/third_party_model_parser.cc | 277 ++++++++++++++++ 25 .../third_party/third_party_model_parser.h | 50 +++ 26 .../registry/model_parser_registry.cc | 4 +- 27 19 files changed, 1045 insertions(+), 29 deletions(-) 28 create mode 100644 mindspore/lite/test/ut/test_data/third_party_model.cfg 29 create mode 100644 mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 30 create mode 100644 mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 31 create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 32 create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 33 create mode 100644 mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 34 create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 35 create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 36 37diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h 38index a92a3a34..dd6e6d08 100644 39--- a/mindspore/lite/include/registry/converter_context.h 40+++ b/mindspore/lite/include/registry/converter_context.h 41@@ -33,6 +33,8 @@ enum MS_API FmkType : int { 42 kFmkTypeMs = 3, 43 kFmkTypeTflite = 4, 44 kFmkTypePytorch = 5, 45+ kFmkTypeThirdParty = 6, 46+ kFmkTypeEnd = 7, // For range check purpose, valid range: [0, kFmkTypeEnd) 47 }; 48 49 /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. 50diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt 51index c0ba8e39..5fa7bea0 100644 52--- a/mindspore/lite/test/CMakeLists.txt 53+++ b/mindspore/lite/test/CMakeLists.txt 54@@ -120,6 +120,8 @@ if(MSLITE_ENABLE_CONVERTER) 55 file(GLOB_RECURSE TEST_CONVERTER_UT_SRC 56 ${TEST_DIR}/ut/tools/converter/registry/*.cc 57 ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc 58+ ${TEST_DIR}/ut/tools/converter/api/*.cc 59+ ${TEST_DIR}/ut/tools/converter/config_parser/*.cc 60 ${TEST_DIR}/st/converter_test.cc 61 ${TEST_DIR}/st/delegate_test.cc 62 ${TEST_DIR}/st/mindrt_parallel_test.cc 63@@ -234,7 +236,7 @@ endif() 64 65 if(MSLITE_ENABLE_CONVERTER) 66 target_link_libraries(lite-test-converter tflite_parser_mid caffe_parser_mid 67- onnx_parser_mid tf_parser_mid) 68+ onnx_parser_mid tf_parser_mid third_party_parser_mid) 69 endif() 70 71 if(ENABLE_MODEL_OBF) 72diff --git a/mindspore/lite/test/runtest.sh b/mindspore/lite/test/runtest.sh 73index 57c9c0aa..921d3fb3 100644 74--- a/mindspore/lite/test/runtest.sh 75+++ b/mindspore/lite/test/runtest.sh 76@@ -61,10 +61,12 @@ echo 'run common ut tests' 77 # test cases of INT8 OP 78 ## ./lite-test --gtest_filter=TestPadInt8.* 79 ./lite-test --gtest_filter=TestDeconvInt8.* 80-if [ "$ENABLE_CONVERTER_TEST" = true ];then 81+if [ "$ENABLE_CONVERTER_TEST" = true ]; then 82 ./lite-test-converter --gtest_filter="ModelParserRegistryTest.TestRegistry" 83 ./lite-test-converter --gtest_filter="NodeParserRegistryTest.TestRegistry" 84 ./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry" 85+ ./lite-test-converter --gtest_filter="TestConverterAPI.*" 86+ ./lite-test-converter --gtest_filter="TestThirdPartyParamParser.*" 87 fi 88 ./lite-test --gtest_filter="TestRegistry.TestAdd" 89 ./lite-test --gtest_filter="TestRegistryCustomOp.TestCustomAdd" 90@@ -87,7 +89,7 @@ echo 'run inference ut tests' 91 ./lite-test --gtest_filter="ControlFlowTest.TestMergeWhileModel" 92 93 echo 'run mindrt parallel ut test' 94-if [ "$ENABLE_CONVERTER_TEST" = true ];then 95+if [ "$ENABLE_CONVERTER_TEST" = true ]; then 96 ./lite-test-converter --gtest_filter="MindrtParallelTest.*" 97 echo 'user set output tensors st test' 98 ./lite-test --gtest_filter="GraphTest.UserSetGraphOutput*" 99diff --git a/mindspore/lite/test/ut/test_data/third_party_model.cfg b/mindspore/lite/test/ut/test_data/third_party_model.cfg 100new file mode 100644 101index 00000000..b5fcba75 102--- /dev/null 103+++ b/mindspore/lite/test/ut/test_data/third_party_model.cfg 104@@ -0,0 +1,8 @@ 105+[third_party_model] 106+input_names=demo_in_0;demo_in_1;demo_in_2 107+input_dtypes=float32;float16;float64 108+input_shapes=1;2,3;4,5,6 109+output_names=demo_out_0;demo_out_1;demo_out_2;demo_out_4 110+output_dtypes=int32;int16;int8;uint8 111+output_shapes=10;20,30;40;50,60,70 112+extended_parameters=foo:foo_value;bar:bar_value 113diff --git a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 114new file mode 100644 115index 00000000..0d434575 116--- /dev/null 117+++ b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 118@@ -0,0 +1,28 @@ 119+/** 120+ * Copyright 2023 Huawei Technologies Co., Ltd 121+ * 122+ * Licensed under the Apache License, Version 2.0 (the "License"); 123+ * you may not use this file except in compliance with the License. 124+ * You may obtain a copy of the License at 125+ * 126+ * http://www.apache.org/licenses/LICENSE-2.0 127+ * 128+ * Unless required by applicable law or agreed to in writing, software 129+ * distributed under the License is distributed on an "AS IS" BASIS, 130+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 131+ * See the License for the specific language governing permissions and 132+ * limitations under the License. 133+ */ 134+ 135+#include "gtest/gtest.h" 136+#include "include/converter.h" 137+ 138+TEST(TestConverterAPI, ConvertThirdParty) { 139+ std::string third_party_model = "./relu.mindir"; 140+ std::string config_model = "./third_party_model.cfg"; 141+ std::string output_model = "./demo_third_party.ms"; 142+ 143+ mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeThirdParty, third_party_model, output_model); 144+ converter.SetConfigFile(config_model); 145+ ASSERT_TRUE(converter.Convert().IsOk()); 146+} 147diff --git a/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 148new file mode 100644 149index 00000000..c8eb5536 150--- /dev/null 151+++ b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 152@@ -0,0 +1,176 @@ 153+/** 154+ * Copyright 2023 Huawei Technologies Co., Ltd 155+ * 156+ * Licensed under the Apache License, Version 2.0 (the "License"); 157+ * you may not use this file except in compliance with the License. 158+ * You may obtain a copy of the License at 159+ * 160+ * http://www.apache.org/licenses/LICENSE-2.0 161+ * 162+ * Unless required by applicable law or agreed to in writing, software 163+ * distributed under the License is distributed on an "AS IS" BASIS, 164+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 165+ * See the License for the specific language governing permissions and 166+ * limitations under the License. 167+ */ 168+ 169+#include "gtest/gtest.h" 170+#include "tools/converter/config_parser/third_party_param_parser.h" 171+ 172+using mindspore::ThirdPartyModelParam; 173+using mindspore::TypeId; 174+using mindspore::lite::RET_OK; 175+using mindspore::lite::ThirdPartyModelString; 176+using mindspore::lite::ThirdPartyParamParser; 177+ 178+const ThirdPartyModelString kDemoSISOParam = { 179+ // SISO is short for single-input-single-output. 180+ .input_dtypes = "float32", 181+ .input_shapes = "1,2,3,4", 182+ .input_names = "siso_input", 183+ .output_dtypes = "int32", 184+ .output_shapes = "2", 185+ .output_names = "siso_output", 186+ .extended_parameters = "siso_foo:siso_foo_value;siso_bar:siso_bar_value", 187+}; 188+ 189+const ThirdPartyModelString kDemoMIMOParam = { 190+ // MIMO is short for multiple-input-multiple-output. 191+ .input_dtypes = "float32;int8;float16", 192+ .input_shapes = "1,2,3,4;5,6;7,8,9", 193+ .input_names = "mimo_in_0;mimo_in_1;mimo_in_2", 194+ .output_dtypes = "int32;float32", 195+ .output_shapes = "2,4;10,20,30", 196+ .output_names = "mimo_out_0;mimo_out_1", 197+ .extended_parameters = "mimo_foo:mimo_foo_value;mimo_bar:mimo_bar_value", 198+}; 199+ 200+TEST(TestThirdPartyParamParser, ParseSISOParam) { 201+ ThirdPartyModelString param_string = kDemoSISOParam; 202+ ThirdPartyModelParam result; 203+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 204+ 205+ ASSERT_EQ(result.input_names, std::vector<std::string>{"siso_input"}); 206+ ASSERT_EQ(result.input_shapes.size(), 1U); 207+ std::vector<int64_t> expect_in_shape = {1, 2, 3, 4}; 208+ ASSERT_EQ(result.input_shapes[0], expect_in_shape); 209+ ASSERT_EQ(result.input_dtypes, std::vector<TypeId>{TypeId::kNumberTypeFloat32}); 210+ 211+ ASSERT_EQ(result.output_names, std::vector<std::string>{"siso_output"}); 212+ ASSERT_EQ(result.output_shapes.size(), 1U); 213+ std::vector<int64_t> expect_out_shape = {2}; 214+ ASSERT_EQ(result.output_shapes[0], expect_out_shape); 215+ ASSERT_EQ(result.output_dtypes, std::vector<TypeId>{TypeId::kNumberTypeInt32}); 216+ 217+ const auto &ext_param = result.extended_parameters; 218+ ASSERT_EQ(ext_param.size(), 2U); 219+ ASSERT_TRUE(ext_param.find("siso_foo") != ext_param.end()); 220+ auto expect_foo_value = ext_param.at("siso_foo"); 221+ ASSERT_EQ(std::string(expect_foo_value.begin(), expect_foo_value.end()), "siso_foo_value"); 222+ ASSERT_TRUE(ext_param.find("siso_bar") != ext_param.end()); 223+ auto expect_bar_value = ext_param.at("siso_bar"); 224+ ASSERT_EQ(std::string(expect_bar_value.begin(), expect_bar_value.end()), "siso_bar_value"); 225+} 226+ 227+TEST(TestThirdPartyParamParser, ParseValidDtype) { 228+ ThirdPartyModelString param_string = kDemoSISOParam; 229+ const std::vector<std::string> kValidDtypeStrings = { 230+ "float64", "float32", "float16", "int64", "int32", "int16", "int8", "uint8", "bool", 231+ }; 232+ 233+ const std::vector<TypeId> kExpects = { 234+ TypeId::kNumberTypeFloat64, TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat16, 235+ TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt16, 236+ TypeId::kNumberTypeInt8, TypeId::kNumberTypeUInt8, TypeId::kNumberTypeBool}; 237+ 238+ for (size_t i = 0; i < kValidDtypeStrings.size(); i++) { 239+ param_string.input_dtypes = kValidDtypeStrings[i]; 240+ ThirdPartyModelParam result; 241+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 242+ ASSERT_EQ(result.input_dtypes[0], kExpects[i]); 243+ } 244+} 245+ 246+TEST(TestThirdPartyParamParser, ParseInvalidDtype) { 247+ ThirdPartyModelParam result; 248+ ThirdPartyModelString param_string = kDemoSISOParam; 249+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 250+ param_string.input_dtypes = "bad_dtype"; 251+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 252+} 253+ 254+TEST(TestThirdPartyParamParser, ParseValidShape) { 255+ ThirdPartyModelString param_string = kDemoSISOParam; 256+ param_string.input_shapes = "256,256,1024,96"; // Only support fixed shape. 257+ ThirdPartyModelParam result; 258+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 259+ std::vector<int64_t> expect = {256, 256, 1024, 96}; 260+ ASSERT_EQ(result.input_shapes[0], expect); 261+} 262+ 263+TEST(TestThirdPartyParamParser, ParseInvalidShape) { 264+ ThirdPartyModelParam result; 265+ ThirdPartyModelString param_string = kDemoSISOParam; 266+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 267+ 268+ param_string.input_shapes = "256,256,1024,-1"; 269+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 270+ 271+ param_string.input_shapes = "256,256,0,96"; 272+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 273+ 274+ param_string.input_shapes = "256,-256,1024,96"; 275+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 276+ 277+ param_string.input_shapes = "256,foo,1024,96"; 278+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 279+} 280+ 281+TEST(TestThirdPartyParamParser, ParseDefaultName) { 282+ ThirdPartyModelParam result; 283+ ThirdPartyModelString param_string = kDemoSISOParam; 284+ param_string.input_names = ""; 285+ param_string.output_names = ""; 286+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 287+ ASSERT_EQ(result.input_names[0], "in_0"); 288+ ASSERT_EQ(result.output_names[0], "out_0"); 289+} 290+ 291+TEST(TestThirdPartyParamParser, ParseMIMOParam) { 292+ ThirdPartyModelString param_string = kDemoMIMOParam; 293+ ThirdPartyModelParam result; 294+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 295+ 296+ std::vector<std::string> expect_input_names = {"mimo_in_0", "mimo_in_1", "mimo_in_2"}; 297+ ASSERT_EQ(result.input_names, expect_input_names); 298+ std::vector<std::vector<int64_t>> expect_input_shapes = {{1, 2, 3, 4}, {5, 6}, {7, 8, 9}}; 299+ ASSERT_EQ(result.input_shapes, expect_input_shapes); 300+ std::vector<TypeId> expect_input_dtypes = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt8, 301+ TypeId::kNumberTypeFloat16}; 302+ ASSERT_EQ(result.input_dtypes, expect_input_dtypes); 303+ 304+ std::vector<std::string> expect_output_names = {"mimo_out_0", "mimo_out_1"}; 305+ ASSERT_EQ(result.output_names, expect_output_names); 306+ std::vector<std::vector<int64_t>> expect_output_shapes = {{2, 4}, {10, 20, 30}}; 307+ ASSERT_EQ(result.output_shapes, expect_output_shapes); 308+ std::vector<TypeId> expect_output_dtypes = {TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32}; 309+ ASSERT_EQ(result.output_dtypes, expect_output_dtypes); 310+} 311+ 312+TEST(TestThirdPartyParamParser, ParseMismatchedShapeAndDtypeSize) { 313+ ThirdPartyModelString param_string = kDemoMIMOParam; 314+ ThirdPartyModelParam result; 315+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 316+ 317+ param_string.input_shapes = "1,2,3,4;5,6"; // shape size is 2 while dtype size is 3. 318+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 319+} 320+ 321+TEST(TestThirdPartyParamParser, ParseMismatchedNameAndDtypeSize) { 322+ ThirdPartyModelString param_string = kDemoMIMOParam; 323+ ThirdPartyModelParam result; 324+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 325+ 326+ param_string.input_names = "mimo_in_0;mimo_in_1"; // name size is 2 while dtype size is 3. 327+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 328+} 329diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt 330index 215d2e17..8ce0304e 100644 331--- a/mindspore/lite/tools/converter/CMakeLists.txt 332+++ b/mindspore/lite/tools/converter/CMakeLists.txt 333@@ -8,6 +8,8 @@ endif() 334 set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) 335 set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) 336 337+include_directories(${CMAKE_SOURCE_DIR}/mindspore/lite/) 338+ 339 set(CCSRC_SRC 340 ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc 341 ${CCSRC_DIR}/backend/common/optimizer/visit.cc 342@@ -81,6 +83,7 @@ add_subdirectory(parser/caffe) 343 add_subdirectory(parser/tflite) 344 add_subdirectory(parser/onnx) 345 add_subdirectory(parser/tf) 346+add_subdirectory(parser/third_party) 347 if(ENABLE_CONVERT_PYTORCH_MODEL) 348 add_subdirectory(parser/pytorch) 349 endif() 350@@ -343,6 +346,7 @@ target_link_libraries(mindspore_converter 351 tf_parser_mid 352 caffe_parser_mid 353 onnx_parser_mid 354+ third_party_parser_mid 355 lite_exporter_mid 356 graph_pass_mid 357 fusion_mid 358diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 359index bb3e7bcf..e2e15230 100644 360--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 361+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 362@@ -30,6 +30,7 @@ constexpr auto kDataPreprocessParam = "data_preprocess_param"; 363 constexpr auto kRegistry = "registry"; 364 constexpr auto kAclOptionParam = "acl_option_cfg_param"; 365 constexpr auto kMicroParam = "micro_param"; 366+constexpr auto kThirdPartyModelParam = "third_party_model"; 367 } // namespace 368 int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) { 369 std::map<std::string, std::map<std::string, std::string>> maps; 370@@ -93,6 +94,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin 371 MS_LOG(ERROR) << "ParseMicroParamString failed."; 372 return ret; 373 } 374+ ret = ParseThirdPartyParamString(*maps); 375+ (void)maps->erase(kThirdPartyModelParam); 376+ if (ret != RET_OK) { 377+ MS_LOG(ERROR) << "ParseTransformQuantString failed."; 378+ return ret; 379+ } 380 381 for (const auto &config_info : *maps) { 382 ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second); 383@@ -223,5 +230,25 @@ int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map 384 } 385 return RET_OK; 386 } 387+ 388+int ConfigFileParser::ParseThirdPartyParamString( 389+ const std::map<std::string, std::map<std::string, std::string>> §ions) { 390+ if (sections.find(kThirdPartyModelParam) == sections.end()) { 391+ return RET_OK; 392+ } 393+ const auto &input_args = sections.at(kThirdPartyModelParam); 394+ const std::map<std::string, std::string &> kValidArgs = { 395+ {"input_shapes", third_party_model_string_.input_shapes}, 396+ {"input_dtypes", third_party_model_string_.input_dtypes}, 397+ {"input_names", third_party_model_string_.input_names}, 398+ {"input_formats", third_party_model_string_.input_formats}, 399+ {"output_shapes", third_party_model_string_.output_shapes}, 400+ {"output_dtypes", third_party_model_string_.output_dtypes}, 401+ {"output_names", third_party_model_string_.output_names}, 402+ {"output_formats", third_party_model_string_.output_formats}, 403+ {"extended_parameters", third_party_model_string_.extended_parameters}, 404+ }; 405+ return SetMapData(input_args, kValidArgs, kThirdPartyModelParam); 406+} 407 } // namespace lite 408 } // namespace mindspore 409diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 410index 876bb307..6bad9e85 100644 411--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h 412+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 413@@ -19,6 +19,8 @@ 414 #include <string> 415 #include <map> 416 #include <vector> 417+#include <memory> 418+#include "tools/converter/cxx_api/converter_para.h" 419 420 namespace mindspore { 421 namespace lite { 422@@ -86,6 +88,18 @@ struct MicroParamString { 423 std::string enable_micro; 424 }; 425 426+struct ThirdPartyModelString { 427+ std::string input_dtypes; 428+ std::string input_shapes; 429+ std::string input_names; // optional, default: "" 430+ std::string input_formats; // optional, default: NHWC 431+ std::string output_dtypes; 432+ std::string output_shapes; 433+ std::string output_names; // optional, default: "" 434+ std::string output_formats; // optional, default: NHWC 435+ std::string extended_parameters; // format: {key1:value1;ker2:value2} 436+}; 437+ 438 class ConfigFileParser { 439 public: 440 int ParseConfigFile(const std::string &config_file_path); 441@@ -98,6 +112,7 @@ class ConfigFileParser { 442 RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; } 443 AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; } 444 MicroParamString GetMicroParamString() { return this->micro_param_string_; } 445+ lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; } 446 447 private: 448 int ParseDataPreProcessString(const std::map<std::string, std::map<std::string, std::string>> &maps); 449@@ -109,6 +124,7 @@ class ConfigFileParser { 450 int SetMapData(const std::map<std::string, std::string> &input_map, 451 const std::map<std::string, std::string &> &parse_map, const std::string §ion); 452 int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps); 453+ int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> §ions); 454 455 private: 456 DataPreProcessString data_pre_process_string_; 457@@ -118,6 +134,7 @@ class ConfigFileParser { 458 RegistryInfoString registry_info_string_; 459 AclOptionCfgString acl_option_cfg_string_; 460 MicroParamString micro_param_string_; 461+ lite::ThirdPartyModelString third_party_model_string_; 462 }; 463 464 } // namespace lite 465diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 466new file mode 100644 467index 00000000..aee6a29c 468--- /dev/null 469+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 470@@ -0,0 +1,299 @@ 471+/** 472+ * Copyright 2023 Huawei Technologies Co., Ltd 473+ * 474+ * Licensed under the Apache License, Version 2.0 (the "License"); 475+ * you may not use this file except in compliance with the License. 476+ * You may obtain a copy of the License at 477+ * 478+ * http://www.apache.org/licenses/LICENSE-2.0 479+ * 480+ * Unless required by applicable law or agreed to in writing, software 481+ * distributed under the License is distributed on an "AS IS" BASIS, 482+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 483+ * See the License for the specific language governing permissions and 484+ * limitations under the License. 485+ */ 486+ 487+#include "tools/converter/config_parser/third_party_param_parser.h" 488+#include <vector> 489+#include <string> 490+#include <map> 491+#include "include/errorcode.h" 492+#include "src/common/log_adapter.h" 493+#include "nnacl/op_base.h" 494+#include "tools/common/string_util.h" 495+ 496+namespace mindspore { 497+namespace lite { 498+namespace { 499+const std::map<std::string, TypeId> kDataTypeMap = { 500+ {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32}, 501+ {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64}, 502+ {"int32", TypeId::kNumberTypeInt32}, {"int16", TypeId::kNumberTypeInt16}, 503+ {"int8", TypeId::kNumberTypeInt8}, {"uint8", TypeId::kNumberTypeUInt8}, 504+ {"bool", TypeId::kNumberTypeBool}, 505+}; 506+ 507+TypeId ConvertDataType(const std::string &type) { 508+ auto iter = kDataTypeMap.find(type); 509+ if (iter == kDataTypeMap.end()) { 510+ return TypeId::kTypeUnknown; 511+ } 512+ return iter->second; 513+} 514+} // namespace 515+ 516+/** 517+ * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]]. 518+ */ 519+int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes) { 520+ MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR); 521+ dst_shapes->clear(); 522+ 523+ auto tmp_shapes = SplitStringToVector(src, ";"); 524+ for (auto tmp_shape : tmp_shapes) { 525+ auto tmp = SplitStringToVector(tmp_shape, ","); 526+ std::vector<int64_t> shape = {}; 527+ for (auto t : tmp) { 528+ int value = 0; 529+ if (!ConvertIntNum(t, &value)) { 530+ MS_LOG(ERROR) << "Found error when convert shape string to integer"; 531+ return RET_ERROR; 532+ } 533+ if (value <= 0) { // Valid shape value should be greater than 0. 534+ MS_LOG(ERROR) << "Only support fixed shapes in third party param"; 535+ return RET_ERROR; 536+ } 537+ shape.push_back(value); 538+ } 539+ dst_shapes->push_back(shape); 540+ } 541+ return RET_OK; 542+} 543+ 544+/** 545+ * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}. 546+ */ 547+int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src, 548+ std::map<std::string, std::vector<uint8_t>> *dst_ext_param) { 549+ MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR); 550+ constexpr size_t kKeyIndex = 0U; 551+ constexpr size_t kValueIndex = 1U; 552+ constexpr size_t kKeyValueSize = 2U; 553+ 554+ if (src == "") { // Just return if 'extended_parameters' is configured. 555+ return RET_OK; 556+ } 557+ 558+ auto tmp_list = SplitStringToVector(src, ";"); 559+ std::map<std::string, std::vector<uint8_t>> tmp_map = {}; 560+ for (auto tmp : tmp_list) { 561+ auto key_and_value = SplitStringToVector(tmp, ":"); 562+ if (key_and_value.size() != kKeyValueSize) { 563+ MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format"; 564+ return RET_ERROR; 565+ } 566+ auto key = key_and_value[kKeyIndex]; 567+ auto value = key_and_value[kValueIndex]; 568+ if (tmp_map.find(key) != tmp_map.end()) { 569+ MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated"; 570+ return RET_ERROR; 571+ } 572+ tmp_map.emplace(key, std::vector<uint8_t>(value.begin(), value.end())); 573+ } 574+ 575+ *dst_ext_param = tmp_map; 576+ return RET_OK; 577+} 578+ 579+/** 580+ * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32] 581+ */ 582+int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes) { 583+ MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR); 584+ dst_dtypes->clear(); 585+ auto tmp_dtypes = SplitStringToVector(src, ";"); 586+ for (auto tmp_dtype : tmp_dtypes) { 587+ TypeId type = ConvertDataType(tmp_dtype); 588+ if (type == kTypeUnknown) { 589+ MS_LOG(ERROR) << "Parse dtypes in third party model config failed"; 590+ return RET_ERROR; 591+ } 592+ dst_dtypes->push_back(type); 593+ } 594+ return RET_OK; 595+} 596+ 597+/** 598+ * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"] 599+ * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n" 600+ */ 601+int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, 602+ std::vector<std::string> *dst_names) { 603+ MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR); 604+ std::string tmp_names = src; 605+ if (tmp_names.empty()) { 606+ std::string tmp = ""; 607+ for (size_t i = 0; i < num; i++) { 608+ tmp += default_prefix + "_" + std::to_string(i); 609+ if (i + 1 < num) { 610+ tmp += ";"; 611+ } 612+ } 613+ tmp_names = tmp; 614+ } 615+ 616+ *dst_names = SplitStringToVector(tmp_names, ";"); 617+ if (dst_names->size() != num) { 618+ MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal"; 619+ return RET_ERROR; 620+ } 621+ return RET_OK; 622+} 623+ 624+/** 625+ * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC] 626+ */ 627+namespace { 628+ int StringToFormat(const std::string &format_string, schema::Format *format) { 629+ static const std::unordered_map<std::string, schema::Format> kFormatTable = { 630+ {"NCHW", schema::Format::Format_NCHW}, 631+ {"NHWC", schema::Format::Format_NHWC}, 632+ {"NHWC4", schema::Format::Format_NHWC4}, 633+ {"HWKC", schema::Format::Format_HWKC}, 634+ {"HWCK", schema::Format::Format_HWCK}, 635+ {"KCHW", schema::Format::Format_KCHW}, 636+ {"CKHW", schema::Format::Format_CKHW}, 637+ {"KHWC", schema::Format::Format_KHWC}, 638+ {"CHWK", schema::Format::Format_CHWK}, 639+ {"HW", schema::Format::Format_HW}, 640+ {"HW4", schema::Format::Format_HW4}, 641+ {"NC", schema::Format::Format_NC}, 642+ {"NC4", schema::Format::Format_NC4}, 643+ {"NC4HW4", schema::Format::Format_NC4HW4}, 644+ {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT}, 645+ {"NCDHW", schema::Format::Format_NCDHW}, 646+ {"NWC", schema::Format::Format_NWC}, 647+ {"NCW", schema::Format::Format_NCW}, 648+ }; 649+ 650+ if (format == nullptr) { 651+ return RET_NULL_PTR; 652+ } 653+ 654+ auto iter = kFormatTable.find(format_string); 655+ if (iter == kFormatTable.end()) { 656+ return RET_PARAM_INVALID; 657+ } 658+ 659+ *format = iter->second; 660+ return RET_OK; 661+ } 662+} 663+ 664+int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num, 665+ std::vector<schema::Format> *result_formats) { 666+ MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR); 667+ std::string tmp_names = src; 668+ if (tmp_names.empty()) { 669+ std::vector<schema::Format> default_formats(num, schema::Format::Format_NHWC); 670+ *result_formats = default_formats; 671+ return RET_OK; 672+ } 673+ 674+ auto format_strings = SplitStringToVector(tmp_names, ";"); 675+ if (format_strings.size() != num) { 676+ MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal"; 677+ return RET_ERROR; 678+ } 679+ 680+ std::vector<schema::Format> result(num); 681+ for (size_t i = 0; i < num; i++) { 682+ if (StringToFormat(format_strings[i], &result[i]) != RET_OK) { 683+ MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid"; 684+ return RET_PARAM_INVALID; 685+ } 686+ } 687+ *result_formats = result; 688+ return RET_OK; 689+} 690+ 691+int ThirdPartyParamParser::Parse(const ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param) { 692+ MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR); 693+ 694+ auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes)); 695+ if (ret != RET_OK) { 696+ MS_LOG(ERROR) << "Parse input shapes of third party param failed"; 697+ return RET_ERROR; 698+ } 699+ 700+ ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes)); 701+ if (ret != RET_OK) { 702+ MS_LOG(ERROR) << "Parse input dtypes of third party param failed"; 703+ return RET_ERROR; 704+ } 705+ 706+ auto input_shape_num = param->input_shapes.size(); 707+ auto input_dtype_num = param->input_dtypes.size(); 708+ if (input_shape_num != input_dtype_num) { 709+ MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num 710+ << " are not equal"; 711+ return RET_ERROR; 712+ } 713+ 714+ ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats)); 715+ if (ret != RET_OK) { 716+ MS_LOG(ERROR) << "Parse input formats of third party param failed"; 717+ return RET_ERROR; 718+ } 719+ 720+ const std::string kInputNamePrefix = "in"; 721+ ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names)); 722+ if (ret != RET_OK) { 723+ MS_LOG(ERROR) << "Parse input names of third party param failed"; 724+ return RET_ERROR; 725+ } 726+ 727+ ret = DoParseShape(param_string.output_shapes, &(param->output_shapes)); 728+ if (ret != RET_OK) { 729+ MS_LOG(ERROR) << "Parse output shaped of third party param failed"; 730+ return RET_ERROR; 731+ } 732+ 733+ ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes)); 734+ if (ret != RET_OK) { 735+ MS_LOG(ERROR) << "Parse output dtypes of third party param failed"; 736+ return RET_ERROR; 737+ } 738+ 739+ auto output_shape_num = param->output_shapes.size(); 740+ auto output_dtype_num = param->output_dtypes.size(); 741+ if (output_shape_num != output_dtype_num) { 742+ MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num 743+ << " are not equal"; 744+ return RET_ERROR; 745+ } 746+ 747+ ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats)); 748+ if (ret != RET_OK) { 749+ MS_LOG(ERROR) << "Parse output formats of third party param failed"; 750+ return RET_ERROR; 751+ } 752+ 753+ const std::string kOutputNamePrefix = "out"; 754+ ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names)); 755+ if (ret != RET_OK) { 756+ MS_LOG(ERROR) << "Parse output names of third party param failed"; 757+ return RET_ERROR; 758+ } 759+ 760+ ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters)); 761+ if (ret != RET_OK) { 762+ MS_LOG(ERROR) << "Parse extended parameter of third party param failed"; 763+ return RET_ERROR; 764+ } 765+ 766+ return RET_OK; 767+} 768+} // namespace lite 769+} // namespace mindspore 770diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 771new file mode 100644 772index 00000000..5cf6e8fb 773--- /dev/null 774+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 775@@ -0,0 +1,44 @@ 776+/** 777+ * Copyright 2023 Huawei Technologies Co., Ltd 778+ * 779+ * Licensed under the Apache License, Version 2.0 (the "License"); 780+ * you may not use this file except in compliance with the License. 781+ * You may obtain a copy of the License at 782+ * 783+ * http://www.apache.org/licenses/LICENSE-2.0 784+ * 785+ * Unless required by applicable law or agreed to in writing, software 786+ * distributed under the License is distributed on an "AS IS" BASIS, 787+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 788+ * See the License for the specific language governing permissions and 789+ * limitations under the License. 790+ */ 791+ 792+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 793+#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 794+#include <string> 795+#include <vector> 796+#include <map> 797+#include "include/errorcode.h" 798+#include "tools/converter/cxx_api/converter_para.h" 799+#include "tools/converter/config_parser/config_file_parser.h" 800+ 801+namespace mindspore { 802+namespace lite { 803+class ThirdPartyParamParser { 804+ public: 805+ static int Parse(const lite::ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param); 806+ 807+ private: 808+ static int DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes); 809+ static int DoParseExtendedParameters(const std::string &src, 810+ std::map<std::string, std::vector<uint8_t>> *dst_ext_param); 811+ static int DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes); 812+ static int DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, 813+ std::vector<std::string> *dst_names); 814+ static int DoParseFormats(const std::string &src, size_t num, std::vector<schema::Format> *result_formats); 815+}; 816+} // namespace lite 817+} // namespace mindspore 818+ 819+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 820diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc 821index f3d4d658..449c6ef9 100644 822--- a/mindspore/lite/tools/converter/converter.cc 823+++ b/mindspore/lite/tools/converter/converter.cc 824@@ -44,6 +44,7 @@ 825 #include "tools/converter/config_parser/micro_param_parser.h" 826 #include "tools/converter/config_parser/preprocess_parser.h" 827 #include "tools/converter/config_parser/quant_param_parser.h" 828+#include "tools/converter/config_parser/third_party_param_parser.h" 829 #include "tools/common/string_util.h" 830 #include "src/common/file_utils.h" 831 832@@ -89,6 +90,7 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara> 833 converter_parameters.fmk = param->fmk_type; 834 converter_parameters.model_file = param->model_file; 835 converter_parameters.weight_file = param->weight_file; 836+ converter_parameters.attrs.emplace("config_file", param->config_file); 837 func_graph_base = model_parser_->Parse(converter_parameters); 838 } 839 if (func_graph_base == nullptr) { 840@@ -96,6 +98,7 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara> 841 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT); 842 return nullptr; 843 } 844+ 845 auto func_graph = ConvertGraph(func_graph_base); 846 if (func_graph == nullptr) { 847 MS_LOG(ERROR) << "func graph is invalid."; 848@@ -137,9 +140,13 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> 849 return nullptr; 850 } 851 MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed."); 852- // funcgraph_transform 853- graph = funcgraph_transform_->Transform(graph, param); 854- MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr."); 855+ 856+ if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) { 857+ // funcgraph_transform 858+ graph = funcgraph_transform_->Transform(graph, param); 859+ MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr."); 860+ } 861+ 862 // export protobuf 863 auto status = MindIRSerialize(param, graph); 864 if (status != RET_OK) { 865@@ -186,11 +193,14 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> 866 } 867 868 MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed"); 869- // funcgraph transform 870- graph = funcgraph_transform_->Transform(graph, param); 871- if (graph == nullptr) { 872- MS_LOG(ERROR) << "Transform anf graph return nullptr"; 873- return nullptr; 874+ 875+ if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) { 876+ // funcgraph transform 877+ graph = funcgraph_transform_->Transform(graph, param); 878+ if (graph == nullptr) { 879+ MS_LOG(ERROR) << "Transform anf graph return nullptr"; 880+ return nullptr; 881+ } 882 } 883 884 // export protobuf 885@@ -354,6 +364,12 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> ¶m) 886 MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; 887 return ret; 888 } 889+ ret = lite::ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), 890+ ¶m->thirdPartyModelParam); 891+ if (ret != RET_OK) { 892+ MS_LOG(ERROR) << "Parse third party param failed."; 893+ return ret; 894+ } 895 ret = InitExtendedIntegrationInfo(param, config_parser); 896 if (ret != RET_OK) { 897 MS_LOG(ERROR) << "Parse extended integration info failed."; 898@@ -535,17 +551,19 @@ std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const s 899 900 int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) { 901 if (param != nullptr) { 902- std::set valid_values = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 903- FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch}; 904- if (std::find(valid_values.begin(), valid_values.end(), param->fmk_type) == valid_values.end()) { 905- MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be kFmkTypeTf|kFmkTypeCaffe|kFmkTypeOnnx|kFmkTypeMs|kFmkTypeTflite" 906- << ", but got " << param->fmk_type; 907- return RET_INPUT_PARAM_INVALID; 908- } 909- if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) { 910- MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; 911- return RET_INPUT_PARAM_INVALID; 912- } 913+ return RET_OK; 914+ } 915+ const std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 916+ FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 917+ FmkType::kFmkTypeThirdParty}; 918+ if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { 919+ MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|THIRDPARTY" 920+ << ", but got " << param->fmk_type; 921+ return RET_INPUT_PARAM_INVALID; 922+ } 923+ if ((param->fmk_type != converter::kFmkTypeCaffe) && (!param->weight_file.empty())) { 924+ MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; 925+ return RET_INPUT_PARAM_INVALID; 926 } 927 return RET_OK; 928 } 929@@ -594,7 +612,7 @@ int CheckInputShape(const std::shared_ptr<ConverterPara> ¶m) { 930 bool has_negative_dim = std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim < 0; }); 931 if (has_negative_dim) { 932 MS_LOG(ERROR) << "INPUT ILLEGAL: Unsupported dim < 0."; 933- return lite::RET_ERROR; 934+ return lite::RET_INPUT_PARAM_INVALID; 935 } 936 } 937 } 938diff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 939index 033db968..595b59ed 100644 940--- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 941+++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 942@@ -118,13 +118,13 @@ int Flags::InitInputOutputDataType() { 943 944 int Flags::InitFmk() { 945 // value check not here, it is in converter c++ API's CheckValueParam method. 946- std::map<std::string, FmkType> StrToEnumFmkTypeMap = {{"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, 947- {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, 948- {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}}; 949+ std::map<std::string, FmkType> StrToEnumFmkTypeMap = { 950+ {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, 951+ {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}}; 952 if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) { 953 this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn); 954 } else { 955- std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl; 956+ std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|PYTORCH|THIRDPARTY" << std::endl; 957 return RET_INPUT_PARAM_INVALID; 958 } 959 960diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h 961index 58bc4c7c..00b7fa3c 100644 962--- a/mindspore/lite/tools/converter/cxx_api/converter_para.h 963+++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h 964@@ -21,6 +21,7 @@ 965 #include <vector> 966 #include <set> 967 #include "include/converter.h" 968+#include "mindapi/base/type_id.h" 969 #include "tools/converter/quantizer/quant_params.h" 970 #include "tools/converter/preprocess/preprocess_param.h" 971 #include "tools/converter/adapter/acl/common/acl_types.h" 972@@ -35,6 +36,18 @@ struct ParallelSplitConfig { 973 std::vector<std::string> parallel_devices_; 974 }; 975 976+struct ThirdPartyModelParam { 977+ std::vector<TypeId> input_dtypes; 978+ std::vector<std::vector<int64_t>> input_shapes; 979+ std::vector<std::string> input_names; 980+ std::vector<schema::Format> input_formats; 981+ std::vector<TypeId> output_dtypes; 982+ std::vector<std::vector<int64_t>> output_shapes; 983+ std::vector<std::string> output_names; 984+ std::vector<schema::Format> output_formats; 985+ std::map<std::string, std::vector<uint8_t>> extended_parameters; 986+}; 987+ 988 struct ConverterPara { 989 converter::FmkType fmk_type; 990 std::string model_file; 991@@ -68,6 +81,7 @@ struct ConverterPara { 992 lite::acl::AclModelOptionCfg aclModelOptionCfgParam; 993 lite::micro::MicroParam microParam; 994 ParallelSplitConfig parallel_split_config; 995+ ThirdPartyModelParam thirdPartyModelParam; 996 }; 997 } // namespace mindspore 998 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_ 999diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc 1000index 538b1ab1..7361204d 100644 1001--- a/mindspore/lite/tools/converter/graphdef_transform.cc 1002+++ b/mindspore/lite/tools/converter/graphdef_transform.cc 1003@@ -92,10 +92,54 @@ int QuantTransform(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGrap 1004 } 1005 return RET_OK; 1006 } 1007+ 1008+int FillGraphOutputShape(MetaGraphT *meta_graph, const std::vector<std::vector<int64_t>> output_shapes) { 1009+ const auto &out_indices = meta_graph->outputIndex; 1010+ for (size_t i = 0; i < out_indices.size(); i++) { 1011+ auto &out_tensor = meta_graph->allTensors[out_indices[i]]; 1012+ out_tensor->dims = {}; 1013+ for (size_t k = 0; k < output_shapes[i].size(); k++) { 1014+ out_tensor->dims.push_back(static_cast<int32_t>(output_shapes[i][k])); 1015+ } 1016+ } 1017+ return RET_OK; 1018+} 1019+ 1020+void FillGraphInputAndOutputFormats(MetaGraphT *meta_graph, const ConverterPara ¶) { 1021+ const auto &in_indices = meta_graph->inputIndex; 1022+ for (size_t i = 0; i < in_indices.size(); i++) { 1023+ auto &in_tensor = meta_graph->allTensors[in_indices[i]]; 1024+ in_tensor->format = para.thirdPartyModelParam.input_formats[i]; 1025+ MS_LOG_DEBUG << "input " << i << " format: " << EnumNameFormat(in_tensor->format); 1026+ } 1027+ 1028+ const auto &out_indices = meta_graph->outputIndex; 1029+ for (size_t i = 0; i < out_indices.size(); i++) { 1030+ auto &out_tensor = meta_graph->allTensors[out_indices[i]]; 1031+ out_tensor->format = para.thirdPartyModelParam.output_formats[i]; 1032+ MS_LOG_DEBUG << "output " << i << " format: " << EnumNameFormat(out_tensor->format); 1033+ } 1034+} 1035 } // namespace 1036 1037 int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> ¶m) { 1038 STATUS status; 1039+ 1040+ if (param->fmk_type == converter::kFmkTypeThirdParty) { 1041+ 1042+ // Legacy optimizer infer shape, but op Custom which wraps third party model has no infer-shape function. 1043+ // So we don't perform legacy optimization for kFmkTypeThirdParty case. 1044+ auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes); 1045+ if (ret != RET_OK) { 1046+ MS_LOG(ERROR) << "Fill output shape of third party model failed, ret:" << ret; 1047+ return ret; 1048+ } 1049+ 1050+ // Tensor of FuncGraph has no attribute of format, so set format in MetaGraph. 1051+ FillGraphInputAndOutputFormats(graph_defT_, *param); 1052+ return RET_OK; 1053+ } 1054+ 1055 { 1056 auto old_nodes = GetGraphNodes(*graph_defT_); 1057 Optimizer unused_op_remove_optimizer; 1058diff --git a/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 1059new file mode 100644 1060index 00000000..b55e0194 1061--- /dev/null 1062+++ b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 1063@@ -0,0 +1,4 @@ 1064+add_library(third_party_parser_mid OBJECT third_party_model_parser.cc) 1065+add_dependencies(third_party_parser_mid proto_mid) 1066+add_dependencies(third_party_parser_mid fbs_src) 1067+add_dependencies(third_party_parser_mid fbs_inner_src) 1068\ No newline at end of file 1069diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 1070new file mode 100644 1071index 00000000..652db4af 1072--- /dev/null 1073+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 1074@@ -0,0 +1,277 @@ 1075+/** 1076+ * Copyright 2023 Huawei Technologies Co., Ltd 1077+ * 1078+ * Licensed under the Apache License, Version 2.0 (the "License"); 1079+ * you may not use this file except in compliance with the License. 1080+ * You may obtain a copy of the License at 1081+ * 1082+ * http://www.apache.org/licenses/LICENSE-2.0 1083+ * 1084+ * Unless required by applicable law or agreed to in writing, software 1085+ * distributed under the License is distributed on an "AS IS" BASIS, 1086+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1087+ * See the License for the specific language governing permissions and 1088+ * limitations under the License. 1089+ */ 1090+#include "tools/converter/parser/third_party/third_party_model_parser.h" 1091+#include <string> 1092+#include <vector> 1093+#include <memory> 1094+#include "ir/value.h" 1095+#include "mindapi/base/type_id.h" 1096+#include "src/common/log_util.h" 1097+#include "src/common/file_utils.h" 1098+#include "nnacl/op_base.h" 1099+#include "ops/primitive_c.h" 1100+#include "ops/custom.h" 1101+#include "ops/tuple_get_item.h" 1102+#include "ops/make_tuple.h" 1103+#include "ops/return.h" 1104+#include "tools/converter/config_parser/config_file_parser.h" 1105+#include "include/registry/model_parser_registry.h" 1106+#include "tools/common/graph_util.h" 1107+#include "tools/common/tensor_util.h" 1108+#include "tools/converter/converter_context.h" 1109+#include "tools/converter/parser/lite_model_parser_creator.h" 1110+ 1111+using mindspore::converter::kFmkTypeThirdParty; 1112+ 1113+namespace mindspore { 1114+namespace lite { 1115+api::FuncGraphPtr ThirdPartyModelParser::Parse(const converter::ConverterParameters &flag) { 1116+ model_file_ = flag.model_file; 1117+ auto &attrs = flag.attrs; 1118+ auto iter = attrs.find("config_file"); 1119+ if (iter == attrs.end()) { 1120+ return nullptr; 1121+ } 1122+ auto config_file = iter->second; 1123+ 1124+ auto ret = InitConfig(config_file); 1125+ if (ret != RET_OK) { 1126+ MS_LOG(ERROR) << "Init config for third party model parsing failed"; 1127+ return nullptr; 1128+ } 1129+ 1130+ return CreateFuncGraph(); 1131+} 1132+ 1133+STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) { 1134+ lite::ConfigFileParser config_parser; 1135+ if (config_file.empty()) { 1136+ MS_LOG(ERROR) << "Missing config file in converting third party model"; 1137+ return RET_ERROR; 1138+ } 1139+ auto ret = config_parser.ParseConfigFile(config_file); 1140+ if (ret != RET_OK) { 1141+ MS_LOG(ERROR) << "Get third party model section from config file failed"; 1142+ return RET_ERROR; 1143+ } 1144+ 1145+ ret = ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), ¶m_); 1146+ if (ret != RET_OK) { 1147+ MS_LOG(ERROR) << "Parse third party model param failed."; 1148+ return ret; 1149+ } 1150+ return RET_OK; 1151+} 1152+ 1153+api::FuncGraphPtr ThirdPartyModelParser::CreateFuncGraph() { 1154+ auto func_graph = std::make_shared<FuncGraph>(); 1155+ MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); 1156+ auto type_value = MakeValue(static_cast<int>(converter::kFmkTypeThirdParty)); 1157+ MS_CHECK_TRUE_RET(type_value != nullptr, nullptr); 1158+ func_graph->set_attr("fmk", type_value); 1159+ auto attr_value = MakeValue("third_party"); 1160+ MS_CHECK_TRUE_RET(attr_value != nullptr, nullptr); 1161+ func_graph->set_attr("graph_name", attr_value); 1162+ 1163+ std::vector<AnfNodePtr> input_nodes = {}; 1164+ auto ret = BuildGraphInputs(func_graph, &input_nodes); 1165+ if (ret != RET_OK) { 1166+ MS_LOG(ERROR) << "Create func graph input nodes failed"; 1167+ return nullptr; 1168+ } 1169+ 1170+ CNodePtr custom_node = nullptr; 1171+ ret = BuildCustomOp(func_graph, input_nodes, &custom_node); 1172+ if (ret != RET_OK) { 1173+ MS_LOG(ERROR) << "Create func graph custom op node failed"; 1174+ return nullptr; 1175+ } 1176+ 1177+ ret = BuildGraphOutputs(func_graph, custom_node); 1178+ if (ret != RET_OK) { 1179+ MS_LOG(ERROR) << "Create func graph output nodes failed"; 1180+ return nullptr; 1181+ } 1182+ 1183+ static auto manager = Manage(func_graph); 1184+ func_graph->set_manager(manager); 1185+ 1186+ auto result_graph = api::MakeShared<api::FuncGraph>(func_graph); 1187+ return result_graph; 1188+} 1189+ 1190+STATUS ThirdPartyModelParser::BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs) { 1191+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 1192+ auto &dtypes = param_.input_dtypes; 1193+ auto &shapes = param_.input_shapes; 1194+ auto &names = param_.input_names; 1195+ 1196+ auto input_size = dtypes.size(); 1197+ 1198+ // Create parameter nodes for graph inputs 1199+ for (size_t i = 0; i < input_size; i++) { 1200+ auto parameter = func_graph->add_parameter(); 1201+ MSLITE_CHECK_PTR(parameter); 1202+ auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); 1203+ if (abstract_tensor == nullptr) { 1204+ MS_LOG(ERROR) << "Create tensor abstract failed"; 1205+ return RET_ERROR; 1206+ } 1207+ parameter->set_abstract(abstract_tensor); 1208+ parameter->set_name(names[i]); 1209+ op_inputs->push_back(parameter); 1210+ } 1211+ 1212+ // Create parameter nodes for const tensor which wrapped third model buffer. 1213+ size_t model_size = 0U; 1214+ auto model_data = ReadFile(model_file_.c_str(), &model_size); 1215+ std::vector<int64_t> model_shape = {static_cast<int64_t>(model_size)}; 1216+ auto tensor_info = CreateTensorInfo(nullptr, 0, model_shape, kNumberTypeUInt8); 1217+ if (tensor_info == nullptr) { 1218+ MS_LOG(ERROR) << "init tensor info failed"; 1219+ delete model_data; 1220+ return RET_NULL_PTR; 1221+ } 1222+ auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c()); 1223+ if (memcpy_s(tensor_data, tensor_info->Size(), model_data, model_size) != EOK) { 1224+ MS_LOG(ERROR) << "memcpy failed."; 1225+ delete model_data; 1226+ return RET_ERROR; 1227+ } 1228+ delete model_data; 1229+ auto parameter = func_graph->add_parameter(); 1230+ MSLITE_CHECK_PTR(parameter); 1231+ auto status = InitParameterFromTensorInfo(parameter, tensor_info); 1232+ if (status != RET_OK) { 1233+ MS_LOG(ERROR) << "init parameter from tensor info failed."; 1234+ return RET_ERROR; 1235+ } 1236+ parameter->set_name("ThirdPartyModel"); 1237+ op_inputs->push_back(parameter); 1238+ return RET_OK; 1239+} 1240+ 1241+STATUS ThirdPartyModelParser::BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs, 1242+ CNodePtr *operator_node) { 1243+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 1244+ NotSupportOp::GetInstance()->set_fmk_type("THIRDPARTY"); 1245+ STATUS status = RET_OK; 1246+ 1247+ // create primitive and build CNode of CUSTOM operator 1248+ ops::PrimitiveCPtr primitive_c; 1249+ auto prim = std::make_unique<ops::Custom>(); 1250+ MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR); 1251+ prim->set_type("ThirdPartyModel"); 1252+ 1253+ const auto &attr = param_.extended_parameters; 1254+ prim->set_attr(attr); 1255+ primitive_c = prim->GetPrim(); 1256+ if (primitive_c == nullptr) { 1257+ MS_LOG(ERROR) << "failed to create primitive: custom"; 1258+ return RET_ERROR; 1259+ } 1260+ 1261+ auto operator_cnode = func_graph->NewCNode(primitive_c, op_inputs); 1262+ MSLITE_CHECK_PTR(operator_cnode); 1263+ operator_cnode->set_fullname_with_scope("Custom"); 1264+ *operator_node = operator_cnode; 1265+ return status; 1266+} 1267+ 1268+STATUS ThirdPartyModelParser::BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node) { 1269+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 1270+ 1271+ auto dtypes = param_.output_dtypes; 1272+ auto shapes = param_.output_shapes; 1273+ auto names = param_.output_names; 1274+ 1275+ auto output_size = dtypes.size(); 1276+ std::vector<AnfNodePtr> output_nodes = {}; 1277+ 1278+ // Use TupleGetItem to wrap op outputs. 1279+ AbstractBasePtrList abstract_list; 1280+ for (size_t i = 0; i < output_size; i++) { 1281+ auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); 1282+ if (abstract_tensor == nullptr) { 1283+ MS_LOG(ERROR) << "Create tensor abstract failed"; 1284+ return RET_ERROR; 1285+ } 1286+ abstract_list.emplace_back(abstract_tensor); 1287+ auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); 1288+ if (tuple_get_item_prim_ptr == nullptr) { 1289+ MS_LOG(ERROR) << "new TupleGetItem failed"; 1290+ return RET_NULL_PTR; 1291+ } 1292+ auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim(); 1293+ MSLITE_CHECK_PTR(tuple_get_item_prim_c); 1294+ auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c); 1295+ MSLITE_CHECK_PTR(tuple_get_item_prim); 1296+ auto get_item_value = NewValueNode(MakeValue<int>(i)); 1297+ MSLITE_CHECK_PTR(get_item_value); 1298+ std::vector<AnfNodePtr> inputs = {tuple_get_item_prim, operator_node, get_item_value}; 1299+ CNodePtr get_item_cnode = func_graph->NewCNode(inputs); 1300+ MSLITE_CHECK_PTR(get_item_cnode); 1301+ std::string output_item_name = operator_node->fullname_with_scope() + "_getitem_" + std::to_string(i); 1302+ auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32); 1303+ if (get_item_abstract == nullptr) { 1304+ MS_LOG(ERROR) << "Create tensor abstarct failed"; 1305+ return RET_ERROR; 1306+ } 1307+ get_item_cnode->set_fullname_with_scope(output_item_name); 1308+ get_item_cnode->set_abstract(get_item_abstract); 1309+ output_nodes.push_back(get_item_cnode); 1310+ } 1311+ auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); 1312+ MSLITE_CHECK_PTR(abstract_tuple); 1313+ operator_node->set_abstract(abstract_tuple); 1314+ 1315+ // Use MakeTuple node to wrap all outputs as single input of Return node. 1316+ auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>(); 1317+ if (make_tuple_prim_ptr == nullptr) { 1318+ MS_LOG(ERROR) << "new MakeTuple failed"; 1319+ return RET_NULL_PTR; 1320+ } 1321+ auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim(); 1322+ MSLITE_CHECK_PTR(make_tuple_prim_c); 1323+ auto make_tuple_prim = NewValueNode(make_tuple_prim_c); 1324+ MSLITE_CHECK_PTR(make_tuple_prim); 1325+ std::vector<AnfNodePtr> make_tuple_inputs = output_nodes; 1326+ make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim); 1327+ auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs); 1328+ MSLITE_CHECK_PTR(make_tuple_cnode); 1329+ make_tuple_cnode->set_fullname_with_scope("return_tuple"); 1330+ 1331+ auto return_prim_ptr = std::make_shared<ops::Return>(); 1332+ if (return_prim_ptr == nullptr) { 1333+ MS_LOG(ERROR) << "new Return failed"; 1334+ return RET_NULL_PTR; 1335+ } 1336+ auto return_prim_c = return_prim_ptr->GetPrim(); 1337+ MSLITE_CHECK_PTR(return_prim_c); 1338+ std::vector<AnfNodePtr> op_inputs{make_tuple_cnode}; 1339+ auto cnode = func_graph->NewCNode(return_prim_c, op_inputs); 1340+ MSLITE_CHECK_PTR(cnode); 1341+ cnode->set_fullname_with_scope("Return"); 1342+ func_graph->set_return(cnode); 1343+ 1344+ // Save original output tensor names. 1345+ ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(names); 1346+ return RET_OK; 1347+} 1348+ 1349+REG_MODEL_PARSER(kFmkTypeThirdParty, LiteModelParserCreator<ThirdPartyModelParser>) 1350+} // namespace lite 1351+} // namespace mindspore 1352diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 1353new file mode 100644 1354index 00000000..c4b197b8 1355--- /dev/null 1356+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 1357@@ -0,0 +1,50 @@ 1358+/** 1359+ * Copyright 2023 Huawei Technologies Co., Ltd 1360+ * 1361+ * Licensed under the Apache License, Version 2.0 (the "License"); 1362+ * you may not use this file except in compliance with the License. 1363+ * You may obtain a copy of the License at 1364+ * 1365+ * http://www.apache.org/licenses/LICENSE-2.0 1366+ * 1367+ * Unless required by applicable law or agreed to in writing, software 1368+ * distributed under the License is distributed on an "AS IS" BASIS, 1369+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1370+ * See the License for the specific language governing permissions and 1371+ * limitations under the License. 1372+ */ 1373+ 1374+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 1375+#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 1376+ 1377+#include <string> 1378+#include <vector> 1379+#include "schema/inner/model_generated.h" 1380+#include "base/base.h" 1381+#include "ir/anf.h" 1382+#include "ir/func_graph.h" 1383+#include "include/errorcode.h" 1384+#include "include/registry/model_parser.h" 1385+#include "tools/converter/config_parser/third_party_param_parser.h" 1386+ 1387+namespace mindspore { 1388+namespace lite { 1389+class ThirdPartyModelParser : public converter::ModelParser { 1390+ public: 1391+ api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; 1392+ 1393+ private: 1394+ STATUS InitConfig(const std::string &config_file); 1395+ api::FuncGraphPtr CreateFuncGraph(); 1396+ STATUS BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs); 1397+ STATUS BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs, 1398+ CNodePtr *operator_node); 1399+ STATUS BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node); 1400+ 1401+ std::string model_file_ = ""; 1402+ ThirdPartyModelParam param_; 1403+}; 1404+} // namespace lite 1405+} // namespace mindspore 1406+ 1407+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 1408diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc 1409index bbdafd96..c6337ea4 100644 1410--- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc 1411+++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc 1412@@ -26,7 +26,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room; 1413 } // namespace 1414 1415 ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { 1416- if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { 1417+ if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { 1418 MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; 1419 return; 1420 } 1421@@ -34,7 +34,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator 1422 } 1423 1424 converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) { 1425- if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { 1426+ if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { 1427 MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; 1428 return nullptr; 1429 } 1430-- 14312.34.1 1432 1433