• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From 03c31305a6a6c8a4b39d5fb734e312a747828672 Mon Sep 17 00:00:00 2001
2From: chengfeng27 <chengfeng27@huawei.com>
3Date: Sat, 1 Jun 2024 17:46:33 +0800
4Subject: fix output DataSize 0, heap-buffer-overflow
5
6---
7 include/c_api/tensor_c.h                      |  15 ++
8 mindspore/lite/BUILD.gn                       |   1 +
9 mindspore/lite/src/litert/c_api/model_c.cc    |  40 ++++-
10 mindspore/lite/src/litert/c_api/tensor_c.cc   |  32 ++++
11 .../lite/src/litert/c_api/type_c_private.h    |   3 +
12 .../src/litert/cxx_api/model/model_impl.cc    |  77 +++++++-
13 .../litert/delegate/nnrt/nnrt_allocator.cc    | 168 ++++++++++++++++++
14 .../src/litert/delegate/nnrt/nnrt_allocator.h |  64 +++++++
15 .../litert/delegate/nnrt/nnrt_model_kernel.cc |  50 +++++-
16 .../litert/delegate/nnrt/nnrt_model_kernel.h  |   3 +
17 .../litert/kernel/cpu/nnacl/nnacl_kernel.cc   |   2 +-
18 mindspore/lite/src/litert/mindrt_executor.cc  |  14 +-
19 12 files changed, 458 insertions(+), 11 deletions(-)
20 create mode 100644 mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.cc
21 create mode 100644 mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.h
22
23diff --git a/include/c_api/tensor_c.h b/include/c_api/tensor_c.h
24index 6d2aaab6..2f641725 100644
25--- a/include/c_api/tensor_c.h
26+++ b/include/c_api/tensor_c.h
27@@ -154,6 +154,21 @@ OH_AI_API int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor);
28 /// \return The data size of the tensor.
29 OH_AI_API size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor);
30
31+/// \brief Obtain allocator for the tensor.
32+///
33+/// \param[in] tensor Tensor object handle.
34+///
35+/// \return The pointer of allocator.
36+OH_AI_API void *OH_AI_TensorGetAllocator(OH_AI_TensorHandle tensor);
37+
38+/// \brief Set allocator for the tensor.
39+///
40+/// \param[in] tensor Tensor object handle.
41+/// \param[in] allocator A pointer to the allocator.
42+///
43+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
44+OH_AI_API OH_AI_Status OH_AI_TensorSetAllocator(OH_AI_TensorHandle tensor, void *allocator);
45+
46 #ifdef __cplusplus
47 }
48 #endif
49diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn
50index 5866e335..58ee5e51 100644
51--- a/mindspore/lite/BUILD.gn
52+++ b/mindspore/lite/BUILD.gn
53@@ -443,6 +443,7 @@ ohos_shared_library("mindspore_lib") {
54       "src/litert/delegate/nnrt/checker/primitive_check.cc",
55       "src/litert/delegate/nnrt/nnrt_delegate.cc",
56       "src/litert/delegate/nnrt/nnrt_model_kernel.cc",
57+      "src/litert/delegate/nnrt/nnrt_allocator.cc",
58     ]
59     include_dirs += [
60       "src/delegate/nnrt/include",
61diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc
62index 9da52d76..20e1c227 100644
63--- a/mindspore/lite/src/litert/c_api/model_c.cc
64+++ b/mindspore/lite/src/litert/c_api/model_c.cc
65@@ -14,6 +14,7 @@
66  * limitations under the License.
67  */
68 #include "include/c_api/model_c.h"
69+#include "type_c_private.h"
70 #include <vector>
71 #include <cstdint>
72 #include "include/api/context.h"
73@@ -37,6 +38,11 @@ public:
74     for (auto out : outputs_train_) {
75       delete out;
76     }
77+
78+    // In zero copy scene where user will call set or get allocator function, but when model is destroyed, the allocator
79+    // table will not be freed, and its size continues to grow causing memory leak, so when ModelC is destroyed, clean
80+    // allocator table.
81+    CleanAllocatorTable();
82   }
83
84   MSTensor **GetInputs(size_t *input_num);
85@@ -246,10 +252,42 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
86   mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after);
87
88   std::vector<mindspore::MSTensor> ms_tensor_outputs;
89+
90+  bool all_has_data = false;
91+
92+  size_t output_num;
93+  (void)impl->GetOutputs(&output_num);
94+  auto handle_num = outputs->handle_num;
95+  if (handle_num == output_num) {
96+    MS_LOG(INFO) << "use user provided output";
97+    for (size_t i = 0; i < output_num; i++) {
98+      if (outputs->handle_list[i] == nullptr) {
99+        MS_LOG(ERROR) << "user provided output array handle_list[" << i << "] is nullptr";
100+        return OH_AI_STATUS_LITE_NULLPTR;
101+      }
102+      ms_tensor_outputs.push_back(*static_cast<mindspore::MSTensor *>(outputs->handle_list[i]));
103+    }
104+
105+    all_has_data = std::all_of(ms_tensor_outputs.begin(), ms_tensor_outputs.end(), [](const mindspore::MSTensor &t) {
106+      return t.Data() != nullptr;
107+    });
108+
109+    if (!all_has_data) {
110+      ms_tensor_outputs.clear();
111+    }
112+
113+  }
114+
115   auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
116   if (!ret.IsOk()) {
117     MS_LOG(ERROR) << "Predict fail, ret :" << ret;
118+    return static_cast<OH_AI_Status>(ret.StatusCode());
119   }
120+
121+  if (handle_num == output_num && all_has_data) {
122+    return OH_AI_STATUS_SUCCESS;
123+  }
124+
125   outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&outputs->handle_num));
126   return static_cast<OH_AI_Status>(ret.StatusCode());
127 }
128@@ -345,7 +383,7 @@ char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) {
129   auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
130   auto loss_name = impl->GetLossName();
131   *num = loss_name.size();
132-  char **name = static_cast<char **>(malloc(loss_name.size()));
133+  char **name = static_cast<char **>(malloc(loss_name.size() * sizeof(char *)));
134   if (name == nullptr) {
135     MS_LOG(ERROR) << "Failed to malloc loss_name.";
136     return nullptr;
137diff --git a/mindspore/lite/src/litert/c_api/tensor_c.cc b/mindspore/lite/src/litert/c_api/tensor_c.cc
138index 4b1e6aff..fc3814dd 100644
139--- a/mindspore/lite/src/litert/c_api/tensor_c.cc
140+++ b/mindspore/lite/src/litert/c_api/tensor_c.cc
141@@ -13,11 +13,18 @@
142  * See the License for the specific language governing permissions and
143  * limitations under the License.
144  */
145+#include <unordered_map>
146 #include "include/c_api/tensor_c.h"
147 #include "include/api/status.h"
148 #include "src/tensor.h"
149 #include "src/litert/cxx_api/tensor/tensor_impl.h"
150
151+static std::unordered_map<void *, std::weak_ptr<mindspore::Allocator>> allocator_table;
152+
153+void CleanAllocatorTable() {
154+  allocator_table.clear();
155+}
156+
157 OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, const int64_t *shape, size_t shape_num,
158                                       const void *data, size_t data_len) {
159   if (name == nullptr || shape == nullptr) {
160@@ -208,3 +215,28 @@ size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor) {
161   auto impl = static_cast<mindspore::MSTensor *>(tensor);
162   return impl->DataSize();
163 }
164+
165+OH_AI_Status OH_AI_TensorSetAllocator(OH_AI_TensorHandle tensor, void *allocator) {
166+  if (tensor == nullptr) {
167+    MS_LOG(ERROR) << "param is nullptr.";
168+    return OH_AI_STATUS_LITE_NULLPTR;
169+  }
170+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
171+  if (allocator_table.count(allocator) == 0) {
172+    MS_LOG(ERROR) << "the input allocator does not belong to framework";
173+    return OH_AI_STATUS_LITE_PARAM_INVALID;
174+  }
175+  std::static_pointer_cast<mindspore::LiteTensorImpl>(impl->impl())->set_own_data(true);
176+  impl->SetAllocator(allocator_table[allocator].lock());
177+  return OH_AI_STATUS_SUCCESS;
178+}
179+
180+void *OH_AI_TensorGetAllocator(const OH_AI_TensorHandle tensor) {
181+  if (tensor == nullptr) {
182+    MS_LOG(ERROR) << "param is nullptr.";
183+    return nullptr;
184+  }
185+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
186+  allocator_table[impl->allocator().get()] = impl->allocator();
187+  return impl->allocator().get();
188+}
189diff --git a/mindspore/lite/src/litert/c_api/type_c_private.h b/mindspore/lite/src/litert/c_api/type_c_private.h
190index 2d3b3883..1a76820d 100644
191--- a/mindspore/lite/src/litert/c_api/type_c_private.h
192+++ b/mindspore/lite/src/litert/c_api/type_c_private.h
193@@ -36,5 +36,8 @@ struct NNRTDeviceDesc {
194
195 #ifdef __cplusplus
196 }
197+
198+void CleanAllocatorTable();
199+
200 #endif
201 #endif  // MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_
202diff --git a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc
203index 78b1ca67..02533dc3 100644
204--- a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc
205+++ b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc
206@@ -463,7 +463,60 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
207           input->set_shape(truncate_shape);
208 #endif
209         }
210-        input->set_data(user_input.MutableData());
211+        if (user_input.allocator() == input->allocator()) {
212+          input->set_data(user_input.MutableData());
213+          input->set_own_data(false);
214+        } else {
215+          void *user_data = user_input.MutableData();
216+          if (user_data == nullptr) {
217+            MS_LOG(ERROR) << "user data is nullptr";
218+            return kLiteNullptr;
219+          }
220+          void *input_data = input->MutableData();
221+          if (input_data == nullptr) {
222+            MS_LOG(ERROR) << "input data is nullptr";
223+            return kLiteNullptr;
224+          }
225+          memcpy(input_data, user_data, input->Size());
226+        }
227+      }
228+    }
229+  }
230+
231+  auto ori_output_tensors = GetOutputs();
232+  std::vector<bool> copy_output_data;
233+  copy_output_data.resize(ori_output_tensors.size(), false);
234+  if (outputs->empty()) {
235+    MS_LOG(INFO) << "user provided output is empty";
236+  } else if (outputs->size() != ori_output_tensors.size()) {
237+    MS_LOG(ERROR) << "user provided output size is not equal to model's output size";
238+    return kLiteError;
239+  } else {
240+    for (size_t i = 0; i < ori_output_tensors.size(); i++) {
241+      auto ori_output = ori_output_tensors[i];
242+      auto lite_impl = std::static_pointer_cast<LiteTensorImpl>(ori_output.impl());
243+      MS_CHECK_TRUE_RET(lite_impl != nullptr, kLiteNullptr);
244+      auto ori_out_tensor = static_cast<lite::Tensor *>(lite_impl->lite_tensor());
245+      MS_CHECK_TRUE_RET(ori_out_tensor != nullptr, kLiteNullptr);
246+
247+      auto user_output = (*outputs)[i];
248+      auto user_lite_impl = std::static_pointer_cast<LiteTensorImpl>(user_output.impl());
249+      MS_CHECK_TRUE_RET(user_lite_impl != nullptr, kLiteNullptr);
250+      auto user_out_tensor = user_lite_impl->lite_tensor();
251+      if (ori_out_tensor == user_out_tensor) {
252+        continue;
253+      }
254+
255+      void *user_out_data = nullptr;
256+      if (user_output.DataSize() > 0) {
257+        user_out_data = user_output.MutableData();
258+      }
259+      if (ori_out_tensor->allocator() == user_output.allocator() && user_out_data != nullptr) {
260+        MS_LOG(INFO) << "use user data";
261+        ori_out_tensor->set_data(user_out_data);
262+        ori_out_tensor->set_own_data(false);
263+      } else {
264+        copy_output_data[i] = true;
265       }
266     }
267   }
268@@ -474,6 +527,28 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
269     return ret;
270   }
271   MS_LOG(DEBUG) << "Run graph success.";
272+
273+  for (size_t i = 0; i < copy_output_data.size(); i++) {
274+    if (!copy_output_data[i]) {
275+      continue;
276+    }
277+    auto ori_output = ori_output_tensors[i];
278+    auto ori_out_data = ori_output.MutableData();
279+    MS_CHECK_TRUE_RET(ori_out_data != nullptr, kLiteNullptr);
280+    auto user_output = (*outputs)[i];
281+    MS_CHECK_TRUE_RET(user_output.MutableData() != nullptr, kLiteNullptr);
282+    if (user_output.DataSize() >= ori_output.DataSize()) {
283+      memcpy(user_output.MutableData(), ori_out_data, ori_output.DataSize());
284+    } else {
285+      MS_LOG(ERROR) << "user out data size is less than model's output data size";
286+      return kLiteError;
287+    }
288+  }
289+
290+  if (outputs->size() == ori_output_tensors.size()) {
291+    return kSuccess;
292+  }
293+
294   auto res = GetOutputs();
295   if (res.empty()) {
296     MS_LOG(DEBUG) << "Empty outputs.";
297diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.cc
298new file mode 100644
299index 00000000..f79c1682
300--- /dev/null
301+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.cc
302@@ -0,0 +1,168 @@
303+/**
304+ * Copyright 2023 Huawei Technologies Co., Ltd
305+ *
306+ * Licensed under the Apache License, Version 2.0 (the "License");
307+ * you may not use this file except in compliance with the License.
308+ * You may obtain a copy of the License at
309+ *
310+ * http://www.apache.org/licenses/LICENSE-2.0
311+ *
312+ * Unless required by applicable law or agreed to in writing, software
313+ * distributed under the License is distributed on an "AS IS" BASIS,
314+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
315+ * See the License for the specific language governing permissions and
316+ * limitations under the License.
317+ */
318+
319+#include <memory>
320+#include <atomic>
321+#include <unordered_map>
322+#include <map>
323+#include <mutex>
324+#include "src/litert/delegate/nnrt/nnrt_allocator.h"
325+#include "src/common/log.h"
326+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
327+
328+namespace mindspore {
329+namespace lite {
330+NNRTAllocator::~NNRTAllocator() {
331+  std::lock_guard<std::mutex> locker(mutex_);
332+  for (auto &it : allocated_list_) {
333+    auto membuf = it.second;
334+    if (memory_category_ == NNRT_INPUT) {
335+      OH_NNExecutor_DestroyInputMemory(executor_, index_, &(membuf->memory_));
336+    } else {
337+      OH_NNExecutor_DestroyOutputMemory(executor_, index_, &(membuf->memory_));
338+    }
339+    free(membuf);
340+  }
341+  allocated_list_.clear();
342+
343+  for (auto &it : free_list_) {
344+    auto membuf = it.second;
345+    if (memory_category_ == NNRT_INPUT) {
346+      OH_NNExecutor_DestroyInputMemory(executor_, index_, &(membuf->memory_));
347+    } else {
348+      OH_NNExecutor_DestroyOutputMemory(executor_, index_, &(membuf->memory_));
349+    }
350+    free(membuf);
351+  }
352+  free_list_.clear();
353+}
354+
355+void *NNRTAllocator::Malloc(size_t size) {
356+  std::lock_guard<std::mutex> locker(mutex_);
357+  auto iter = free_list_.lower_bound(size);
358+  if (iter != free_list_.end()) {
359+    auto membuf = iter->second;
360+    membuf->ref_count_ = 0;
361+    (void)free_list_.erase(iter);
362+    allocated_list_[membuf->memory_->data] = membuf;
363+    return membuf->memory_->data;
364+  }
365+
366+  auto membuf = new (std::nothrow) MemBuf();
367+  if (membuf == nullptr) {
368+    MS_LOG(ERROR) << "new Membuf failed.";
369+    return nullptr;
370+  }
371+
372+  membuf->ref_count_ = 0;
373+  if (memory_category_ == NNRT_INPUT) {
374+    membuf->memory_ = OH_NNExecutor_AllocateInputMemory(executor_, index_, size);
375+  } else {
376+    membuf->memory_ = OH_NNExecutor_AllocateOutputMemory(executor_, index_, size);
377+  }
378+
379+  if (membuf->memory_ == nullptr) {
380+    MS_LOG(ERROR) << "malloc OH_NN_Memory return nullptr";
381+    return nullptr;
382+  }
383+  if (membuf->memory_->data == nullptr) {
384+    MS_LOG(ERROR) << "malloc OH_NN_Memory return nullptr";
385+    if (memory_category_ == NNRT_INPUT) {
386+      OH_NNExecutor_DestroyInputMemory(executor_, index_, &(membuf->memory_));
387+    } else {
388+      OH_NNExecutor_DestroyOutputMemory(executor_, index_, &(membuf->memory_));
389+    }
390+    return nullptr;
391+  }
392+
393+  allocated_list_[membuf->memory_->data] = membuf;
394+  return membuf->memory_->data;
395+}
396+
397+void NNRTAllocator::Free(void *ptr) {
398+  if (ptr == nullptr) {
399+    return;
400+  }
401+
402+  std::lock_guard<std::mutex> locker(mutex_);
403+  auto iter = allocated_list_.find(ptr);
404+  if (iter == allocated_list_.end()) {
405+    return;
406+  }
407+  auto membuf = iter->second;
408+  membuf->ref_count_ = 0;
409+  (void)allocated_list_.erase(iter);
410+  (void)free_list_.insert(std::make_pair(membuf->memory_->length, membuf));
411+}
412+
413+int NNRTAllocator::RefCount(void *ptr) {
414+  if (ptr == nullptr) {
415+    return -1;
416+  }
417+  std::lock_guard<std::mutex> locker(mutex_);
418+  auto iter = allocated_list_.find(ptr);
419+  if (iter != allocated_list_.end()) {
420+    auto membuf = iter->second;
421+    int ref_count = std::atomic_load(&membuf->ref_count_);
422+    return ref_count;
423+  }
424+  return -1;
425+}
426+
427+int NNRTAllocator::SetRefCount(void *ptr, int ref_count) {
428+  if (ptr == nullptr) {
429+    return -1;
430+  }
431+  std::lock_guard<std::mutex> locker(mutex_);
432+  auto iter = allocated_list_.find(ptr);
433+  if (iter != allocated_list_.end()) {
434+    auto membuf = iter->second;
435+    std::atomic_store(&membuf->ref_count_, ref_count);
436+    return ref_count;
437+  }
438+  return -1;
439+}
440+
441+int NNRTAllocator::DecRefCount(void *ptr, int ref_count) {
442+  if (ptr == nullptr) {
443+    return -1;
444+  }
445+  std::lock_guard<std::mutex> locker(mutex_);
446+  auto iter = allocated_list_.find(ptr);
447+  if (iter != allocated_list_.end()) {
448+    auto membuf = iter->second;
449+    auto ref = std::atomic_fetch_sub(&membuf->ref_count_, ref_count);
450+    return ref;
451+  }
452+  return -1;
453+}
454+
455+int NNRTAllocator::IncRefCount(void *ptr, int ref_count) {
456+  if (ptr == nullptr) {
457+    return -1;
458+  }
459+  std::lock_guard<std::mutex> locker(mutex_);
460+  auto iter = allocated_list_.find(ptr);
461+  if (iter != allocated_list_.end()) {
462+    auto membuf = iter->second;
463+    auto ref = std::atomic_fetch_add(&membuf->ref_count_, ref_count);
464+    return ref;
465+  }
466+  return -1;
467+}
468+
469+}  // namespace lite
470+}  // namespace mindspore
471\ No newline at end of file
472diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.h
473new file mode 100644
474index 00000000..f6721369
475--- /dev/null
476+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_allocator.h
477@@ -0,0 +1,64 @@
478+/**
479+* Copyright 2023 Huawei Technologies Co., Ltd
480+*
481+* Licensed under the Apache License, Version 2.0 (the "License");
482+* you may not use this file except in compliance with the License.
483+* You may obtain a copy of the License at
484+*
485+* http://www.apache.org/licenses/LICENSE-2.0
486+*
487+* Unless required by applicable law or agreed to in writing, software
488+* distributed under the License is distributed on an "AS IS" BASIS,
489+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
490+* See the License for the specific language governing permissions and
491+* limitations under the License.
492+ */
493+#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_NNRT_NNRT_ALLOCATOR_H_
494+#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_NNRT_NNRT_ALLOCATOR_H_
495+
496+#include <vector>
497+#include <map>
498+#include <atomic>
499+#include <unordered_map>
500+#include <map>
501+#include <mutex>
502+#include "include/api/allocator.h"
503+struct OH_NN_Memory;
504+struct OH_NNExecutor;
505+
506+namespace mindspore {
507+namespace lite {
508+enum MemoryCategory { NNRT_INPUT, NNRT_OUTPUT };
509+
510+class NNRTAllocator : public Allocator {
511+ public:
512+  NNRTAllocator(OH_NNExecutor *executor, int index, MemoryCategory memory_category)
513+      : index_(index), memory_category_(memory_category), executor_(executor) {}
514+  ~NNRTAllocator() override;
515+
516+  void *Malloc(size_t size) override;
517+  void Free(void *ptr) override;
518+  int RefCount(void *ptr) override;
519+  int SetRefCount(void *ptr, int ref_count) override;
520+  int DecRefCount(void *ptr, int ref_count) override;
521+  int IncRefCount(void *ptr, int ref_count) override;
522+
523+ private:
524+  struct MemBuf {
525+    std::atomic_int ref_count_{0};
526+    OH_NN_Memory *memory_{nullptr};
527+  };
528+
529+  int index_{0};
530+  MemoryCategory memory_category_{NNRT_INPUT};
531+  OH_NNExecutor *executor_{nullptr};
532+  std::mutex mutex_;
533+  // <membuf->memory_->data, membuf>
534+  std::unordered_map<void *, MemBuf *> allocated_list_;
535+  std::multimap<size_t, MemBuf *> free_list_;
536+};
537+
538+}  // namespace lite
539+}  // namespace mindspore
540+
541+#endif  // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_NNRT_NNRT_ALLOCATOR_H_
542\ No newline at end of file
543diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
544index 67443e08..f83632dd 100644
545--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
546+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
547@@ -15,8 +15,33 @@
548  */
549 #include <include/errorcode.h>
550 #include "nnrt_model_kernel.h"
551-int mindspore::NNRTModelKernel::Prepare() { return 0; }
552+#include "nnrt_allocator.h"
553+#include "litert/cxx_api/tensor/tensor_impl.h"
554+int mindspore::NNRTModelKernel::Prepare() {
555+  for (size_t i = 0; i < inputs_.size(); i++) {
556+    auto nnrt_allocator = std::make_shared<lite::NNRTAllocator>(oh_nn_executor, i, lite::NNRT_INPUT);
557+    if (nnrt_allocator == nullptr) {
558+      MS_LOG(ERROR) << "Create NNRTAllocator failed";
559+      return lite::RET_NULL_PTR;
560+    }
561+    inputs_[i].SetAllocator(nnrt_allocator);
562+  }
563+  for (size_t i = 0; i < outputs_.size(); i++) {
564+    auto nnrt_allocator = std::make_shared<lite::NNRTAllocator>(oh_nn_executor, i, lite::NNRT_OUTPUT);
565+    if (nnrt_allocator == nullptr) {
566+      MS_LOG(ERROR) << "Create NNRTAllocator failed";
567+      return lite::RET_NULL_PTR;
568+    }
569+    outputs_[i].SetAllocator(nnrt_allocator);
570+  }
571+  return lite::RET_OK;
572+}
573+
574 int mindspore::NNRTModelKernel::Execute() {
575+  MS_CHECK_TRUE_RET(this->outputs().empty() != true, lite::RET_ERROR);
576+  zero_copy_ = this->outputs()[Index0].allocator() != nullptr;
577+
578+
579   lite::STATUS ret_val = PrepareInputs();
580   if (ret_val != lite::RET_OK) {
581     MS_LOG(ERROR) << "NNRTModelKernel PrepareInputs failed, STATUS is " << ret_val;
582@@ -142,9 +167,17 @@ int mindspore::NNRTModelKernel::PrepareInputs() {
583     oprend->dimensions = dimensions_list.data();
584     oprend->quantParam = quant_param;
585     oprend->type = OH_NN_TENSOR;
586-    MS_LOG_INFO << "input tensor: " << tensor.Name() << ", data: " << (void *)tensor.MutableData() << ", size: " << tensor.DataSize();
587-    OH_NN_ReturnCode ret_code =
588-      OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
589+    MS_LOG_INFO << "input tensor: " << tensor.Name() << ", data: " << (void *)tensor.MutableData()
590+                << ", size: " << tensor.DataSize();
591+
592+    OH_NN_ReturnCode ret_code;
593+    if (zero_copy_) {
594+      OH_NN_Memory mem{tensor.MutableData(), tensor.DataSize()};
595+      ret_code = OH_NNExecutor_SetInputWithMemory(oh_nn_executor, i, oprend, &mem);
596+    } else {
597+      ret_code = OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
598+    }
599+
600     delete (oprend);
601
602     if (!tmp_quant_param.empty()) {
603@@ -165,7 +198,14 @@ int mindspore::NNRTModelKernel::TransferOutputs() {
604   auto output_tensors = this->outputs();
605   for (size_t i = 0; i < output_tensors.size(); i++) {
606     auto tensor = output_tensors[i];
607-    OH_NN_ReturnCode ret_code = OH_NNExecutor_SetOutput(oh_nn_executor, i, tensor.MutableData(), tensor.DataSize());
608+
609+    OH_NN_ReturnCode ret_code;
610+    if (zero_copy_) {
611+      OH_NN_Memory mem{tensor.MutableData(), tensor.DataSize()};
612+      ret_code = OH_NNExecutor_SetOutputWithMemory(oh_nn_executor, i, &mem);
613+    } else {
614+      ret_code = OH_NNExecutor_SetOutput(oh_nn_executor, i, tensor.MutableData(), tensor.DataSize());
615+    }
616     if (ret_code != OH_NN_SUCCESS) {
617       MS_LOG(ERROR) << "NNExecutor SetOutput failed, current out tensor is" << tensor.Name()
618                     << ", OH_NN_ReturnCode = " << ret_code;
619diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
620index ea15f7ca..4f2d4f19 100644
621--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
622+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
623@@ -51,6 +51,9 @@ class NNRTModelKernel : public kernel::Kernel {
624
625  protected:
626   OH_NNExecutor *oh_nn_executor = nullptr;
627+
628+ private:
629+  bool zero_copy_{false};
630 };
631 }  // namespace mindspore
632
633diff --git a/mindspore/lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc b/mindspore/lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc
634index 813a6467..6cedc8c9 100644
635--- a/mindspore/lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc
636+++ b/mindspore/lite/src/litert/kernel/cpu/nnacl/nnacl_kernel.cc
637@@ -105,7 +105,7 @@ int NNACLKernel::OptimizeDataCopy() {
638
639   if (input_tensor->allocator() == nullptr || input_tensor->allocator() != output_tensor->allocator() ||
640       input_tensor->allocator() != ms_context_->allocator || /* runtime allocator */
641-      op_parameter_->is_train_session_) {
642+      op_parameter_->is_train_session_ || !output_tensor->own_data()) {
643     return NNACLKernel::Run();
644   }
645
646diff --git a/mindspore/lite/src/litert/mindrt_executor.cc b/mindspore/lite/src/litert/mindrt_executor.cc
647index e5cd720c..5c08cedf 100644
648--- a/mindspore/lite/src/litert/mindrt_executor.cc
649+++ b/mindspore/lite/src/litert/mindrt_executor.cc
650@@ -295,14 +295,22 @@ void MindrtExecutor::FreeOutputTensor() {
651     if (dst_tensor->data_type() == kNumberTypeGLUInt && src_tensor->data_type() == kNumberTypeGLUInt) {
652       continue;
653     }
654-    if (dst_tensor->allocator() != nullptr) {
655+
656+    if ((dst_tensor->allocator() != nullptr && dst_tensor->own_data()) || dst_tensor->data() == nullptr) {
657+      MS_LOG(DEBUG) << "free data";
658       dst_tensor->FreeData();
659-    } else {
660-      if (dst_tensor->data_type() == src_tensor->data_type()) {
661+    } else if (dst_tensor->data() != nullptr && dst_tensor->data_type() == src_tensor->data_type()) {
662+      if (dst_tensor->allocator() == nullptr) {
663         /* user set graph-output-tensor from outside */
664+        MS_LOG(DEBUG) << "user set graph-output-tensor from outside";
665         src_tensor->set_data(dst_tensor->data());
666         src_tensor->set_own_data(false);
667         src_tensor->set_allocator(nullptr);
668+      } else if (dst_tensor->allocator() == src_tensor->allocator()) {
669+        /* nnrt npu zero copy scene */
670+        MS_LOG(DEBUG) << "zero copy data";
671+        src_tensor->set_data(dst_tensor->data());
672+        src_tensor->set_own_data(dst_tensor->own_data());
673       }
674     }
675   }
676--
6772.17.1
678
679