1From c87618bc9c440082b7ed6f804539b499ea2263ed Mon Sep 17 00:00:00 2001 2From: chengfeng27 <chengfeng27@huawei.com> 3Date: Thu, 30 May 2024 19:32:52 +0800 4Subject: add model version check 5 6--- 7 .../plugin/device/cpu/kernel/nnacl/kernel.c | 13 ++++++++++ 8 mindspore/lite/src/common/utils.cc | 26 +++++++++++++++++++ 9 mindspore/lite/src/common/utils.h | 7 +++++ 10 mindspore/lite/src/litert/c_api/model_c.cc | 13 +--------- 11 mindspore/lite/src/litert/lite_model.cc | 9 ++++--- 12 mindspore/lite/src/litert/scheduler.cc | 1 + 13 6 files changed, 54 insertions(+), 15 deletions(-) 14 15diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c 16index b86ab817..86a5d163 100644 17--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c 18+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c 19@@ -38,12 +38,22 @@ void Init_MSC_VER_kernels(void) { 20 return; 21 } 22 23+bool checkOpValid(int opType) { 24+ if (opType < PrimType_MIN || opType >= PrimType_MAX) { 25+ return false; 26+ } 27+ return true; 28+} 29+ 30 bool SupportKernelC(int opType, int dataType) { 31 Init_MSC_VER_kernels(); 32 const int length = 16; 33 if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) { 34 return false; 35 } 36+ if (!checkOpValid(opType)) { 37+ return false; 38+ } 39 KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)]; 40 return creator != NULL; 41 } 42@@ -77,6 +87,9 @@ int NNACLCheckKernelBase(KernelBase *kernel_base) { 43 KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, 44 int data_type, ExecEnv *env) { 45 Init_MSC_VER_kernels(); 46+ if (!checkOpValid(param->type_)) { 47+ return NULL; 48+ } 49 KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)]; 50 if (creator == NULL) { 51 return NULL; 52diff --git a/mindspore/lite/src/common/utils.cc b/mindspore/lite/src/common/utils.cc 53index c8509976..e1699687 100644 54--- a/mindspore/lite/src/common/utils.cc 55+++ b/mindspore/lite/src/common/utils.cc 56@@ -195,6 +195,32 @@ std::vector<std::string> Tokenize(const std::string &src, const std::string &del 57 return tokens; 58 } 59 60+std::string GetShortVersionStr(const std::string &s) { 61+ std::string match_str = ""; 62+ std::regex e("\\d+(\\.\\d+){2}"); 63+ auto words_begin = std::sregex_iterator(s.begin(), s.end(), e); 64+ auto words_end = std::sregex_iterator(); 65+ if (words_begin != words_end) { 66+ std::smatch match = *words_begin; 67+ match_str = match.str(); 68+ } 69+ return match_str; 70+} 71+ 72+bool IsVersionGreaterThan(const std::string& str1, const std::string& str2) { 73+ auto str1_splits = StrSplit(str1, "."); 74+ auto str2_splits = StrSplit(str2, "."); 75+ size_t len1 = str1_splits.size(); 76+ size_t len2 = str2_splits.size(); 77+ size_t len = std::min(len1, len2); 78+ for (size_t i = 0; i < len; ++i) { 79+ if (str1_splits[i] != str2_splits[i]) { 80+ return std::stoi(str1_splits[i]) > std::stoi(str2_splits[i]); 81+ } 82+ } 83+ return len1 > len2; 84+} 85+ 86 #if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) 87 uint32_t getHwCap(int hwcap_type) { 88 uint32_t ret = getauxval(hwcap_type); 89diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h 90index c3f1d069..ecbe4af2 100644 91--- a/mindspore/lite/src/common/utils.h 92+++ b/mindspore/lite/src/common/utils.h 93@@ -25,6 +25,7 @@ 94 #include <cmath> 95 #include <string> 96 #include <utility> 97+#include <regex> 98 #include "src/common/log_adapter.h" 99 #include "tools/common/option.h" 100 #include "include/errorcode.h" 101@@ -213,6 +214,12 @@ enum RemoveSubStrMode { PREFIX, SUFFIX, ANY }; 102 // remove redundant character 103 std::string RemoveSubStr(const std::string &from, const std::string &sub_str, RemoveSubStrMode mode = ANY); 104 105+// match version: x.y.z 106+std::string GetShortVersionStr(const std::string &s); 107+ 108+// compare string 109+bool IsVersionGreaterThan(const std::string& str1, const std::string& str2); 110+ 111 template <typename T> 112 inline Option<T> GenericParseValue(const std::string &value) { 113 T ret; 114diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc 115index cbbe2dbb..4f40b3d3 100644 116--- a/mindspore/lite/src/litert/c_api/model_c.cc 117+++ b/mindspore/lite/src/litert/c_api/model_c.cc 118@@ -259,8 +259,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl 119 120 std::vector<mindspore::MSTensor> ms_tensor_outputs; 121 122- bool all_has_data = false; 123- 124 size_t output_num; 125 (void)impl->GetOutputs(&output_num); 126 auto handle_num = outputs->handle_num; 127@@ -273,15 +271,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl 128 } 129 ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i])); 130 } 131- 132- all_has_data = std::all_of(ms_tensor_outputs.begin(), ms_tensor_outputs.end(), [](const mindspore::MSTensor &t) { 133- return t.Data() != nullptr; 134- }); 135- 136- if (!all_has_data) { 137- ms_tensor_outputs.clear(); 138- } 139- 140 } 141 142 auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back); 143@@ -290,7 +279,7 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl 144 return static_cast<OH_AI_Status>(ret.StatusCode()); 145 } 146 147- if (handle_num == output_num && all_has_data) { 148+ if (handle_num == output_num) { 149 return OH_AI_STATUS_SUCCESS; 150 } 151 152diff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc 153index d32db7c8..006bc02c 100644 154--- a/mindspore/lite/src/litert/lite_model.cc 155+++ b/mindspore/lite/src/litert/lite_model.cc 156@@ -29,6 +29,7 @@ 157 #include "src/common/prim_util.h" 158 #include "src/common/graph_util.h" 159 #include "src/common/file_utils.h" 160+#include "src/common/utils.h" 161 #include "src/tensor.h" 162 #include "extendrt/mindir_loader/model_loader.h" 163 #include "src/common/mmap_utils.h" 164@@ -434,9 +435,11 @@ int LiteModel::GenerateModelByVersion() { 165 if(DeObfRegister::deobf_handle != nullptr) { 166 dlclose(DeObfRegister::deobf_handle); 167 } 168- if (this->graph_.version_ != Version()) { 169- MS_LOG(INFO) << "model version is " << this->graph_.version_ << ", inference version is " << Version() 170- << " not equal"; 171+ if (IsVersionGreaterThan(GetShortVersionStr(this->graph_.version_), GetShortVersionStr(Version()))) { 172+ MS_LOG(WARNING) << "The current model version "<< this->graph_.version_ 173+ << " is later than the inference engine version " << Version() 174+ << ". Use a converter tool whose version is earlier than or equal to " 175+ << "the inference engine version to convert the model."; 176 } 177 MS_LOG(INFO) << "MindSpore Lite inference version: " << Version(); 178 return status; 179diff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc 180index d6749471..bc2cf881 100644 181--- a/mindspore/lite/src/litert/scheduler.cc 182+++ b/mindspore/lite/src/litert/scheduler.cc 183@@ -1021,6 +1021,7 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std: 184 MS_CHECK_TRUE_MSG(op_parameter != nullptr, RET_ERROR, "op parameter is nullptr."); 185 auto op_type = op_parameter->type_; 186 if (!KernelRegistry::GetInstance()->SupportKernel(desc)) { 187+ MS_LOG(INFO) << "unsupport op_type: " << PrimitiveCurVersionTypeName(op_type) << ", data_type: " << desc.data_type; 188 return RET_NOT_SUPPORT; 189 } 190 kernel::KernelKey cpu_desc = desc; 191-- 1922.17.1 193 194