• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From c87618bc9c440082b7ed6f804539b499ea2263ed Mon Sep 17 00:00:00 2001
2From: chengfeng27 <chengfeng27@huawei.com>
3Date: Thu, 30 May 2024 19:32:52 +0800
4Subject: add model version check
5
6---
7 .../plugin/device/cpu/kernel/nnacl/kernel.c   | 13 ++++++++++
8 mindspore/lite/src/common/utils.cc            | 26 +++++++++++++++++++
9 mindspore/lite/src/common/utils.h             |  7 +++++
10 mindspore/lite/src/litert/c_api/model_c.cc    | 13 +---------
11 mindspore/lite/src/litert/lite_model.cc       |  9 ++++---
12 mindspore/lite/src/litert/scheduler.cc        |  1 +
13 6 files changed, 54 insertions(+), 15 deletions(-)
14
15diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c
16index b86ab817..86a5d163 100644
17--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c
18+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c
19@@ -38,12 +38,22 @@ void Init_MSC_VER_kernels(void) {
20   return;
21 }
22
23+bool checkOpValid(int opType) {
24+  if (opType < PrimType_MIN || opType >= PrimType_MAX) {
25+    return false;
26+  }
27+  return true;
28+}
29+
30 bool SupportKernelC(int opType, int dataType) {
31   Init_MSC_VER_kernels();
32   const int length = 16;
33   if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) {
34     return false;
35   }
36+  if (!checkOpValid(opType)) {
37+    return false;
38+  }
39   KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)];
40   return creator != NULL;
41 }
42@@ -77,6 +87,9 @@ int NNACLCheckKernelBase(KernelBase *kernel_base) {
43 KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size,
44                          int data_type, ExecEnv *env) {
45   Init_MSC_VER_kernels();
46+  if (!checkOpValid(param->type_)) {
47+    return NULL;
48+  }
49   KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)];
50   if (creator == NULL) {
51     return NULL;
52diff --git a/mindspore/lite/src/common/utils.cc b/mindspore/lite/src/common/utils.cc
53index c8509976..e1699687 100644
54--- a/mindspore/lite/src/common/utils.cc
55+++ b/mindspore/lite/src/common/utils.cc
56@@ -195,6 +195,32 @@ std::vector<std::string> Tokenize(const std::string &src, const std::string &del
57   return tokens;
58 }
59
60+std::string GetShortVersionStr(const std::string &s) {
61+  std::string match_str = "";
62+  std::regex e("\\d+(\\.\\d+){2}");
63+  auto words_begin = std::sregex_iterator(s.begin(), s.end(), e);
64+  auto words_end = std::sregex_iterator();
65+  if (words_begin != words_end) {
66+    std::smatch match = *words_begin;
67+    match_str = match.str();
68+  }
69+  return match_str;
70+}
71+
72+bool IsVersionGreaterThan(const std::string& str1, const std::string& str2) {
73+  auto str1_splits = StrSplit(str1, ".");
74+  auto str2_splits = StrSplit(str2, ".");
75+  size_t len1 = str1_splits.size();
76+  size_t len2 = str2_splits.size();
77+  size_t len = std::min(len1, len2);
78+  for (size_t i = 0; i < len; ++i) {
79+    if (str1_splits[i] != str2_splits[i]) {
80+      return std::stoi(str1_splits[i]) > std::stoi(str2_splits[i]);
81+    }
82+  }
83+  return len1 > len2;
84+}
85+
86 #if defined(__ANDROID__) || defined(MS_COMPILE_OHOS)
87 uint32_t getHwCap(int hwcap_type) {
88   uint32_t ret = getauxval(hwcap_type);
89diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h
90index c3f1d069..ecbe4af2 100644
91--- a/mindspore/lite/src/common/utils.h
92+++ b/mindspore/lite/src/common/utils.h
93@@ -25,6 +25,7 @@
94 #include <cmath>
95 #include <string>
96 #include <utility>
97+#include <regex>
98 #include "src/common/log_adapter.h"
99 #include "tools/common/option.h"
100 #include "include/errorcode.h"
101@@ -213,6 +214,12 @@ enum RemoveSubStrMode { PREFIX, SUFFIX, ANY };
102 // remove redundant character
103 std::string RemoveSubStr(const std::string &from, const std::string &sub_str, RemoveSubStrMode mode = ANY);
104
105+// match version: x.y.z
106+std::string GetShortVersionStr(const std::string &s);
107+
108+// compare string
109+bool IsVersionGreaterThan(const std::string& str1, const std::string& str2);
110+
111 template <typename T>
112 inline Option<T> GenericParseValue(const std::string &value) {
113   T ret;
114diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc
115index cbbe2dbb..4f40b3d3 100644
116--- a/mindspore/lite/src/litert/c_api/model_c.cc
117+++ b/mindspore/lite/src/litert/c_api/model_c.cc
118@@ -259,8 +259,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
119
120   std::vector<mindspore::MSTensor> ms_tensor_outputs;
121
122-  bool all_has_data = false;
123-
124   size_t output_num;
125   (void)impl->GetOutputs(&output_num);
126   auto handle_num = outputs->handle_num;
127@@ -273,15 +271,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
128       }
129       ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i]));
130     }
131-
132-    all_has_data = std::all_of(ms_tensor_outputs.begin(), ms_tensor_outputs.end(), [](const mindspore::MSTensor &t) {
133-      return t.Data() != nullptr;
134-    });
135-
136-    if (!all_has_data) {
137-      ms_tensor_outputs.clear();
138-    }
139-
140   }
141
142   auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
143@@ -290,7 +279,7 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
144     return static_cast<OH_AI_Status>(ret.StatusCode());
145   }
146
147-  if (handle_num == output_num && all_has_data) {
148+  if (handle_num == output_num) {
149     return OH_AI_STATUS_SUCCESS;
150   }
151
152diff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc
153index d32db7c8..006bc02c 100644
154--- a/mindspore/lite/src/litert/lite_model.cc
155+++ b/mindspore/lite/src/litert/lite_model.cc
156@@ -29,6 +29,7 @@
157 #include "src/common/prim_util.h"
158 #include "src/common/graph_util.h"
159 #include "src/common/file_utils.h"
160+#include "src/common/utils.h"
161 #include "src/tensor.h"
162 #include "extendrt/mindir_loader/model_loader.h"
163 #include "src/common/mmap_utils.h"
164@@ -434,9 +435,11 @@ int LiteModel::GenerateModelByVersion() {
165   if(DeObfRegister::deobf_handle != nullptr) {
166     dlclose(DeObfRegister::deobf_handle);
167   }
168-  if (this->graph_.version_ != Version()) {
169-    MS_LOG(INFO) << "model version is " << this->graph_.version_ << ", inference version is " << Version()
170-                 << " not equal";
171+  if (IsVersionGreaterThan(GetShortVersionStr(this->graph_.version_), GetShortVersionStr(Version()))) {
172+    MS_LOG(WARNING) << "The current model version "<< this->graph_.version_
173+                    << " is later than the inference engine version " << Version()
174+                    << ". Use a converter tool whose version is earlier than or equal to "
175+                    << "the inference engine version to convert the model.";
176   }
177   MS_LOG(INFO) << "MindSpore Lite inference version: " << Version();
178   return status;
179diff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc
180index d6749471..bc2cf881 100644
181--- a/mindspore/lite/src/litert/scheduler.cc
182+++ b/mindspore/lite/src/litert/scheduler.cc
183@@ -1021,6 +1021,7 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std:
184   MS_CHECK_TRUE_MSG(op_parameter != nullptr, RET_ERROR, "op parameter is nullptr.");
185   auto op_type = op_parameter->type_;
186   if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
187+    MS_LOG(INFO) << "unsupport op_type: " << PrimitiveCurVersionTypeName(op_type) << ", data_type: " << desc.data_type;
188     return RET_NOT_SUPPORT;
189   }
190   kernel::KernelKey cpu_desc = desc;
191--
1922.17.1
193
194