• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From 3d1f369cb1ae95cc7180a002067b19961650b432 Mon Sep 17 00:00:00 2001
2From: chengfeng27 <chengfeng27@huawei.com>
3Date: Tue, 30 Apr 2024 16:38:09 +0800
4Subject: [PATCH] fix ocr gcn model crash
5
6---
7 .../ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c   | 13 +++++++++++++
8 mindspore/lite/src/litert/c_api/model_c.cc          | 13 +------------
9 mindspore/lite/src/litert/scheduler.cc              |  1 +
10 3 files changed, 15 insertions(+), 12 deletions(-)
11
12diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c
13index b86ab817..86a5d163 100644
14--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c
15+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel.c
16@@ -38,12 +38,22 @@ void Init_MSC_VER_kernels(void) {
17   return;
18 }
19
20+bool checkOpValid(int opType) {
21+  if (opType < PrimType_MIN || opType >= PrimType_MAX) {
22+    return false;
23+  }
24+  return true;
25+}
26+
27 bool SupportKernelC(int opType, int dataType) {
28   Init_MSC_VER_kernels();
29   const int length = 16;
30   if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) {
31     return false;
32   }
33+  if (!checkOpValid(opType)) {
34+    return false;
35+  }
36   KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)];
37   return creator != NULL;
38 }
39@@ -77,6 +87,9 @@ int NNACLCheckKernelBase(KernelBase *kernel_base) {
40 KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size,
41                          int data_type, ExecEnv *env) {
42   Init_MSC_VER_kernels();
43+  if (!checkOpValid(param->type_)) {
44+    return NULL;
45+  }
46   KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)];
47   if (creator == NULL) {
48     return NULL;
49diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc
50index cbbe2dbb..4f40b3d3 100644
51--- a/mindspore/lite/src/litert/c_api/model_c.cc
52+++ b/mindspore/lite/src/litert/c_api/model_c.cc
53@@ -259,8 +259,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
54
55   std::vector<mindspore::MSTensor> ms_tensor_outputs;
56
57-  bool all_has_data = false;
58-
59   size_t output_num;
60   (void)impl->GetOutputs(&output_num);
61   auto handle_num = outputs->handle_num;
62@@ -273,15 +271,6 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
63       }
64       ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i]));
65     }
66-
67-    all_has_data = std::all_of(ms_tensor_outputs.begin(), ms_tensor_outputs.end(), [](const mindspore::MSTensor &t) {
68-      return t.Data() != nullptr;
69-    });
70-
71-    if (!all_has_data) {
72-      ms_tensor_outputs.clear();
73-    }
74-
75   }
76
77   auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
78@@ -290,7 +279,7 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
79     return static_cast<OH_AI_Status>(ret.StatusCode());
80   }
81
82-  if (handle_num == output_num && all_has_data) {
83+  if (handle_num == output_num) {
84     return OH_AI_STATUS_SUCCESS;
85   }
86
87diff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc
88index d6749471..bc2cf881 100644
89--- a/mindspore/lite/src/litert/scheduler.cc
90+++ b/mindspore/lite/src/litert/scheduler.cc
91@@ -1021,6 +1021,7 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std:
92   MS_CHECK_TRUE_MSG(op_parameter != nullptr, RET_ERROR, "op parameter is nullptr.");
93   auto op_type = op_parameter->type_;
94   if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
95+    MS_LOG(INFO) << "unsupport op_type: " << PrimitiveCurVersionTypeName(op_type) << ", data_type: " << desc.data_type;
96     return RET_NOT_SUPPORT;
97   }
98   kernel::KernelKey cpu_desc = desc;
99--
1002.17.1
101
102