From c87618bc9c440082b7ed6f804539b499ea2263ed Mon Sep 17 00:00:00 2001 From: chengfeng27 Date: Thu, 30 May 2024 19:32:52 +0800 Subject: add model version check --- .../plugin/device/cpu/kernel/nnacl/kernel.c | 13 ++++++++++ mindspore/lite/src/common/utils.cc | 26 +++++++++++++++++++ mindspore/lite/src/common/utils.h | 7 +++++ mindspore/lite/src/litert/c_api/model_c.cc | 13 +--------- mindspore/lite/src/litert/lite_model.cc | 9 ++++--- mindspore/lite/src/litert/scheduler.cc | 1 + 6 files changed, 54 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c index b86ab817..86a5d163 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c @@ -38,12 +38,22 @@ void Init_MSC_VER_kernels(void) { return; } +bool checkOpValid(int opType) { + if (opType < PrimType_MIN || opType >= PrimType_MAX) { + return false; + } + return true; +} + bool SupportKernelC(int opType, int dataType) { Init_MSC_VER_kernels(); const int length = 16; if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) { return false; } + if (!checkOpValid(opType)) { + return false; + } KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)]; return creator != NULL; } @@ -77,6 +87,9 @@ int NNACLCheckKernelBase(KernelBase *kernel_base) { KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, int data_type, ExecEnv *env) { Init_MSC_VER_kernels(); + if (!checkOpValid(param->type_)) { + return NULL; + } KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)]; if (creator == NULL) { return NULL; diff --git a/mindspore/lite/src/common/utils.cc b/mindspore/lite/src/common/utils.cc index c8509976..e1699687 100644 --- a/mindspore/lite/src/common/utils.cc +++ b/mindspore/lite/src/common/utils.cc @@ -195,6 +195,32 @@ std::vector Tokenize(const std::string &src, const std::string &del return tokens; } +std::string GetShortVersionStr(const std::string &s) { + std::string match_str = ""; + std::regex e("\\d+(\\.\\d+){2}"); + auto words_begin = std::sregex_iterator(s.begin(), s.end(), e); + auto words_end = std::sregex_iterator(); + if (words_begin != words_end) { + std::smatch match = *words_begin; + match_str = match.str(); + } + return match_str; +} + +bool IsVersionGreaterThan(const std::string& str1, const std::string& str2) { + auto str1_splits = StrSplit(str1, "."); + auto str2_splits = StrSplit(str2, "."); + size_t len1 = str1_splits.size(); + size_t len2 = str2_splits.size(); + size_t len = std::min(len1, len2); + for (size_t i = 0; i < len; ++i) { + if (str1_splits[i] != str2_splits[i]) { + return std::stoi(str1_splits[i]) > std::stoi(str2_splits[i]); + } + } + return len1 > len2; +} + #if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) uint32_t getHwCap(int hwcap_type) { uint32_t ret = getauxval(hwcap_type); diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h index c3f1d069..ecbe4af2 100644 --- a/mindspore/lite/src/common/utils.h +++ b/mindspore/lite/src/common/utils.h @@ -25,6 +25,7 @@ #include #include #include +#include #include "src/common/log_adapter.h" #include "tools/common/option.h" #include "include/errorcode.h" @@ -213,6 +214,12 @@ enum RemoveSubStrMode { PREFIX, SUFFIX, ANY }; // remove redundant character std::string RemoveSubStr(const std::string &from, const std::string &sub_str, RemoveSubStrMode mode = ANY); +// match version: x.y.z +std::string GetShortVersionStr(const std::string &s); + +// compare string +bool IsVersionGreaterThan(const std::string& str1, const std::string& str2); + template inline Option GenericParseValue(const std::string &value) { T ret; diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc index cbbe2dbb..4f40b3d3 100644 --- a/mindspore/lite/src/litert/c_api/model_c.cc +++ b/mindspore/lite/src/litert/c_api/model_c.cc @@ -259,8 +259,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl std::vector ms_tensor_outputs; - bool all_has_data = false; - size_t output_num; (void)impl->GetOutputs(&output_num); auto handle_num = outputs->handle_num; @@ -273,15 +271,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl } ms_tensor_outputs.push_back(*static_cast(outputs->handle_list[i])); } - - all_has_data = std::all_of(ms_tensor_outputs.begin(), ms_tensor_outputs.end(), [](const mindspore::MSTensor &t) { - return t.Data() != nullptr; - }); - - if (!all_has_data) { - ms_tensor_outputs.clear(); - } - } auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back); @@ -290,7 +279,7 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl return static_cast(ret.StatusCode()); } - if (handle_num == output_num && all_has_data) { + if (handle_num == output_num) { return OH_AI_STATUS_SUCCESS; } diff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc index d32db7c8..006bc02c 100644 --- a/mindspore/lite/src/litert/lite_model.cc +++ b/mindspore/lite/src/litert/lite_model.cc @@ -29,6 +29,7 @@ #include "src/common/prim_util.h" #include "src/common/graph_util.h" #include "src/common/file_utils.h" +#include "src/common/utils.h" #include "src/tensor.h" #include "extendrt/mindir_loader/model_loader.h" #include "src/common/mmap_utils.h" @@ -434,9 +435,11 @@ int LiteModel::GenerateModelByVersion() { if(DeObfRegister::deobf_handle != nullptr) { dlclose(DeObfRegister::deobf_handle); } - if (this->graph_.version_ != Version()) { - MS_LOG(INFO) << "model version is " << this->graph_.version_ << ", inference version is " << Version() - << " not equal"; + if (IsVersionGreaterThan(GetShortVersionStr(this->graph_.version_), GetShortVersionStr(Version()))) { + MS_LOG(WARNING) << "The current model version "<< this->graph_.version_ + << " is later than the inference engine version " << Version() + << ". Use a converter tool whose version is earlier than or equal to " + << "the inference engine version to convert the model."; } MS_LOG(INFO) << "MindSpore Lite inference version: " << Version(); return status; diff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc index d6749471..bc2cf881 100644 --- a/mindspore/lite/src/litert/scheduler.cc +++ b/mindspore/lite/src/litert/scheduler.cc @@ -1021,6 +1021,7 @@ int Scheduler::FindCpuKernel(const std::vector &in_tensors, const std: MS_CHECK_TRUE_MSG(op_parameter != nullptr, RET_ERROR, "op parameter is nullptr."); auto op_type = op_parameter->type_; if (!KernelRegistry::GetInstance()->SupportKernel(desc)) { + MS_LOG(INFO) << "unsupport op_type: " << PrimitiveCurVersionTypeName(op_type) << ", data_type: " << desc.data_type; return RET_NOT_SUPPORT; } kernel::KernelKey cpu_desc = desc; -- 2.17.1