• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From 46c42f4387b8280a25b16592aa25179093d99fa2 Mon Sep 17 00:00:00 2001
2From: z00574805 <z00574805@notesmail.huawei.com/>
3Date: Wed, 24 May 2023 11:10:00 +0800
4Subject: [PATCH 3/5] xiaoyi-0003
5
6---
7 .jenkins/check/config/filter_cppcheck.txt     |   1 +
8 cmake/package_micro.cmake                     |   4 +-
9 include/api/serialization.h                   |  26 +
10 .../nnacl/base/minimal_filtering_generator.c  |   2 +-
11 .../nnacl/base/minimal_filtering_generator.h  |   1 -
12 .../cpu/kernel/nnacl/custom_gru_parameter.h   |  31 +
13 .../cpu/kernel/nnacl/fp16/custom_gru_fp16.c   |  70 ++
14 .../cpu/kernel/nnacl/fp16/custom_gru_fp16.h   |  32 +
15 .../cpu/kernel/nnacl/fp32/custom_gru_fp32.c   |  71 ++
16 .../cpu/kernel/nnacl/fp32/custom_gru_fp32.h   |  32 +
17 .../device/cpu/kernel/nnacl/fp32/lstm_fp32.h  |   2 +-
18 .../cpu/kernel/nnacl/infer/custom_gru_infer.c |  45 +
19 .../cpu/kernel/nnacl/infer/custom_gru_infer.h |  30 +
20 .../plugin/device/cpu/kernel/nnacl/op_base.h  |   1 +
21 .../lite/include/registry/converter_context.h |   3 +-
22 mindspore/lite/src/CMakeLists.txt             |   1 +
23 mindspore/lite/src/common/graph_util.cc       |   7 +-
24 .../common/ops/populate/custom_populate.cc    |  14 +
25 .../lite/src/runtime/cxx_api/serialization.cc |  31 +
26 .../kernel/cpu/base/group_convolution_base.cc |  34 +-
27 .../cpu/base/group_convolution_creator.cc     |  24 +-
28 .../cpu/base/group_convolution_creator.h      |   8 +-
29 .../cpu/fp16/convolution_delegate_fp16.cc     |   2 +-
30 .../kernel/cpu/fp16/custom_gru_fp16.cc        | 132 +++
31 .../runtime/kernel/cpu/fp16/custom_gru_fp16.h |  40 +
32 .../cpu/fp32/convolution_delegate_fp32.cc     |   2 +-
33 .../kernel/cpu/fp32/custom_gru_fp32.cc        | 251 ++++++
34 .../runtime/kernel/cpu/fp32/custom_gru_fp32.h |  51 ++
35 .../cpu/int8/convolution_int8_creator.cc      |   2 +-
36 mindspore/lite/src/runtime/lite_session.h     |   6 +
37 mindspore/lite/src/train/graph_fusion.cc      |   7 +
38 .../train/optimizer/fusion/gru_fusion_pass.cc | 809 ++++++++++++++++++
39 .../train/optimizer/fusion/gru_fusion_pass.h  |  45 +
40 mindspore/lite/src/train/static_allocator.h   |   6 +-
41 mindspore/lite/src/train/train_export.cc      |  38 +
42 mindspore/lite/src/train/train_export.h       |   1 +
43 mindspore/lite/src/train/train_session.cc     |  37 +-
44 mindspore/lite/src/train/train_session.h      |   3 +
45 .../test/config_level0/micro/micro_arm64.cfg  |   7 +
46 .../config_parser/config_file_parser.cc       |  13 +-
47 .../config_parser/config_file_parser.h        |   2 +
48 .../config_parser/micro_param_parser.cc       |  33 +
49 .../config_parser/micro_param_parser.h        |   2 +
50 mindspore/lite/tools/converter/converter.cc   |  18 +-
51 .../converter_lite/converter_flags.cc         |   4 +-
52 .../converter/micro/cmake/file_list.cmake     |   3 +
53 .../micro/coder/allocator/allocator.cc        |  82 +-
54 .../micro/coder/allocator/allocator.h         |  13 +-
55 .../lite/tools/converter/micro/coder/coder.cc |  51 +-
56 .../lite/tools/converter/micro/coder/coder.h  |   9 +-
57 .../lite/tools/converter/micro/coder/config.h |  10 +
58 .../tools/converter/micro/coder/context.h     |  25 +-
59 .../generator/component/common_component.cc   |   5 +-
60 .../generator/component/weight_component.cc   | 301 +++++--
61 .../generator/component/weight_component.h    |   2 +-
62 .../micro/coder/generator/generator.cc        |   2 +-
63 .../coder/opcoders/base/reshape_base_coder.cc |   5 +
64 .../coder/opcoders/base/stack_base_coder.cc   |  85 ++
65 .../coder/opcoders/base/stack_base_coder.h    |  42 +
66 .../opcoders/base/strided_slice_base_coder.cc |  21 +
67 .../nnacl/fp16/custom_gru_fp16_coder.cc       |  34 +
68 .../nnacl/fp16/custom_gru_fp16_coder.h        |  44 +
69 .../nnacl/fp16/matmul_fp16_base_coder.cc      |  22 +-
70 .../nnacl/fp16/matmul_fp16_base_coder.h       |   4 +-
71 .../opcoders/nnacl/fp16/matmul_fp16_coder.h   |   4 +-
72 .../fp32/convolution_depthwise_fp32_coder.cc  |  69 +-
73 .../fp32/convolution_depthwise_fp32_coder.h   |   8 +-
74 .../fp32/convolution_winograd_fp32_coder.cc   |  98 ++-
75 .../fp32/convolution_winograd_fp32_coder.h    |  15 +-
76 .../nnacl/fp32/custom_gru_fp32_coder.cc       | 214 +++++
77 .../nnacl/fp32/custom_gru_fp32_coder.h        |  64 ++
78 .../opcoders/nnacl/fp32/gather_fp32_coder.cc  |  59 +-
79 .../opcoders/nnacl/fp32/gather_fp32_coder.h   |   2 +
80 .../nnacl/fp32/matmul_fp32_base_coder.cc      |  41 +-
81 .../nnacl/fp32/matmul_fp32_base_coder.h       |   2 +-
82 .../micro/coder/opcoders/op_coder_builder.cc  |  13 +
83 .../micro/coder/opcoders/op_coder_builder.h   |   4 +
84 .../micro/coder/opcoders/op_coder_register.cc |   8 +-
85 .../micro/coder/opcoders/op_coder_register.h  |  23 +-
86 .../nnacl_serializer/nnacl_fp32_serializer.cc |   5 +
87 .../nnacl_serializer/nnacl_fp32_serializer.h  |   2 +
88 .../tools/converter/micro/coder/session.cc    |  49 +-
89 .../tools/converter/micro/coder/session.h     |   2 +-
90 .../converter/micro/coder/utils/type_cast.cc  |   3 +-
91 .../converter/micro/coder/utils/type_cast.h   |   4 +-
92 85 files changed, 3163 insertions(+), 267 deletions(-)
93 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h
94 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c
95 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h
96 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c
97 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h
98 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c
99 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h
100 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc
101 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h
102 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc
103 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h
104 create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc
105 create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h
106 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc
107 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h
108 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc
109 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h
110 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc
111 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h
112
113diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt
114index 6aeb515a..6f6a08f5 100644
115--- a/.jenkins/check/config/filter_cppcheck.txt
116+++ b/.jenkins/check/config/filter_cppcheck.txt
117@@ -56,6 +56,7 @@
118 "mindspore/mindspore/lite/tools/converter/quantizer/quantize_util.cc"                                 "useStlAlgorithm"
119 "mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/"                                          "unreadVariable"
120 "mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/"                                              "unreadVariable"
121+"mindspore/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc"                              "stlFindInsert"
122 "mindspore/mindspore/lite/examples/quick_start_micro/"                                                "syntaxError"
123 "mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental"                               "unreadVariable"
124 "mindspore/mindspore/lite/python/src/pybind_module.cc"                                                "syntaxError"
125diff --git a/cmake/package_micro.cmake b/cmake/package_micro.cmake
126index 3c6da3db..0481e3c3 100644
127--- a/cmake/package_micro.cmake
128+++ b/cmake/package_micro.cmake
129@@ -10,6 +10,8 @@ function(__install_micro_wrapper)
130             COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
131     install(DIRECTORY ${NNACL_DIR}/fp32 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
132             COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
133+    install(DIRECTORY ${NNACL_DIR}/fp16 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
134+            COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
135     install(DIRECTORY ${NNACL_DIR}/kernel DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
136             COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
137     install(DIRECTORY ${NNACL_DIR}/infer DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
138@@ -34,4 +36,4 @@ function(__install_micro_codegen)
139             COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
140     install(TARGETS cmsis_nn ARCHIVE DESTINATION ${CODEGEN_ROOT_DIR}/third_party/lib
141             COMPONENT ${RUNTIME_COMPONENT_NAME})
142-endfunction()
143\ No newline at end of file
144+endfunction()
145diff --git a/include/api/serialization.h b/include/api/serialization.h
146index 1a0c1f57..76d5dbec 100644
147--- a/include/api/serialization.h
148+++ b/include/api/serialization.h
149@@ -105,6 +105,21 @@ class MS_API Serialization {
150                                    QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
151                                    std::vector<std::string> output_tensor_name = {});
152
153+  /// \brief Export model's weights, which can be used in micro only.
154+  ///
155+  /// \param[in] model The model data.
156+  /// \param[in] model_type The model file type.
157+  /// \param[in] weight_file The path of exported weight file.
158+  /// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
159+  /// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format.
160+  /// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
161+  ///
162+  /// \return Status.
163+  inline static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
164+                                                         const std::string &weight_file, bool is_inference = true,
165+                                                         bool enable_fp16 = false,
166+                                                         const std::vector<std::string> &changeable_weights_name = {});
167+
168  private:
169   static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
170                      const std::vector<char> &dec_mode);
171@@ -119,6 +134,10 @@ class MS_API Serialization {
172   static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
173                             QuantizationType quantization_type, bool export_inference_only,
174                             const std::vector<std::vector<char>> &output_tensor_name);
175+  static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
176+                                                  const std::vector<char> &weight_file, bool is_inference,
177+                                                  bool enable_fp16,
178+                                                  const std::vector<std::vector<char>> &changeable_weights_name);
179 };
180
181 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
182@@ -150,5 +169,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
183                      VectorStringToChar(output_tensor_name));
184 }
185
186+Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
187+                                                        const std::string &weight_file, bool is_inference,
188+                                                        bool enable_fp16,
189+                                                        const std::vector<std::string> &changeable_weights_name) {
190+  return ExportWeightsCollaborateWithMicro(model, model_type, StringToChar(weight_file), is_inference, enable_fp16,
191+                                           VectorStringToChar(changeable_weights_name));
192+}
193 }  // namespace mindspore
194 #endif  // MINDSPORE_INCLUDE_API_SERIALIZATION_H
195diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c
196index 3796e47b..81bf8ddf 100644
197--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c
198+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c
199@@ -16,9 +16,9 @@
200 #include "nnacl/base/minimal_filtering_generator.h"
201 #include <string.h>
202 #include <math.h>
203-#include "nnacl/fp32/winograd_utils.h"
204 #include "nnacl/errorcode.h"
205 #include "nnacl/intrinsics/ms_simd_instructions.h"
206+#include "nnacl/fp32/pack_fp32.h"
207
208 void Polynomial(const float *interval, float *m, int degree) {
209   for (int i = 0; i < degree; ++i) {
210diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h
211index 01b013e8..fc0fa0e6 100644
212--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h
213+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h
214@@ -21,7 +21,6 @@
215 #include <arm_neon.h>
216 #endif
217 #include <stdbool.h>
218-#include "nnacl/pack.h"
219
220 #ifdef __cplusplus
221 extern "C" {
222diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h
223new file mode 100644
224index 00000000..3bb8a444
225--- /dev/null
226+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h
227@@ -0,0 +1,31 @@
228+/**
229+ * Copyright 2023 Huawei Technologies Co., Ltd
230+ *
231+ * Licensed under the Apache License, Version 2.0 (the "License");
232+ * you may not use this file except in compliance with the License.
233+ * You may obtain a copy of the License at
234+ *
235+ * http://www.apache.org/licenses/LICENSE-2.0
236+ *
237+ * Unless required by applicable law or agreed to in writing, software
238+ * distributed under the License is distributed on an "AS IS" BASIS,
239+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
240+ * See the License for the specific language governing permissions and
241+ * limitations under the License.
242+ */
243+#ifndef MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_
244+#define MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_
245+
246+#include "nnacl/op_base.h"
247+
248+typedef struct CustomGruParameter {
249+  // Primitive parameter
250+  OpParameter op_parameter_;
251+  // shape correlative
252+  int num_step;
253+  int batch_size;
254+  int input_size;
255+  int hidden_size;
256+} CustomGruParameter;
257+
258+#endif  // MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_
259diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c
260new file mode 100644
261index 00000000..6e754569
262--- /dev/null
263+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c
264@@ -0,0 +1,70 @@
265+#ifdef ENABLE_ARM64
266+/**
267+ * Copyright 2023 Huawei Technologies Co., Ltd
268+ *
269+ * Licensed under the Apache License, Version 2.0 (the "License");
270+ * you may not use this file except in compliance with the License.
271+ * You may obtain a copy of the License at
272+ *
273+ * http://www.apache.org/licenses/LICENSE-2.0
274+ *
275+ * Unless required by applicable law or agreed to in writing, software
276+ * distributed under the License is distributed on an "AS IS" BASIS,
277+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
278+ * See the License for the specific language governing permissions and
279+ * limitations under the License.
280+ */
281+#include "nnacl/fp16/custom_gru_fp16.h"
282+#include "nnacl/fp16/activation_fp16.h"
283+#include "nnacl/fp16/arithmetic_fp16.h"
284+#include "nnacl/fp16/matmul_fp16.h"
285+
286+void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input,
287+                   const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden,
288+                   const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param) {
289+  int num_step = gru_param->num_step;
290+  int batch_size = gru_param->batch_size;
291+  int input_size = gru_param->input_size;
292+  int hidden_size = gru_param->hidden_size;
293+  int output_size = batch_size * hidden_size;
294+  int double_output_size = output_size * C2NUM;
295+  int col_align = UP_ROUND(hidden_size, C8NUM);
296+  int weight_in_offset = col_align * input_size;
297+  int weight_hidden_offset = col_align * hidden_size;
298+  float16_t *input_gate = buffer[1];
299+  float16_t *hidden_gate = buffer[C3NUM];
300+  for (int i = 0; i < num_step; ++i) {
301+    if (batch_size != 1) {
302+      RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size);
303+      for (int j = 0; j < C3NUM; ++j) {
304+        MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size,
305+                           bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size,
306+                           OutType_Nhwc);
307+      }
308+      RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size);
309+      for (int j = 0; j < C3NUM; ++j) {
310+        MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
311+                           bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size,
312+                           OutType_Nhwc);
313+      }
314+    } else {
315+      for (int j = 0; j < C3NUM; ++j) {
316+        VecMatmulFp16(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size,
317+                      bias_input + j * col_align, ActType_No, input_size, hidden_size);
318+        VecMatmulFp16(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
319+                      bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size);
320+      }
321+    }
322+    ElementAddFp16(input_gate, hidden_gate, input_gate, double_output_size);
323+    SigmoidFp16(input_gate, input_gate, double_output_size);
324+    ElementMulFp16(input_gate, hidden_gate + double_output_size, input_gate, output_size);
325+    ElementAddFp16(input_gate, input_gate + double_output_size, input_gate, output_size);
326+    TanhFp16(input_gate, input_gate, output_size);
327+    ElementSubFp16(init_h, input_gate, hidden_gate, output_size);
328+    ElementMulFp16(input_gate + output_size, hidden_gate, hidden_gate, output_size);
329+    ElementAddFp16(input_gate, hidden_gate, output, output_size);
330+    init_h = output;
331+    output += output_size;
332+  }
333+}
334+#endif
335diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h
336new file mode 100644
337index 00000000..67008f03
338--- /dev/null
339+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h
340@@ -0,0 +1,32 @@
341+/**
342+ * Copyright 2023 Huawei Technologies Co., Ltd
343+ *
344+ * Licensed under the Apache License, Version 2.0 (the "License");
345+ * you may not use this file except in compliance with the License.
346+ * You may obtain a copy of the License at
347+ *
348+ * http://www.apache.org/licenses/LICENSE-2.0
349+ *
350+ * Unless required by applicable law or agreed to in writing, software
351+ * distributed under the License is distributed on an "AS IS" BASIS,
352+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
353+ * See the License for the specific language governing permissions and
354+ * limitations under the License.
355+ */
356+#ifndef MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_
357+#define MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_
358+#ifdef ENABLE_ARM64
359+#include "nnacl/custom_gru_parameter.h"
360+
361+#ifdef __cplusplus
362+extern "C" {
363+#endif
364+void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input,
365+                   const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden,
366+                   const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param);
367+#ifdef __cplusplus
368+}
369+#endif
370+
371+#endif
372+#endif  // MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_
373diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c
374new file mode 100644
375index 00000000..caeece4a
376--- /dev/null
377+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c
378@@ -0,0 +1,71 @@
379+#ifdef ENABLE_ARM64
380+/**
381+ * Copyright 2023 Huawei Technologies Co., Ltd
382+ *
383+ * Licensed under the Apache License, Version 2.0 (the "License");
384+ * you may not use this file except in compliance with the License.
385+ * You may obtain a copy of the License at
386+ *
387+ * http://www.apache.org/licenses/LICENSE-2.0
388+ *
389+ * Unless required by applicable law or agreed to in writing, software
390+ * distributed under the License is distributed on an "AS IS" BASIS,
391+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
392+ * See the License for the specific language governing permissions and
393+ * limitations under the License.
394+ */
395+#include "nnacl/fp32/custom_gru_fp32.h"
396+#include "nnacl/fp32/activation_fp32.h"
397+#include "nnacl/fp32/arithmetic_fp32.h"
398+#include "nnacl/fp32/matmul_fp32.h"
399+#include "nnacl/fp32/pack_fp32.h"
400+
401+void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden,
402+               const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4],
403+               const CustomGruParameter *gru_param) {
404+  int num_step = gru_param->num_step;
405+  int batch_size = gru_param->batch_size;
406+  int input_size = gru_param->input_size;
407+  int hidden_size = gru_param->hidden_size;
408+  int output_size = batch_size * hidden_size;
409+  int double_output_size = output_size * C2NUM;
410+  int col_align = UP_ROUND(hidden_size, C8NUM);
411+  int weight_in_offset = col_align * input_size;
412+  int weight_hidden_offset = col_align * hidden_size;
413+  float *input_gate = buffer[1];
414+  float *hidden_gate = buffer[C3NUM];
415+  for (int i = 0; i < num_step; ++i) {
416+    if (batch_size != 1) {
417+      RowMajor2Col12Major(input + i * batch_size * input_size, buffer[0], batch_size, input_size);
418+      for (int j = 0; j < C3NUM; ++j) {
419+        MatMulOpt(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size,
420+                  bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size,
421+                  OutType_Nhwc);
422+      }
423+      RowMajor2Col12Major(init_h, buffer[C2NUM], batch_size, hidden_size);
424+      for (int j = 0; j < C3NUM; ++j) {
425+        MatMulOpt(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
426+                  bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size,
427+                  OutType_Nhwc);
428+      }
429+    } else {
430+      for (int j = 0; j < C3NUM; ++j) {
431+        MatVecMulFp32Neon64(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size,
432+                            bias_input + j * col_align, ActType_No, input_size, hidden_size, col_align);
433+        MatVecMulFp32Neon64(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
434+                            bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size, col_align);
435+      }
436+    }
437+    ElementAdd(input_gate, hidden_gate, input_gate, double_output_size);
438+    Sigmoid(input_gate, double_output_size, input_gate);
439+    ElementMul(input_gate, hidden_gate + double_output_size, input_gate, output_size);
440+    ElementAdd(input_gate, input_gate + double_output_size, input_gate, output_size);
441+    Tanh(input_gate, output_size, input_gate);
442+    ElementSub(init_h, input_gate, hidden_gate, output_size);
443+    ElementMul(input_gate + output_size, hidden_gate, hidden_gate, output_size);
444+    ElementAdd(input_gate, hidden_gate, output, output_size);
445+    init_h = output;
446+    output += output_size;
447+  }
448+}
449+#endif
450diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h
451new file mode 100644
452index 00000000..576726c5
453--- /dev/null
454+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h
455@@ -0,0 +1,32 @@
456+/**
457+ * Copyright 2023 Huawei Technologies Co., Ltd
458+ *
459+ * Licensed under the Apache License, Version 2.0 (the "License");
460+ * you may not use this file except in compliance with the License.
461+ * You may obtain a copy of the License at
462+ *
463+ * http://www.apache.org/licenses/LICENSE-2.0
464+ *
465+ * Unless required by applicable law or agreed to in writing, software
466+ * distributed under the License is distributed on an "AS IS" BASIS,
467+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
468+ * See the License for the specific language governing permissions and
469+ * limitations under the License.
470+ */
471+#ifndef MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_
472+#define MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_
473+#ifdef ENABLE_ARM64
474+#include "nnacl/custom_gru_parameter.h"
475+
476+#ifdef __cplusplus
477+extern "C" {
478+#endif
479+void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden,
480+               const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4],
481+               const CustomGruParameter *gru_param);
482+#ifdef __cplusplus
483+}
484+#endif
485+
486+#endif
487+#endif  // MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_
488diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h
489index b608e1e0..8e217d02 100644
490--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h
491+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h
492@@ -43,7 +43,7 @@ int ElementOptMulAcc(const float *input0, const float input1, float *output, con
493
494 void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
495                   const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
496-                  float *buffer[6], const LstmParameter *lstm_param);
497+                  float *buffer[C6NUM], const LstmParameter *lstm_param);
498
499 void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
500           const float *state_bias, float *hidden_state, float *cell_state, float *buffer[7],
501diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c
502new file mode 100644
503index 00000000..060d04cf
504--- /dev/null
505+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c
506@@ -0,0 +1,45 @@
507+/**
508+ * Copyright 2023 Huawei Technologies Co., Ltd
509+ *
510+ * Licensed under the Apache License, Version 2.0 (the "License");
511+ * you may not use this file except in compliance with the License.
512+ * You may obtain a copy of the License at
513+ *
514+ * http://www.apache.org/licenses/LICENSE-2.0
515+ *
516+ * Unless required by applicable law or agreed to in writing, software
517+ * distributed under the License is distributed on an "AS IS" BASIS,
518+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
519+ * See the License for the specific language governing permissions and
520+ * limitations under the License.
521+ */
522+
523+#include "nnacl/infer/custom_gru_infer.h"
524+#include "nnacl/infer/infer_register.h"
525+
526+int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
527+                        OpParameter *parameter) {
528+  int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C6NUM, 1);
529+  if (check_ret != NNACL_OK) {
530+    return check_ret;
531+  }
532+
533+  const TensorC *input = inputs[0];
534+  TensorC *output = outputs[0];
535+  SetDataTypeFormat(output, input);
536+  if (!InferFlag(inputs, inputs_size)) {
537+    return NNACL_INFER_INVALID;
538+  }
539+  if (input->shape_size_ != C3NUM) {
540+    return NNACL_INPUT_TENSOR_ERROR;
541+  }
542+  SetShapeTensor(output, input);
543+  const TensorC *weight_in = inputs[1];
544+  if (weight_in->shape_size_ != C2NUM || weight_in->shape_[0] % C3NUM != 0) {
545+    return NNACL_INPUT_TENSOR_ERROR;
546+  }
547+  output->shape_[C2NUM] = weight_in[0].shape_[0] / C3NUM;
548+  return NNACL_OK;
549+}
550+
551+REG_INFER(CustomGru, PrimType_Inner_CustomGru, CustomGruInferShape)
552diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h
553new file mode 100644
554index 00000000..830150d5
555--- /dev/null
556+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h
557@@ -0,0 +1,30 @@
558+/**
559+ * Copyright 2023 Huawei Technologies Co., Ltd
560+ *
561+ * Licensed under the Apache License, Version 2.0 (the "License");
562+ * you may not use this file except in compliance with the License.
563+ * You may obtain a copy of the License at
564+ *
565+ * http://www.apache.org/licenses/LICENSE-2.0
566+ *
567+ * Unless required by applicable law or agreed to in writing, software
568+ * distributed under the License is distributed on an "AS IS" BASIS,
569+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
570+ * See the License for the specific language governing permissions and
571+ * limitations under the License.
572+ */
573+#ifndef MINDSPORE_NNACL_CUSTOM_GRU_INFER_H
574+#define MINDSPORE_NNACL_CUSTOM_GRU_INFER_H
575+#include "nnacl/infer/common_infer.h"
576+
577+#ifdef __cplusplus
578+extern "C" {
579+#endif
580+
581+int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
582+                        OpParameter *parameter);
583+
584+#ifdef __cplusplus
585+}
586+#endif
587+#endif  // MINDSPORE_NNACL_CUSTOM_GRU_INFER_H
588diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
589index 5876bdf6..8c219212 100644
590--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
591+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
592@@ -520,6 +520,7 @@ enum PrimType {
593   PrimType_Inner_ShapeFusion = 10003,
594   PrimType_Inner_GraphKernel = 10004,
595   PrimType_Inner_ThirdPartyModel = 10005,
596+  PrimType_Inner_CustomGru = 10006,
597   PrimType_InnerOpMax,
598   PrimType_InnerOpMin = PrimType_Inner_ToFormat
599 };
600diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h
601index dd6e6d08..0d2de256 100644
602--- a/mindspore/lite/include/registry/converter_context.h
603+++ b/mindspore/lite/include/registry/converter_context.h
604@@ -34,7 +34,8 @@ enum MS_API FmkType : int {
605   kFmkTypeTflite = 4,
606   kFmkTypePytorch = 5,
607   kFmkTypeThirdParty = 6,
608-  kFmkTypeEnd = 7,  // For range check purpose, valid range: [0, kFmkTypeEnd)
609+  kFmkTypeMsLite = 7,
610+  kFmkTypeEnd = 8,  // For range check purpose, valid range: [0, kFmkTypeEnd)
611 };
612
613 /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
614diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
615index 48e0fe7c..a80656f2 100644
616--- a/mindspore/lite/src/CMakeLists.txt
617+++ b/mindspore/lite/src/CMakeLists.txt
618@@ -380,6 +380,7 @@ set(TRAIN_SRC
619         ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
620         ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
621         ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc
622+        ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/gru_fusion_pass.cc
623         ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc
624         ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
625         ${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
626diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc
627index 1688fb79..d5cb2a98 100644
628--- a/mindspore/lite/src/common/graph_util.cc
629+++ b/mindspore/lite/src/common/graph_util.cc
630@@ -23,6 +23,7 @@
631 #include "src/common/log_adapter.h"
632 #include "src/common/version_manager.h"
633 #include "include/errorcode.h"
634+#include "nnacl/op_base.h"
635
636 namespace mindspore {
637 namespace lite {
638@@ -86,9 +87,9 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor
639
640 // only support op_type from current schema
641 bool IsPackedOp(int op_type) {
642-  static const std::vector<int> packed_ops = {schema::PrimitiveType_Conv2DFusion,
643-                                              schema::PrimitiveType_Conv2dTransposeFusion,
644-                                              schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion};
645+  static const std::vector<int> packed_ops = {
646+    schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion,
647+    schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion, PrimType::PrimType_Inner_CustomGru};
648   return IsContain(packed_ops, op_type);
649 }
650
651diff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc
652index f1506ece..391a587b 100644
653--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc
654+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc
655@@ -17,10 +17,22 @@
656 #include "src/common/log_adapter.h"
657 #include "src/tensor.h"
658 #include "nnacl/custom_parameter.h"
659+#include "nnacl/custom_gru_parameter.h"
660 using mindspore::schema::PrimitiveType_Custom;
661
662 namespace mindspore {
663 namespace lite {
664+OpParameter *CreateCustomGruParameter() {
665+  auto *param = static_cast<CustomGruParameter *>(malloc(sizeof(CustomGruParameter)));
666+  if (param == nullptr) {
667+    MS_LOG(ERROR) << "malloc CustomGruParameter failed.";
668+    return nullptr;
669+  }
670+  memset(param, 0, sizeof(CustomGruParameter));
671+  param->op_parameter_.type_ = PrimType_Inner_CustomGru;
672+  return reinterpret_cast<OpParameter *>(param);
673+}
674+
675 OpParameter *PopulateCustomParameter(const void *prim) {
676   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
677   auto primitive = static_cast<const schema::Primitive *>(prim);
678@@ -62,6 +74,8 @@ OpParameter *PopulateCustomParameter(const void *prim) {
679     // Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary.
680     param->attr_data[0] = static_cast<char *>(const_cast<void *>(prim));
681     return reinterpret_cast<OpParameter *>(param);
682+  } else if (type == "CustomGRU") {
683+    return CreateCustomGruParameter();
684   } else {
685     MS_LOG(ERROR) << "Unsupported custom type: " << type;
686   }
687diff --git a/mindspore/lite/src/runtime/cxx_api/serialization.cc b/mindspore/lite/src/runtime/cxx_api/serialization.cc
688index 8405f4b2..ddf69d23 100644
689--- a/mindspore/lite/src/runtime/cxx_api/serialization.cc
690+++ b/mindspore/lite/src/runtime/cxx_api/serialization.cc
691@@ -212,4 +212,35 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
692
693   return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
694 }
695+
696+Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
697+                                                        const std::vector<char> &weight_file, bool is_inference,
698+                                                        bool enable_fp16,
699+                                                        const std::vector<std::vector<char>> &changeable_weights_name) {
700+  if (model.impl_ == nullptr) {
701+    MS_LOG(ERROR) << "Model implement is null.";
702+    return kLiteUninitializedObj;
703+  }
704+  if (!model.impl_->IsTrainModel()) {
705+    MS_LOG(ERROR) << "Model is not TrainModel.";
706+    return kLiteError;
707+  }
708+  if (model_type != kMindIR && model_type != kMindIR_Lite) {
709+    MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
710+    return kLiteParamInvalid;
711+  }
712+  if (model.impl_->session_ == nullptr) {
713+    MS_LOG(ERROR) << "Model session is nullptr.";
714+    return kLiteError;
715+  }
716+  if (!is_inference) {
717+    MS_LOG(ERROR) << "Currently, can only export inference-model's weights.";
718+    return kLiteNotSupport;
719+  }
720+  auto ret = model.impl_->session_->ExportWeightsCollaborateWithMicro(CharToString(weight_file), lite::MT_INFERENCE,
721+                                                                      lite::FT_FLATBUFFERS, enable_fp16,
722+                                                                      VectorCharToString(changeable_weights_name));
723+
724+  return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
725+}
726 }  // namespace mindspore
727diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
728index b5370ddd..0352ad19 100644
729--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
730+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
731@@ -35,6 +35,8 @@ int GroupConvolutionBaseCPUKernel::Prepare() {
732       return ret;
733     }
734   }
735+  conv_param_->input_channel_ *= group_num_;
736+  conv_param_->output_channel_ *= group_num_;
737   // if infer shape is done, resize func will be invoked in sub kernels
738   return RET_OK;
739 }
740@@ -99,11 +101,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
741       auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front();
742       CHECK_NULL_RETURN(sub_kernel_in_tensor);
743       sub_kernel_in_tensor->set_shape(in_shape);
744-      ret = sub_kernel_in_tensor->MallocData();
745-      if (ret != RET_OK) {
746-        MS_LOG(ERROR) << "sub kernel in tensor malloc data failed.";
747-        return ret;
748-      }
749       // out
750       auto out_tensor = out_tensors_.front();
751       CHECK_NULL_RETURN(out_tensor);
752@@ -113,11 +110,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
753       for (auto tensor : sub_kernel_out_tensors) {
754         CHECK_NULL_RETURN(tensor);
755         tensor->set_shape(out_shape);
756-        ret = tensor->MallocData();
757-        if (ret != RET_OK) {
758-          MS_LOG(ERROR) << "sub kernel out tensor malloc data failed.";
759-          return ret;
760-        }
761       }
762     }
763     ret = ReSize();
764@@ -177,7 +169,22 @@ int GroupConvolutionBaseCPUKernel::Run() {
765   ori_out_data_ = out_tensors_[0]->data();
766   CHECK_NULL_RETURN(ori_out_data_);
767   for (int i = 0; i < group_num_; ++i) {
768-    // first, separate group conv input into several parts. This step must be in runtime stage.
769+    // first, malloc data for sub_kernel's tensors
770+    auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front();
771+    ret = sub_kernel_in_tensor->MallocData();
772+    if (ret != RET_OK) {
773+      MS_LOG(ERROR) << "sub kernel in tensor malloc data failed.";
774+      return ret;
775+    }
776+    auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors();
777+    for (auto tensor : sub_kernel_out_tensors) {
778+      ret = tensor->MallocData();
779+      if (ret != RET_OK) {
780+        MS_LOG(ERROR) << "sub kernel out tensor malloc data failed.";
781+        return ret;
782+      }
783+    }
784+    // second, separate group conv input into several parts. This step must be in runtime stage.
785     ret = SeparateInput(i);
786     if (ret != RET_OK) {
787       MS_LOG(ERROR) << "Separate input failed.";
788@@ -195,6 +202,11 @@ int GroupConvolutionBaseCPUKernel::Run() {
789       MS_LOG(ERROR) << "Concat output failed.";
790       return ret;
791     }
792+    // free data
793+    sub_kernel_in_tensor->FreeData();
794+    for (auto tensor : sub_kernel_out_tensors) {
795+      tensor->FreeData();
796+    }
797   }
798   return RET_OK;
799 }
800diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
801index fc78a887..81b0aac2 100644
802--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
803+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
804@@ -53,15 +53,6 @@ void FreeCurrentConv(ConvParameter *conv_param, std::vector<lite::Tensor *> *new
805   }
806 }
807
808-static inline lite::Tensor *TensorMalloc(lite::Tensor *tensor) {
809-  if (tensor->MallocData() != lite::RET_OK) {
810-    delete tensor;
811-    MS_LOG(ERROR) << "malloc tensor data failed.";
812-    return nullptr;
813-  }
814-  return tensor;
815-}
816-
817 lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector<int> &shape, const int index) {
818   auto new_tensor =
819     new (std::nothrow) lite::Tensor(tensor->data_type(), shape, mindspore::NHWC, lite::Category::CONST_TENSOR);
820@@ -87,7 +78,7 @@ lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector<in
821   return new_tensor;
822 }
823
824-lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) {
825+lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info) {
826   auto tensor = new (std::nothrow) lite::Tensor();
827   if (tensor == nullptr) {
828     MS_LOG(ERROR) << "new tensor failed.";
829@@ -97,11 +88,7 @@ lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) {
830   tensor->set_format(tensor_info.format_);
831   tensor->set_category(tensor_info.tensor_type_);
832   tensor->set_shape(tensor_info.shape_);
833-
834-  if (inferred) {
835-    // set shape of out tensor
836-    return TensorMalloc(tensor);
837-  }
838+  tensor->set_allocator(tensor_info.allocator_);
839   return tensor;
840 }
841
842@@ -129,7 +116,8 @@ void GroupConvCreator::FreeGroupConvs() {
843 }
844
845 int GroupConvCreator::NewInputTensor(std::vector<lite::Tensor *> *tensors) {
846-  auto in_tensor = CreateVarTensor({input_shape_, mindspore::NHWC, data_type_, lite::Category::VAR, true}, infered_);
847+  auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr;
848+  auto in_tensor = CreateVarTensor({input_shape_, allocator, mindspore::NHWC, data_type_, lite::Category::VAR, true});
849   if (in_tensor == nullptr) {
850     return lite::RET_ERROR;
851   }
852@@ -138,7 +126,9 @@ int GroupConvCreator::NewInputTensor(std::vector<lite::Tensor *> *tensors) {
853 }
854
855 int GroupConvCreator::NewOutputTensor(std::vector<lite::Tensor *> *tensors, const lite::Tensor *output) const {
856-  auto out_tensor = CreateVarTensor({output_shape_, output->format(), data_type_, output->category(), false}, infered_);
857+  auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr;
858+  auto out_tensor =
859+    CreateVarTensor({output_shape_, allocator, output->format(), data_type_, output->category(), false});
860   if (out_tensor == nullptr) {
861     return lite::RET_ERROR;
862   }
863diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h
864index b4a4f768..27aa0cc8 100644
865--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h
866+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h
867@@ -22,10 +22,12 @@
868 #include "src/runtime/lite_kernel.h"
869 #include "nnacl/conv_parameter.h"
870 #include "src/runtime/tensor_category.h"
871+#include "include/api/allocator.h"
872
873 namespace mindspore::kernel {
874 struct TensorInfo {
875   std::vector<int> shape_;
876+  AllocatorPtr allocator_;
877   mindspore::Format format_;
878   TypeId data_type_;
879   lite::Category tensor_type_;
880@@ -35,8 +37,9 @@ struct TensorInfo {
881 class GroupConvCreator {
882  public:
883   GroupConvCreator(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs, OpParameter *op_parameter,
884-                   bool is_quant, TypeId data_type)
885-      : origin_inputs_(std::move(inputs)),
886+                   bool is_quant, TypeId data_type, const lite::InnerContext *ctx)
887+      : ms_context_(ctx),
888+        origin_inputs_(std::move(inputs)),
889         origin_outputs_(std::move(outputs)),
890         is_quant_(is_quant),
891         data_type_(data_type) {
892@@ -64,6 +67,7 @@ class GroupConvCreator {
893   int NewOutputTensor(std::vector<lite::Tensor *> *tensors, const lite::Tensor *output) const;
894
895  private:
896+  const lite::InnerContext *ms_context_ = nullptr;
897   std::vector<lite::Tensor *> origin_inputs_;
898   std::vector<lite::Tensor *> origin_outputs_;
899   std::vector<kernel::LiteKernel *> group_convs_;
900diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc
901index 7aa823b0..17ba38ff 100644
902--- a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc
903+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc
904@@ -202,7 +202,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor
905                                                   const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
906                                                   const InnerContext *ctx) {
907   auto *group_conv_creator =
908-    new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16);
909+    new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16, ctx);
910   if (group_conv_creator == nullptr) {
911     MS_LOG(ERROR) << "new GroupConvCreator fail";
912     free(op_parameter);
913diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc
914new file mode 100644
915index 00000000..7851eecb
916--- /dev/null
917+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc
918@@ -0,0 +1,132 @@
919+#ifdef ENABLE_ARM64
920+/**
921+ * Copyright 2023 Huawei Technologies Co., Ltd
922+ *
923+ * Licensed under the Apache License, Version 2.0 (the "License");
924+ * you may not use this file except in compliance with the License.
925+ * You may obtain a copy of the License at
926+ *
927+ * http://www.apache.org/licenses/LICENSE-2.0
928+ *
929+ * Unless required by applicable law or agreed to in writing, software
930+ * distributed under the License is distributed on an "AS IS" BASIS,
931+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
932+ * See the License for the specific language governing permissions and
933+ * limitations under the License.
934+ */
935+#include "src/runtime/kernel/cpu/fp16/custom_gru_fp16.h"
936+#include <algorithm>
937+#include "src/runtime/kernel_registry.h"
938+#include "include/errorcode.h"
939+#include "src/common/log_adapter.h"
940+#include "src/runtime/pack_weight_manager.h"
941+#include "nnacl/custom_gru_parameter.h"
942+#include "nnacl/fp16/custom_gru_fp16.h"
943+#include "nnacl/fp16/matmul_fp16.h"
944+
945+using mindspore::lite::KernelRegistrar;
946+using mindspore::lite::RET_ERROR;
947+using mindspore::lite::RET_NOT_SUPPORT;
948+using mindspore::lite::RET_OK;
949+
950+namespace mindspore::kernel {
951+int CustomGruFp16CPUKernel::InitWeightAndBias() {
952+  auto weight_shape = in_tensors_[1]->shape();
953+  auto hidden_size = weight_shape[0] / C3NUM;
954+  auto col_align = UP_ROUND(hidden_size, col_tile_);
955+  auto weight_in_pack_size = static_cast<size_t>(col_align * weight_shape[1]) * sizeof(float16_t);
956+  bool is_packed = false;
957+  weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData(
958+    in_tensors_[SECOND_INPUT]->data(), static_cast<size_t>(weight_in_pack_size * C3NUM), &is_packed);
959+  MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed.");
960+  if (!is_packed) {
961+    auto weight_in_src = static_cast<const float16_t *>(in_tensors_[SECOND_INPUT]->data());
962+    for (int i = 0; i < C3NUM; ++i) {
963+      RowMajor2Col8MajorFp16(weight_in_src + i * hidden_size * weight_shape[1],
964+                             static_cast<float16_t *>(weight_in_) + i * col_align * weight_shape[1], hidden_size,
965+                             weight_shape[1], false);
966+    }
967+  }
968+  auto weight_hidden_pack_size = static_cast<size_t>(col_align * hidden_size) * sizeof(float16_t);
969+  is_packed = false;
970+  weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData(
971+    in_tensors_[THIRD_INPUT]->data(), static_cast<size_t>(weight_hidden_pack_size * C3NUM), &is_packed);
972+  MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed.");
973+  if (!is_packed) {
974+    auto weight_hidden_src = static_cast<const float16_t *>(in_tensors_[THIRD_INPUT]->data());
975+    for (int i = 0; i < C3NUM; ++i) {
976+      RowMajor2Col8MajorFp16(weight_hidden_src + i * hidden_size * weight_shape[1],
977+                             static_cast<float16_t *>(weight_hidden_) + i * col_align * weight_shape[1], hidden_size,
978+                             hidden_size, false);
979+    }
980+  }
981+  auto bias_pack_size = static_cast<size_t>(col_align) * sizeof(float16_t);
982+  auto bias = reinterpret_cast<float16_t *>(malloc(bias_pack_size * C6NUM));
983+  if (bias == nullptr) {
984+    MS_LOG(ERROR) << "malloc for packing bias failed.";
985+    return lite::RET_NULL_PTR;
986+  }
987+  (void)memset(bias, 0, bias_pack_size * C6NUM);
988+  bias_in_ = bias;
989+  bias_hidden_ = bias + col_align * C3NUM;
990+  auto bias_in_src = static_cast<const float16_t *>(in_tensors_[FOURTH_INPUT]->data());
991+  for (int i = 0; i < C3NUM; ++i) {
992+    (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float16_t));
993+  }
994+  auto bias_hidden_src = static_cast<const float16_t *>(in_tensors_[FIFTH_INPUT]->data());
995+  for (int i = 0; i < C3NUM; ++i) {
996+    (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float16_t));
997+  }
998+  if (in_tensors_[SIXTH_INPUT]->IsConst()) {
999+    init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size());
1000+    MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed.");
1001+    (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size());
1002+  }
1003+  return RET_OK;
1004+}
1005+
1006+int CustomGruFp16CPUKernel::Run() {
1007+  auto input = reinterpret_cast<float16_t *>(in_tensors_[FIRST_INPUT]->data());
1008+  if (input == nullptr) {
1009+    MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_;
1010+    return lite::RET_NULL_PTR;
1011+  }
1012+  if (!in_tensors_[SIXTH_INPUT]->IsConst()) {
1013+    init_h_ = in_tensors_[SIXTH_INPUT]->data();
1014+  }
1015+  if (init_h_ == nullptr) {
1016+    MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_;
1017+    return lite::RET_NULL_PTR;
1018+  }
1019+  auto output = reinterpret_cast<float16_t *>(out_tensors_.front()->data());
1020+  if (output == nullptr) {
1021+    MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_;
1022+    return lite::RET_NULL_PTR;
1023+  }
1024+  MallocRunBuffer(sizeof(float16_t));
1025+  if (run_buffer_ == nullptr) {
1026+    MS_LOG(ERROR) << "malloc running buffer failed." << name_;
1027+    return lite::RET_NULL_PTR;
1028+  }
1029+  auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_);
1030+  auto row_align = UP_ROUND(param->batch_size, row_tile_);
1031+  auto run_buffer = reinterpret_cast<float16_t *>(run_buffer_);
1032+  float16_t *buffer[C4NUM] = {
1033+    run_buffer, run_buffer + row_align * param->input_size,
1034+    run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM,
1035+    run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM};
1036+  CustomGruFp16(output, input, static_cast<float16_t *>(weight_in_), static_cast<float16_t *>(weight_hidden_),
1037+                static_cast<float16_t *>(bias_in_), static_cast<float16_t *>(bias_hidden_),
1038+                static_cast<float16_t *>(init_h_), buffer, param);
1039+  if (ms_context_->allocator != nullptr) {
1040+    ms_context_->allocator->Free(run_buffer_);
1041+  } else {
1042+    free(run_buffer_);
1043+  }
1044+  run_buffer_ = nullptr;
1045+  return RET_OK;
1046+}
1047+
1048+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimType_Inner_CustomGru, LiteKernelCreator<CustomGruFp16CPUKernel>)
1049+}  // namespace mindspore::kernel
1050+#endif
1051diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h
1052new file mode 100644
1053index 00000000..d7ed313a
1054--- /dev/null
1055+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h
1056@@ -0,0 +1,40 @@
1057+/**
1058+ * Copyright 2023 Huawei Technologies Co., Ltd
1059+ *
1060+ * Licensed under the Apache License, Version 2.0 (the "License");
1061+ * you may not use this file except in compliance with the License.
1062+ * You may obtain a copy of the License at
1063+ *
1064+ * http://www.apache.org/licenses/LICENSE-2.0
1065+ *
1066+ * Unless required by applicable law or agreed to in writing, software
1067+ * distributed under the License is distributed on an "AS IS" BASIS,
1068+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1069+ * See the License for the specific language governing permissions and
1070+ * limitations under the License.
1071+ */
1072+
1073+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_
1074+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_
1075+#ifdef ENABLE_ARM64
1076+#include <vector>
1077+#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h"
1078+
1079+namespace mindspore::kernel {
1080+class CustomGruFp16CPUKernel : public CustomGruCPUKernel {
1081+ public:
1082+  CustomGruFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
1083+                         const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
1084+      : CustomGruCPUKernel(parameter, inputs, outputs, ctx) {
1085+    row_tile_ = C4NUM;
1086+    col_tile_ = C8NUM;
1087+  }
1088+  ~CustomGruFp16CPUKernel() override = default;
1089+  int Run() override;
1090+
1091+ protected:
1092+  int InitWeightAndBias() override;
1093+};
1094+}  // namespace mindspore::kernel
1095+#endif
1096+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_
1097diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc
1098index bbbf8488..3514e5b4 100644
1099--- a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc
1100+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc
1101@@ -359,7 +359,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
1102 kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
1103                                                   const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
1104                                                   const lite::InnerContext *ctx) {
1105-  auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32);
1106+  auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32, ctx);
1107   auto group_kernel = new (std::nothrow) GroupConvolutionFp32CPUKernel(
1108     op_parameter, inputs, outputs, ctx, group_conv_creator, reinterpret_cast<ConvParameter *>(op_parameter)->group_);
1109   if (group_kernel == nullptr) {
1110diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc
1111new file mode 100644
1112index 00000000..c85a1283
1113--- /dev/null
1114+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc
1115@@ -0,0 +1,251 @@
1116+#ifdef ENABLE_ARM64
1117+/**
1118+ * Copyright 2023 Huawei Technologies Co., Ltd
1119+ *
1120+ * Licensed under the Apache License, Version 2.0 (the "License");
1121+ * you may not use this file except in compliance with the License.
1122+ * You may obtain a copy of the License at
1123+ *
1124+ * http://www.apache.org/licenses/LICENSE-2.0
1125+ *
1126+ * Unless required by applicable law or agreed to in writing, software
1127+ * distributed under the License is distributed on an "AS IS" BASIS,
1128+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1129+ * See the License for the specific language governing permissions and
1130+ * limitations under the License.
1131+ */
1132+#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h"
1133+#include <algorithm>
1134+#include "src/runtime//kernel_registry.h"
1135+#include "include/errorcode.h"
1136+#include "src/common/log_adapter.h"
1137+#include "src/runtime//pack_weight_manager.h"
1138+#include "nnacl/fp32/pack_fp32.h"
1139+#include "nnacl/custom_gru_parameter.h"
1140+#include "nnacl/fp32/custom_gru_fp32.h"
1141+
1142+using mindspore::lite::KernelRegistrar;
1143+using mindspore::lite::RET_ERROR;
1144+using mindspore::lite::RET_NOT_SUPPORT;
1145+using mindspore::lite::RET_OK;
1146+
1147+namespace mindspore::kernel {
1148+CustomGruCPUKernel::~CustomGruCPUKernel() {
1149+  if (weight_in_) {
1150+    lite::PackWeightManager::GetInstance()->Free(weight_in_);
1151+    weight_in_ = nullptr;
1152+  }
1153+  if (weight_hidden_) {
1154+    lite::PackWeightManager::GetInstance()->Free(weight_hidden_);
1155+    weight_hidden_ = nullptr;
1156+  }
1157+  if (bias_in_) {
1158+    free(bias_in_);
1159+    bias_in_ = nullptr;
1160+    bias_hidden_ = nullptr;
1161+  }
1162+  if (in_tensors_[SIXTH_INPUT]->IsConst() && init_h_) {
1163+    free(init_h_);
1164+    init_h_ = nullptr;
1165+  }
1166+}
1167+
1168+int CustomGruCPUKernel::Prepare() {
1169+  CHECK_LESS_RETURN(in_tensors_.size(), C6NUM);
1170+  CHECK_LESS_RETURN(out_tensors_.size(), 1);
1171+  if (in_tensors_[FIRST_INPUT]->IsConst()) {
1172+    MS_LOG(ERROR) << "Built-in CustomGru first-input must be a variable." << name_;
1173+    return RET_NOT_SUPPORT;
1174+  }
1175+  for (size_t i = 1; i < C5NUM; ++i) {
1176+    if (!in_tensors_[i]->IsConst()) {
1177+      MS_LOG(ERROR) << "Built-in CustomGru only support first-input and last-input is variable." << name_;
1178+      return RET_NOT_SUPPORT;
1179+    }
1180+  }
1181+  if (InitParamter() != RET_OK) {
1182+    MS_LOG(ERROR) << "Init Built-in CustomGru Parameter failed." << name_;
1183+    return RET_ERROR;
1184+  }
1185+  if (InitWeightAndBias() != RET_OK) {
1186+    MS_LOG(ERROR) << "Init Built-in CustomGru Weight and bias failed." << name_;
1187+    return RET_ERROR;
1188+  }
1189+  if (!InferShapeDone()) {
1190+    return RET_OK;
1191+  }
1192+  return ReSize();
1193+}
1194+
1195+int CustomGruCPUKernel::InitParamter() {
1196+  auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_);
1197+  thread_num_ = 1;
1198+  auto weight_in_shape = in_tensors_[1]->shape();
1199+  auto weight_hidden_shape = in_tensors_[C2NUM]->shape();
1200+  if (weight_in_shape.size() != C2NUM || weight_hidden_shape.size() != C2NUM) {
1201+    MS_LOG(ERROR) << "Built-in CustomGru's weight must be 2D." << name_;
1202+    return RET_ERROR;
1203+  }
1204+  if (weight_in_shape[0] != weight_hidden_shape[0]) {
1205+    MS_LOG(ERROR) << "Built-in CustomGru's weight-in and weight-hidden first-dim must be same." << name_;
1206+    return RET_ERROR;
1207+  }
1208+  if (weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) {
1209+    MS_LOG(ERROR) << "Built-in CustomGru's weight-hidden first-dim must be 3 * second-dim." << name_;
1210+    return RET_ERROR;
1211+  }
1212+  auto bias_in_shape = in_tensors_[C3NUM]->shape();
1213+  auto bias_hidden_shape = in_tensors_[C4NUM]->shape();
1214+  if (bias_in_shape.size() != 1) {
1215+    MS_LOG(ERROR) << "Built-in CustomGru's bias must be 1D." << name_;
1216+    return RET_ERROR;
1217+  }
1218+  if (bias_in_shape != bias_hidden_shape) {
1219+    MS_LOG(ERROR) << "Built-in CustomGru's bias-in and bias-hidden must have same shape." << name_;
1220+    return RET_ERROR;
1221+  }
1222+  if (bias_in_shape.back() != weight_in_shape.front()) {
1223+    MS_LOG(ERROR) << "Built-in CustomGru's bias-in shape don't match with the first-dim of weight." << name_;
1224+    return RET_ERROR;
1225+  }
1226+  if (bias_in_shape.front() % C3NUM != 0) {
1227+    MS_LOG(ERROR) << "The first-dim of CustomGru's weight must be 3 * hidden.";
1228+    return RET_ERROR;
1229+  }
1230+  param->input_size = weight_in_shape.back();
1231+  param->hidden_size = bias_in_shape.front() / C3NUM;
1232+  return RET_OK;
1233+}
1234+
1235+int CustomGruCPUKernel::InitWeightAndBias() {
1236+  auto weight_shape = in_tensors_[1]->shape();
1237+  auto hidden_size = weight_shape[0] / C3NUM;
1238+  auto col_align = UP_ROUND(hidden_size, col_tile_);
1239+  auto weight_in_pack_size = static_cast<size_t>(col_align * weight_shape[1]) * sizeof(float);
1240+  bool is_packed = false;
1241+  weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData(
1242+    in_tensors_[SECOND_INPUT]->data(), static_cast<size_t>(weight_in_pack_size * C3NUM), &is_packed);
1243+  MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed.");
1244+  if (!is_packed) {
1245+    auto weight_in_src = static_cast<const float *>(in_tensors_[SECOND_INPUT]->data());
1246+    for (int i = 0; i < C3NUM; ++i) {
1247+      RowMajor2Col8Major(weight_in_src + i * hidden_size * weight_shape[1],
1248+                         static_cast<float *>(weight_in_) + i * col_align * weight_shape[1], hidden_size,
1249+                         weight_shape[1]);
1250+    }
1251+  }
1252+  auto weight_hidden_pack_size = static_cast<size_t>(col_align * hidden_size) * sizeof(float);
1253+  is_packed = false;
1254+  weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData(
1255+    in_tensors_[THIRD_INPUT]->data(), static_cast<size_t>(weight_hidden_pack_size * C3NUM), &is_packed);
1256+  MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed.");
1257+  if (!is_packed) {
1258+    auto weight_hidden_src = static_cast<const float *>(in_tensors_[THIRD_INPUT]->data());
1259+    for (int i = 0; i < C3NUM; ++i) {
1260+      RowMajor2Col8Major(weight_hidden_src + i * hidden_size * weight_shape[1],
1261+                         static_cast<float *>(weight_hidden_) + i * col_align * weight_shape[1], hidden_size,
1262+                         hidden_size);
1263+    }
1264+  }
1265+  auto bias_pack_size = static_cast<size_t>(col_align) * sizeof(float);
1266+  auto bias = reinterpret_cast<float *>(malloc(bias_pack_size * C6NUM));
1267+  if (bias == nullptr) {
1268+    MS_LOG(ERROR) << "malloc for packing bias failed.";
1269+    return lite::RET_NULL_PTR;
1270+  }
1271+  (void)memset(bias, 0, bias_pack_size * C6NUM);
1272+  bias_in_ = bias;
1273+  bias_hidden_ = bias + col_align * C3NUM;
1274+  auto bias_in_src = static_cast<const float *>(in_tensors_[FOURTH_INPUT]->data());
1275+  for (int i = 0; i < C3NUM; ++i) {
1276+    (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float));
1277+  }
1278+  auto bias_hidden_src = static_cast<const float *>(in_tensors_[FIFTH_INPUT]->data());
1279+  for (int i = 0; i < C3NUM; ++i) {
1280+    (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float));
1281+  }
1282+  if (in_tensors_[SIXTH_INPUT]->IsConst()) {
1283+    init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size());
1284+    MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed.");
1285+    (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size());
1286+  }
1287+  return RET_OK;
1288+}
1289+
1290+int CustomGruCPUKernel::ReSize() {
1291+  auto in_shape = in_tensors_.front()->shape();
1292+  if (in_shape.size() != C3NUM) {
1293+    MS_LOG(ERROR) << "Built-in CustomGru's first-input must be 3D." << name_;
1294+    return RET_ERROR;
1295+  }
1296+  auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_);
1297+  param->num_step = in_shape[0];
1298+  param->batch_size = in_shape[1];
1299+  if (in_shape.back() != param->input_size) {
1300+    MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input don't match its weight." << name_;
1301+    return RET_ERROR;
1302+  }
1303+  return RET_OK;
1304+}
1305+
1306+void CustomGruCPUKernel::MallocRunBuffer(size_t data_type_size) {
1307+  if (run_buffer_ != nullptr) {
1308+    return;
1309+  }
1310+  auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_);
1311+  auto row_align = UP_ROUND(param->batch_size, row_tile_);
1312+  auto run_buffer_size =
1313+    (row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C6NUM) *
1314+    data_type_size;
1315+  if (ms_context_->allocator != nullptr) {
1316+    run_buffer_ = ms_context_->allocator->Malloc(run_buffer_size);
1317+  } else {
1318+    run_buffer_ = malloc(run_buffer_size);
1319+  }
1320+}
1321+
1322+int CustomGruCPUKernel::Run() {
1323+  auto input = reinterpret_cast<float *>(in_tensors_[FIRST_INPUT]->data());
1324+  if (input == nullptr) {
1325+    MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_;
1326+    return lite::RET_NULL_PTR;
1327+  }
1328+  if (!in_tensors_[SIXTH_INPUT]->IsConst()) {
1329+    init_h_ = in_tensors_[SIXTH_INPUT]->data();
1330+  }
1331+  if (init_h_ == nullptr) {
1332+    MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_;
1333+    return lite::RET_NULL_PTR;
1334+  }
1335+  auto output = reinterpret_cast<float *>(out_tensors_.front()->data());
1336+  if (output == nullptr) {
1337+    MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_;
1338+    return lite::RET_NULL_PTR;
1339+  }
1340+  MallocRunBuffer(sizeof(float));
1341+  if (run_buffer_ == nullptr) {
1342+    MS_LOG(ERROR) << "malloc running buffer failed." << name_;
1343+    return lite::RET_NULL_PTR;
1344+  }
1345+  auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_);
1346+  auto row_align = UP_ROUND(param->batch_size, row_tile_);
1347+  auto run_buffer = reinterpret_cast<float *>(run_buffer_);
1348+  float *buffer[C4NUM] = {
1349+    run_buffer, run_buffer + row_align * param->input_size,
1350+    run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM,
1351+    run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM};
1352+  CustomGru(output, input, static_cast<float *>(weight_in_), static_cast<float *>(weight_hidden_),
1353+            static_cast<float *>(bias_in_), static_cast<float *>(bias_hidden_), static_cast<float *>(init_h_), buffer,
1354+            param);
1355+  if (ms_context_->allocator != nullptr) {
1356+    ms_context_->allocator->Free(run_buffer_);
1357+  } else {
1358+    free(run_buffer_);
1359+  }
1360+  run_buffer_ = nullptr;
1361+  return RET_OK;
1362+}
1363+
1364+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomGru, LiteKernelCreator<CustomGruCPUKernel>)
1365+}  // namespace mindspore::kernel
1366+#endif
1367diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h
1368new file mode 100644
1369index 00000000..e70e408f
1370--- /dev/null
1371+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h
1372@@ -0,0 +1,51 @@
1373+/**
1374+ * Copyright 2023 Huawei Technologies Co., Ltd
1375+ *
1376+ * Licensed under the Apache License, Version 2.0 (the "License");
1377+ * you may not use this file except in compliance with the License.
1378+ * You may obtain a copy of the License at
1379+ *
1380+ * http://www.apache.org/licenses/LICENSE-2.0
1381+ *
1382+ * Unless required by applicable law or agreed to in writing, software
1383+ * distributed under the License is distributed on an "AS IS" BASIS,
1384+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1385+ * See the License for the specific language governing permissions and
1386+ * limitations under the License.
1387+ */
1388+
1389+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_
1390+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_
1391+#ifdef ENABLE_ARM64
1392+#include <vector>
1393+#include "src/runtime/lite_kernel.h"
1394+
1395+namespace mindspore::kernel {
1396+class CustomGruCPUKernel : public LiteKernel {
1397+ public:
1398+  CustomGruCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
1399+                     const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
1400+      : LiteKernel(parameter, inputs, outputs, ctx) {}
1401+  ~CustomGruCPUKernel() override;
1402+  int Prepare() override;
1403+  int ReSize() override;
1404+  int Run() override;
1405+
1406+ private:
1407+  int InitParamter();
1408+
1409+ protected:
1410+  void MallocRunBuffer(size_t data_type_size);
1411+  virtual int InitWeightAndBias();
1412+  int row_tile_{C12NUM};
1413+  int col_tile_{C8NUM};
1414+  void *weight_in_{nullptr};
1415+  void *weight_hidden_{nullptr};
1416+  void *bias_in_{nullptr};
1417+  void *bias_hidden_{nullptr};
1418+  void *init_h_{nullptr};
1419+  void *run_buffer_{nullptr};
1420+};
1421+}  // namespace mindspore::kernel
1422+#endif
1423+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_
1424diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc
1425index ef03171a..0ff780c7 100644
1426--- a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc
1427+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc
1428@@ -107,7 +107,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor
1429                   << conv_param->input_channel_;
1430     return nullptr;
1431   }
1432-  auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8);
1433+  auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8, ctx);
1434   if (group_conv_creator == nullptr) {
1435     MS_LOG(ERROR) << "group_conv_creator is nullptr.";
1436     return nullptr;
1437diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h
1438index 2fdb1eb7..d5a672bb 100644
1439--- a/mindspore/lite/src/runtime/lite_session.h
1440+++ b/mindspore/lite/src/runtime/lite_session.h
1441@@ -106,6 +106,12 @@ class MS_API LiteSession {
1442                      std::vector<std::string> out_put_tensor_name = {}) {
1443     return mindspore::lite::RET_ERROR;
1444   }
1445+  virtual int ExportWeightsCollaborateWithMicro(const std::string &file_name,
1446+                                                lite::ModelType model_type = lite::MT_TRAIN,
1447+                                                lite::FormatType = lite::FT_FLATBUFFERS, bool enable_fp16 = false,
1448+                                                const std::vector<std::string> &changeable_weights_name = {}) {
1449+    return mindspore::lite::RET_ERROR;
1450+  }
1451   virtual int UpdateWeights(std::vector<lite::Tensor *> new_weights) { return mindspore::lite::RET_ERROR; }
1452   virtual std::vector<lite::Tensor *> GetFeatureMaps() const {
1453     std::vector<lite::Tensor *> features;
1454diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc
1455index 1af44e45..03f5675e 100644
1456--- a/mindspore/lite/src/train/graph_fusion.cc
1457+++ b/mindspore/lite/src/train/graph_fusion.cc
1458@@ -22,6 +22,7 @@
1459 #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
1460 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
1461 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
1462+#include "src/train/optimizer/fusion/gru_fusion_pass.h"
1463 #include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
1464 #include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h"
1465
1466@@ -41,6 +42,12 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) {
1467     MS_LOG(ERROR) << "graph is nullptr.";
1468     return RET_ERROR;
1469   }
1470+  auto gru_fusion = std::make_shared<GruFusionPass>();
1471+  MS_CHECK_TRUE_MSG(gru_fusion != nullptr, RET_NULL_PTR, "Create GruFusion object failed.");
1472+  if (gru_fusion->Run(graph) != RET_OK) {
1473+    MS_LOG(ERROR) << "Do GruFusion failed.";
1474+    return RET_ERROR;
1475+  }
1476   auto old_nodes = GetGraphNodes(*graph);
1477   Optimizer fusion_optimizer;
1478   fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass());
1479diff --git a/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc
1480new file mode 100644
1481index 00000000..435686e5
1482--- /dev/null
1483+++ b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc
1484@@ -0,0 +1,809 @@
1485+/**
1486+ * Copyright 2023 Huawei Technologies Co., Ltd
1487+ *
1488+ * Licensed under the Apache License, Version 2.0 (the "License");
1489+ * you may not use this file except in compliance with the License.
1490+ * You may obtain a copy of the License at
1491+ *
1492+ * http://www.apache.org/licenses/LICENSE-2.0
1493+ *
1494+ * Unless required by applicable law or agreed to in writing, software
1495+ * distributed under the License is distributed on an "AS IS" BASIS,
1496+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1497+ * See the License for the specific language governing permissions and
1498+ * limitations under the License.
1499+ */
1500+
1501+#include "src/train/optimizer/fusion/gru_fusion_pass.h"
1502+#include <algorithm>
1503+#include <map>
1504+#include <memory>
1505+#include <set>
1506+#include <string>
1507+#include <utility>
1508+#include <vector>
1509+#include "src/common/log_adapter.h"
1510+#include "include/errorcode.h"
1511+#include "nnacl/op_base.h"
1512+
1513+namespace mindspore {
1514+namespace lite {
1515+namespace {
1516+constexpr size_t kSplitOutSize = 3;
1517+constexpr uint32_t kAdd0 = 0;
1518+constexpr uint32_t kAdd1 = 1;
1519+constexpr uint32_t kAdd2 = 2;
1520+constexpr uint32_t kAdd3 = 3;
1521+constexpr uint32_t kAdd4 = 4;
1522+constexpr uint32_t kAdd5 = 5;
1523+constexpr uint32_t kSub = 6;
1524+constexpr uint32_t kMul0 = 7;
1525+constexpr uint32_t kMul1 = 8;
1526+constexpr uint32_t kTanh = 9;
1527+constexpr uint32_t kSigmoid0 = 10;
1528+constexpr uint32_t kSigmoid1 = 11;
1529+constexpr uint32_t kSplit0 = 12;
1530+constexpr uint32_t kSplit1 = 13;
1531+constexpr uint32_t kMatmul0 = 14;
1532+constexpr uint32_t kMatmul1 = 15;
1533+constexpr uint32_t kInputH = 16;
1534+constexpr uint32_t kInputI = 17;
1535+constexpr auto kCustomGRU = "CustomGRU";
1536+
1537+bool CheckCommon(schema::MetaGraphT *graph, uint32_t node_index, schema::PrimitiveType type, size_t in_nums,
1538+                 size_t out_nums) {
1539+  if (graph->nodes.size() <= node_index) {
1540+    return false;
1541+  }
1542+  const auto &node = graph->nodes[node_index];
1543+  if (node == nullptr || node->primitive == nullptr) {
1544+    return false;
1545+  }
1546+  const auto &value = node->primitive->value;
1547+  if (value.type != type) {
1548+    return false;
1549+  }
1550+  if (value.value == nullptr) {
1551+    return false;
1552+  }
1553+  if ((in_nums > 0 && node->inputIndex.size() != in_nums) || node->outputIndex.size() != out_nums) {
1554+    return false;
1555+  }
1556+  return std::all_of(node->inputIndex.begin(), node->inputIndex.end(),
1557+                     [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; }) &&
1558+         std::all_of(node->outputIndex.begin(), node->outputIndex.end(),
1559+                     [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; });
1560+}
1561+
1562+template <schema::PrimitiveType T, typename P>
1563+bool CheckArithmetic(schema::MetaGraphT *graph, uint32_t node_index) {
1564+  if (!CheckCommon(graph, node_index, T, kInputSize1, 1)) {
1565+    return false;
1566+  }
1567+  const auto &node = graph->nodes[node_index];
1568+  const auto &value = node->primitive->value;
1569+  const auto add_attr = static_cast<const P *>(value.value);
1570+  if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
1571+    return false;
1572+  }
1573+  auto tensor_indexes = node->inputIndex;
1574+  (void)tensor_indexes.insert(tensor_indexes.end(), node->outputIndex.begin(), node->outputIndex.end());
1575+  std::vector<int> shape;
1576+  for (size_t i = 0; i < tensor_indexes.size(); ++i) {
1577+    if (i == 0) {
1578+      shape = graph->allTensors[tensor_indexes[i]]->dims;
1579+      continue;
1580+    }
1581+    if (graph->allTensors[tensor_indexes[i]]->dims != shape) {
1582+      return false;
1583+    }
1584+  }
1585+  return true;
1586+}
1587+
1588+template <schema::ActivationType T>
1589+bool CheckActivation(schema::MetaGraphT *graph, uint32_t node_index) {
1590+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_Activation, 1, 1)) {
1591+    return false;
1592+  }
1593+  const auto &value = graph->nodes[node_index]->primitive->value;
1594+  const auto add_attr = static_cast<const schema::ActivationT *>(value.value);
1595+  if (add_attr->activation_type != T) {
1596+    return false;
1597+  }
1598+  return true;
1599+}
1600+
1601+bool CheckBiasAdd(schema::MetaGraphT *graph, uint32_t node_index) {
1602+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_AddFusion, kInputSize1, 1) &&
1603+      !CheckCommon(graph, node_index, schema::PrimitiveType_BiasAdd, kInputSize1, 1)) {
1604+    return false;
1605+  }
1606+  const auto &node = graph->nodes[node_index];
1607+  const auto &value = node->primitive->value;
1608+  if (value.type == schema::PrimitiveType_AddFusion) {
1609+    const auto add_attr = static_cast<const schema::AddFusionT *>(value.value);
1610+    if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
1611+      return false;
1612+    }
1613+  }
1614+  auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims;
1615+  auto in_shape1 = graph->allTensors[node->inputIndex[1]]->dims;
1616+  if (in_shape1.size() != 1 || in_shape0.empty() || in_shape0.back() != in_shape1.back()) {
1617+    return false;
1618+  }
1619+  return true;
1620+}
1621+
1622+bool CheckMatmul(schema::MetaGraphT *graph, uint32_t node_index) {
1623+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_MatMulFusion, kInputSize1, 1)) {
1624+    return false;
1625+  }
1626+  const auto &node = graph->nodes[node_index];
1627+  const auto &value = node->primitive->value;
1628+  const auto matmul_attr = static_cast<const schema::MatMulFusionT *>(value.value);
1629+  if (matmul_attr->activation_type != schema::ActivationType_NO_ACTIVATION) {
1630+    return false;
1631+  }
1632+  auto out_shape = graph->allTensors[node->outputIndex.front()]->dims;
1633+  return out_shape.size() == kInputSize1;
1634+}
1635+
1636+bool CheckSplit(schema::MetaGraphT *graph, uint32_t node_index) {
1637+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_Split, 1, kSplitOutSize)) {
1638+    return false;
1639+  }
1640+  const auto &node = graph->nodes[node_index];
1641+  if (node->inputIndex.size() != 1 || node->outputIndex.size() != kSplitOutSize) {
1642+    return false;
1643+  }
1644+  auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims;
1645+  auto out_shape0 = graph->allTensors[node->outputIndex[0]]->dims;
1646+  auto out_shape1 = graph->allTensors[node->outputIndex[1]]->dims;
1647+  auto out_shape2 = graph->allTensors[node->outputIndex[kInputSize1]]->dims;
1648+  if (out_shape0 != out_shape1 || out_shape0 != out_shape2) {
1649+    return false;
1650+  }
1651+  if (in_shape0.empty() || out_shape0.empty()) {
1652+    return false;
1653+  }
1654+  if (in_shape0.back() != (out_shape0.back() + out_shape1.back() + out_shape2.back())) {
1655+    return false;
1656+  }
1657+  return true;
1658+}
1659+
1660+bool CheckStack(schema::MetaGraphT *graph, uint32_t node_index) {
1661+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_Stack, 0, 1)) {
1662+    return false;
1663+  }
1664+  const auto &node = graph->nodes[node_index];
1665+  const auto &value = node->primitive->value;
1666+  const auto stack_attr = static_cast<const schema::StackT *>(value.value);
1667+  auto out_shape = graph->allTensors[node->outputIndex.front()]->dims;
1668+  if (out_shape.empty()) {
1669+    return false;
1670+  }
1671+  auto axis = stack_attr->axis;
1672+  if (axis < 0) {
1673+    axis += static_cast<int64_t>(out_shape.size());
1674+  }
1675+  return axis == 0;
1676+}
1677+
1678+bool CheckSqueeze(schema::MetaGraphT *graph, uint32_t node_index) {
1679+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_Squeeze, 0, 1)) {
1680+    return false;
1681+  }
1682+  const auto &node = graph->nodes[node_index];
1683+  if (node->inputIndex.size() != 1 && node->inputIndex.size() != kInputSize1) {
1684+    return false;
1685+  }
1686+  int axis = 0;
1687+  if (node->inputIndex.size() == kInputSize1) {
1688+    const auto &data = graph->allTensors[node->inputIndex[1]]->data;
1689+    if (data.size() != sizeof(int)) {
1690+      return false;
1691+    }
1692+    axis = *(reinterpret_cast<const int *>(data.data()));
1693+  } else {
1694+    const auto &value = node->primitive->value;
1695+    const auto squeeze_attr = static_cast<const schema::SqueezeT *>(value.value);
1696+    if (squeeze_attr->axis.size() != 1) {
1697+      return false;
1698+    }
1699+    axis = squeeze_attr->axis.front();
1700+  }
1701+  auto in_shape = graph->allTensors[node->inputIndex[0]]->dims;
1702+  if (in_shape.empty()) {
1703+    return false;
1704+  }
1705+  if (axis < 0) {
1706+    axis += static_cast<int>(in_shape.size());
1707+  }
1708+  return axis == 0;
1709+}
1710+
1711+std::vector<int> GetStridedSlicePoints(const schema::TensorT *tensor, int64_t mask) {
1712+  if (tensor->data.empty()) {
1713+    return {};
1714+  }
1715+  auto origin_data = reinterpret_cast<const int *>(tensor->data.data());
1716+  size_t num = tensor->data.size() / sizeof(int);
1717+  std::vector<int> data;
1718+  for (size_t i = 0; i < num; ++i) {
1719+    bool ineffective = (mask & (1 << i));
1720+    int cur_point = ineffective ? 0 : origin_data[i];
1721+    data.push_back(cur_point);
1722+  }
1723+  return data;
1724+}
1725+
1726+bool CheckStridedSlice(schema::MetaGraphT *graph, uint32_t node_index, int batch_position) {
1727+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_StridedSlice, C4NUM, 1)) {
1728+    return false;
1729+  }
1730+  const auto &node = graph->nodes[node_index];
1731+  const auto &step_tensor = graph->allTensors[node->inputIndex.back()];
1732+  if (!step_tensor->data.empty()) {
1733+    const auto data = reinterpret_cast<int *>(step_tensor->data.data());
1734+    auto size = step_tensor->data.size() / sizeof(int);
1735+    if (std::any_of(data, data + size, [](int val) { return val != 1; })) {
1736+      return false;
1737+    }
1738+  }
1739+  auto in_shape = graph->allTensors[node->inputIndex.front()]->dims;
1740+  auto out_shape = graph->allTensors[node->outputIndex.back()]->dims;
1741+  if (in_shape.size() != out_shape.size() || in_shape.empty()) {
1742+    return false;
1743+  }
1744+  for (size_t i = 1; i < in_shape.size(); ++i) {
1745+    if (in_shape[i] != out_shape[i]) {
1746+      return false;
1747+    }
1748+  }
1749+  const auto &value = node->primitive->value;
1750+  const auto strided_slice_attr = static_cast<const schema::StridedSliceT *>(value.value);
1751+  if (strided_slice_attr->ellipsis_mask != 0 || strided_slice_attr->new_axis_mask != 0 ||
1752+      strided_slice_attr->shrink_axis_mask != 0) {
1753+    return false;
1754+  }
1755+  auto begin = GetStridedSlicePoints(graph->allTensors[node->inputIndex[1]].get(), strided_slice_attr->begin_mask);
1756+  if (begin.empty()) {
1757+    return false;
1758+  }
1759+  return begin.front() == batch_position;
1760+}
1761+
1762+bool CheckGruCell(schema::MetaGraphT *graph, uint32_t node_index) {
1763+  if (!CheckCommon(graph, node_index, schema::PrimitiveType_Custom, C6NUM, 1)) {
1764+    return false;
1765+  }
1766+  const auto &node = graph->nodes[node_index];
1767+  const auto &value = node->primitive->value;
1768+  const auto gru_attr = static_cast<const schema::CustomT *>(value.value);
1769+  return gru_attr->type == kCustomGRU;
1770+}
1771+
1772+std::unique_ptr<schema::CustomT> CreateCustom() {
1773+  auto ConvertToAttr = [](const std::string &key, const std::vector<uint8_t> &value) {
1774+    auto attr = std::make_unique<schema::AttributeT>();
1775+    attr->name = key;
1776+    attr->data = value;
1777+    return attr;
1778+  };
1779+  auto attrs = std::make_unique<schema::CustomT>();
1780+  MS_CHECK_TRUE_MSG(attrs != nullptr, nullptr, "Create CustomT failed.");
1781+  attrs->type = kCustomGRU;
1782+  std::vector<uint8_t> transpose_a{false};
1783+  std::vector<uint8_t> transpose_b{true};
1784+  std::vector<uint8_t> built_in{true};
1785+
1786+  attrs->attr.push_back(ConvertToAttr("transpose_a", transpose_a));
1787+  attrs->attr.push_back(ConvertToAttr("transpose_b", transpose_b));
1788+  attrs->attr.push_back(ConvertToAttr("builtin", built_in));
1789+  return attrs;
1790+}
1791+
1792+struct InNodeInfo {
1793+  int node_index;
1794+  std::vector<uint32_t> in_indexes;
1795+};
1796+
1797+struct OutNodeInfo {
1798+  int node_index;
1799+  uint32_t out_index;
1800+};
1801+
1802+struct camp {
1803+  bool operator()(uint32_t left, uint32_t right) const { return left > right; }
1804+};
1805+}  // namespace
1806+
1807+class LinkInfoManager {
1808+ public:
1809+  explicit LinkInfoManager(schema::MetaGraphT *graph) : graph_{graph} {
1810+    auto &all_nodes = graph->nodes;
1811+    for (int node_index = 0; node_index < static_cast<int>(all_nodes.size()); ++node_index) {
1812+      auto in_indexes = all_nodes[node_index]->inputIndex;
1813+      for (uint32_t index = 0; index < static_cast<uint32_t>(in_indexes.size()); ++index) {
1814+        if (link_info_manager_.find(in_indexes[index]) == link_info_manager_.end()) {
1815+          link_info_manager_[in_indexes[index]] = std::make_pair(std::vector<InNodeInfo>{}, OutNodeInfo{-1, 0});
1816+        }
1817+        auto &in_infos = link_info_manager_[in_indexes[index]].first;
1818+        auto iter = in_infos.begin();
1819+        for (; iter != in_infos.end(); ++iter) {
1820+          if (iter->node_index == node_index) {
1821+            break;
1822+          }
1823+        }
1824+        if (iter != in_infos.end()) {
1825+          iter->in_indexes.push_back(index);
1826+        } else {
1827+          in_infos.push_back({node_index, {index}});
1828+        }
1829+      }
1830+
1831+      auto out_indexes = all_nodes[node_index]->outputIndex;
1832+      for (uint32_t index = 0; index < static_cast<uint32_t>(out_indexes.size()); ++index) {
1833+        link_info_manager_[out_indexes[index]].second = OutNodeInfo{node_index, index};
1834+      }
1835+    }
1836+  }
1837+
1838+  const auto &GetLinkInfos() const { return link_info_manager_; }
1839+
1840+  void Replace(uint32_t node_index, std::unique_ptr<CNodeT> node) { graph_->nodes[node_index].swap(node); }
1841+
1842+  void AddDeleteNodes(const std::set<uint32_t> &node_indexes) {
1843+    delete_nodes_.insert(node_indexes.begin(), node_indexes.end());
1844+  }
1845+
1846+  void UpdateMetaGraph() {
1847+    auto &main_graph = graph_->subGraph.front();
1848+    for (auto node_index : delete_nodes_) {
1849+      graph_->nodes.erase(graph_->nodes.begin() + node_index);
1850+    }
1851+    main_graph->nodeIndices.clear();
1852+    for (uint32_t index = 0; index < static_cast<uint32_t>(graph_->nodes.size()); ++index) {
1853+      main_graph->nodeIndices.push_back(index);
1854+    }
1855+    std::map<uint32_t, uint32_t> tensor_maps;
1856+    BuildTensorMap(&tensor_maps);
1857+    auto UpdateTensorIndex = [&tensor_maps](std::vector<uint32_t> *origin) {
1858+      auto origin_indexes = *origin;
1859+      origin->clear();
1860+      (void)std::transform(origin_indexes.begin(), origin_indexes.end(), std::back_inserter(*origin),
1861+                           [&tensor_maps](uint32_t origin_index) { return tensor_maps[origin_index]; });
1862+    };
1863+    UpdateTensorIndex(&graph_->inputIndex);
1864+    for (auto &node : graph_->nodes) {
1865+      UpdateTensorIndex(&node->inputIndex);
1866+      UpdateTensorIndex(&node->outputIndex);
1867+    }
1868+    UpdateTensorIndex(&graph_->outputIndex);
1869+    main_graph->inputIndices = graph_->inputIndex;
1870+    main_graph->outputIndices = graph_->outputIndex;
1871+    main_graph->tensorIndices.clear();
1872+    for (uint32_t index = 0; index < static_cast<uint32_t>(tensor_maps.size()); ++index) {
1873+      main_graph->tensorIndices.push_back(index);
1874+    }
1875+    std::vector<std::unique_ptr<TensorT>> tensors;
1876+    graph_->allTensors.swap(tensors);
1877+    graph_->allTensors.resize(tensor_maps.size());
1878+    for (auto &tensor_map : tensor_maps) {
1879+      graph_->allTensors[tensor_map.second].swap(tensors[tensor_map.first]);
1880+    }
1881+  }
1882+
1883+ private:
1884+  void BuildTensorMap(std::map<uint32_t, uint32_t> *tensor_maps) {
1885+    uint32_t new_index = 0;
1886+    auto InsertElements = [tensor_maps, &new_index](const std::vector<uint32_t> &indexes) mutable {
1887+      for (auto index : indexes) {
1888+        if (tensor_maps->find(index) != tensor_maps->end()) {
1889+          continue;
1890+        }
1891+        (*tensor_maps)[index] = new_index++;
1892+      }
1893+    };
1894+    InsertElements(graph_->inputIndex);
1895+    for (auto &node : graph_->nodes) {
1896+      InsertElements(node->inputIndex);
1897+      InsertElements(node->outputIndex);
1898+    }
1899+    InsertElements(graph_->outputIndex);
1900+  }
1901+
1902+  schema::MetaGraphT *graph_{nullptr};
1903+  std::set<uint32_t, camp> delete_nodes_;
1904+  // tensor_index, <in_node_infos, out_node_info>
1905+  std::map<uint32_t, std::pair<std::vector<InNodeInfo>, OutNodeInfo>> link_info_manager_;
1906+};
1907+
1908+class GruCellFusion {
1909+ public:
1910+  GruCellFusion() = default;
1911+  ~GruCellFusion() = default;
1912+  STATUS Run(schema::MetaGraphT *graph) {
1913+    MS_ASSERT(graph != nullptr);
1914+    MS_ASSERT(graph->subGraph.size() == 1);
1915+    link_info_manager_ = std::make_shared<LinkInfoManager>(graph);
1916+    graph_ = graph;
1917+    DefinePattern();
1918+    for (uint32_t node_index = 0; node_index < static_cast<uint32_t>(graph->nodes.size()); ++node_index) {
1919+      if (!MatchPattern(node_index)) {
1920+        continue;
1921+      }
1922+      if (CreateCustomGruCell() != RET_OK) {
1923+        MS_LOG(ERROR) << "Create Custom-Gru failed.";
1924+        return RET_ERROR;
1925+      }
1926+    }
1927+    link_info_manager_->UpdateMetaGraph();
1928+    return RET_OK;
1929+  }
1930+
1931+ private:
1932+  struct NodeInfo {
1933+    struct InTensorInfo {
1934+      bool is_const{false};
1935+      uint32_t node_index_{0};
1936+      uint32_t tensor_index_{0};
1937+    };
1938+    struct OutTensorInfo {
1939+      uint32_t node_index_{0};
1940+      uint32_t tensor_index_{0};
1941+    };
1942+    bool (*checker)(schema::MetaGraphT *graph, uint32_t node_index);
1943+    std::vector<InTensorInfo> in_infos;
1944+    std::vector<OutTensorInfo> out_infos;
1945+  };
1946+
1947+  void DefinePattern() {
1948+    int match_order = 0;
1949+    pattern_[{match_order++, kAdd0}] = {
1950+      CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>, {{false, kTanh, 0}, {false, kMul0, 0}}, {}};
1951+    pattern_[{match_order++, kTanh}] = {
1952+      CheckActivation<schema::ActivationType_TANH>, {{false, kAdd1, 0}}, {{kSub, 1}, {kAdd0, 0}}};
1953+    pattern_[{match_order++, kMul0}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>,
1954+                                        {{false, kSigmoid0, 0}, {false, kSub, 0}},
1955+                                        {{kAdd0, 1}}};
1956+    pattern_[{match_order++, kAdd1}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
1957+                                        {{false, kSplit0, 2}, {false, kMul1, 0}},
1958+                                        {{kTanh, 0}}};
1959+    pattern_[{match_order++, kSub}] = {CheckArithmetic<schema::PrimitiveType_SubFusion, schema::SubFusionT>,
1960+                                       {{false, kInputH, 0}, {false, kTanh, 0}},
1961+                                       {{kMul0, 1}}};
1962+    pattern_[{match_order++, kSigmoid0}] = {
1963+      CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd2, 0}}, {{kMul0, 0}}};
1964+    pattern_[{match_order++, kSplit0}] = {CheckSplit, {{false, kAdd3, 0}}, {{kAdd4, 0}, {kAdd2, 0}, {kAdd1, 0}}};
1965+    pattern_[{match_order++, kMul1}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>,
1966+                                        {{false, kSigmoid1, 0}, {false, kSplit1, 2}},
1967+                                        {{kAdd1, 1}}};
1968+    pattern_[{match_order++, kAdd2}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
1969+                                        {{false, kSplit0, 1}, {false, kSplit1, 1}},
1970+                                        {{kSigmoid0, 0}}};
1971+    pattern_[{match_order++, kSigmoid1}] = {
1972+      CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd4, 0}}, {{kMul1, 0}}};
1973+    pattern_[{match_order++, kAdd3}] = {CheckBiasAdd, {{false, kMatmul0, 0}, {true}}, {{kSplit0, 0}}};
1974+    pattern_[{match_order++, kSplit1}] = {CheckSplit, {{false, kAdd5, 0}}, {{kAdd4, 1}, {kAdd2, 1}, {kMul1, 1}}};
1975+    pattern_[{match_order++, kAdd4}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>,
1976+                                        {{false, kSplit0, 0}, {false, kSplit1, 0}},
1977+                                        {{kSigmoid1, 0}}};
1978+    pattern_[{match_order++, kAdd5}] = {CheckBiasAdd, {{false, kMatmul1, 0}, {true}}, {{kSplit1, 0}}};
1979+    pattern_[{match_order++, kMatmul0}] = {CheckMatmul, {{false, kInputI, 0}, {true}}, {{kAdd3, 0}}};
1980+    pattern_[{match_order++, kMatmul1}] = {CheckMatmul, {{false, kInputH, 0}, {true}}, {{kAdd5, 0}}};
1981+  }
1982+
1983+  bool FillRealPattern(uint32_t node_index, std::map<uint32_t, NodeInfo> *real_pattern) {
1984+    const auto &link_infos = link_info_manager_->GetLinkInfos();
1985+    if (real_pattern->find(node_index) != real_pattern->end()) {
1986+      return false;
1987+    }
1988+    real_pattern->insert({node_index, {nullptr}});
1989+    auto in_tensor_indexes = graph_->nodes[node_index]->inputIndex;
1990+    for (auto tensor_index : in_tensor_indexes) {
1991+      if (link_infos.find(tensor_index) == link_infos.end()) {
1992+        return false;
1993+      }
1994+      const auto &tensor_out_info = link_infos.at(tensor_index).second;
1995+      if (tensor_out_info.node_index < 0) {
1996+        real_pattern->at(node_index).in_infos.push_back({true});
1997+      } else {
1998+        real_pattern->at(node_index)
1999+          .in_infos.push_back({false, static_cast<uint32_t>(tensor_out_info.node_index), tensor_out_info.out_index});
2000+      }
2001+    }
2002+    auto out_tensor_indexes = graph_->nodes[node_index]->outputIndex;
2003+    for (auto tensor_index : out_tensor_indexes) {
2004+      if (link_infos.find(tensor_index) == link_infos.end()) {
2005+        return false;
2006+      }
2007+      const auto &in_tensor_out_info = link_infos.at(tensor_index).first;
2008+      for (const auto &in_node_info : in_tensor_out_info) {
2009+        for (auto index : in_node_info.in_indexes) {
2010+          real_pattern->at(node_index).out_infos.push_back({static_cast<uint32_t>(in_node_info.node_index), index});
2011+        }
2012+      }
2013+    }
2014+    return true;
2015+  }
2016+
2017+  bool CheckPattern(const std::map<uint32_t, NodeInfo> &real_pattern,
2018+                    const std::pair<int, uint32_t> &pattern_node_index) {
2019+    const auto &real_in_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).in_infos;
2020+    const auto &virtual_in_infos = pattern_.at(pattern_node_index).in_infos;
2021+    if (real_in_infos.size() != virtual_in_infos.size()) {
2022+      return false;
2023+    }
2024+    for (size_t i = 0; i < virtual_in_infos.size(); ++i) {
2025+      if (virtual_in_infos[i].is_const) {
2026+        if (!real_in_infos[i].is_const) {
2027+          return false;
2028+        }
2029+        continue;
2030+      }
2031+      if (virtual_in_infos[i].tensor_index_ != real_in_infos[i].tensor_index_) {
2032+        return false;
2033+      }
2034+      if (real_node_map_.find(virtual_in_infos[i].node_index_) == real_node_map_.end()) {
2035+        real_node_map_.insert({virtual_in_infos[i].node_index_, real_in_infos[i].node_index_});
2036+      } else if (real_node_map_.at(virtual_in_infos[i].node_index_) != real_in_infos[i].node_index_) {
2037+        return false;
2038+      }
2039+    }
2040+    const auto &real_out_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).out_infos;
2041+    const auto &virtual_out_infos = pattern_.at(pattern_node_index).out_infos;
2042+    if (virtual_out_infos.empty()) {
2043+      return true;
2044+    }
2045+    if (real_out_infos.size() != virtual_out_infos.size()) {
2046+      return false;
2047+    }
2048+    for (size_t i = 0; i < virtual_out_infos.size(); ++i) {
2049+      if (virtual_out_infos[i].tensor_index_ != real_out_infos[i].tensor_index_) {
2050+        return false;
2051+      }
2052+      if (real_node_map_.find(virtual_out_infos[i].node_index_) == real_node_map_.end()) {
2053+        real_node_map_.insert({virtual_out_infos[i].node_index_, real_out_infos[i].node_index_});
2054+      } else if (real_node_map_.at(virtual_out_infos[i].node_index_) != real_out_infos[i].node_index_) {
2055+        return false;
2056+      }
2057+    }
2058+    return true;
2059+  }
2060+
2061+  bool CheckClosure(const std::map<uint32_t, uint32_t> &node_map) {
2062+    std::set<uint32_t> real_nodes;
2063+    (void)std::for_each(node_map.begin(), node_map.end(),
2064+                        [&real_nodes](std::pair<uint32_t, uint32_t> pair) { real_nodes.insert(pair.second); });
2065+    if (real_nodes.size() != node_map.size()) {
2066+      return false;
2067+    }
2068+    const auto &link_infos = link_info_manager_->GetLinkInfos();
2069+    for (uint32_t start = kAdd1; start <= kMatmul1; ++start) {
2070+      if (node_map.find(start) == node_map.end()) {
2071+        return false;
2072+      }
2073+      const auto &node = graph_->nodes[node_map.at(start)];
2074+      auto out_tensor_indexes = node->outputIndex;
2075+      for (auto out_index : out_tensor_indexes) {
2076+        if (link_infos.find(out_index) == link_infos.end()) {
2077+          return false;
2078+        }
2079+        for (const auto &in_node_info : link_infos.at(out_index).first) {
2080+          if (real_nodes.find(in_node_info.node_index) == real_nodes.end()) {
2081+            return false;
2082+          }
2083+        }
2084+      }
2085+    }
2086+    return true;
2087+  }
2088+
2089+  bool MatchPattern(uint32_t add_index) {
2090+    real_node_map_.clear();
2091+    real_node_map_[kAdd0] = add_index;
2092+    std::map<uint32_t, NodeInfo> real_pattern;
2093+    for (const auto &pair : pattern_) {
2094+      if (real_node_map_.find(pair.first.second) == real_node_map_.end()) {
2095+        return false;
2096+      }
2097+      auto node_index = real_node_map_[pair.first.second];
2098+      if (!pair.second.checker(graph_, node_index)) {
2099+        return false;
2100+      }
2101+      if (!FillRealPattern(node_index, &real_pattern)) {
2102+        return false;
2103+      }
2104+      if (!CheckPattern(real_pattern, pair.first)) {
2105+        return false;
2106+      }
2107+    }
2108+    auto weight_hidden_index = graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1];
2109+    auto weight_hidden_shape = graph_->allTensors[weight_hidden_index]->dims;
2110+    if (weight_hidden_shape.size() != C2NUM || weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) {
2111+      return false;
2112+    }
2113+    return CheckClosure(real_node_map_);
2114+  }
2115+
2116+  STATUS CreateCustomGruCell() {
2117+    std::vector<uint32_t> inputs;
2118+    inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[0]);  // x
2119+    inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[1]);  // weight_input
2120+    inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1]);  // weight_hidden
2121+    inputs.push_back(graph_->nodes[real_node_map_[kAdd3]]->inputIndex[1]);     // bias_input
2122+    inputs.push_back(graph_->nodes[real_node_map_[kAdd5]]->inputIndex[1]);     // bias_hidden
2123+    inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[0]);  // init_h
2124+    auto outputs = graph_->nodes[real_node_map_[kAdd0]]->outputIndex;
2125+    auto attrs = CreateCustom();
2126+    MS_CHECK_TRUE_RET(attrs != nullptr, RET_NULL_PTR);
2127+    auto prim_t = std::make_unique<schema::PrimitiveT>();
2128+    MS_CHECK_TRUE_MSG(prim_t != nullptr, RET_ERROR, "Create PrimitiveT failed.");
2129+    prim_t->value.type = schema::PrimitiveType_Custom;
2130+    prim_t->value.value = attrs.release();
2131+    auto custom_gru = std::make_unique<schema::CNodeT>();
2132+    MS_CHECK_TRUE_MSG(custom_gru != nullptr, RET_ERROR, "Create Custom-Gru failed.");
2133+    custom_gru->name = graph_->nodes[real_node_map_[kAdd0]]->name;
2134+    custom_gru->inputIndex = inputs;
2135+    custom_gru->outputIndex = outputs;
2136+    custom_gru->primitive = std::move(prim_t);
2137+    link_info_manager_->Replace(real_node_map_[kAdd0], std::move(custom_gru));
2138+    std::set<uint32_t> delete_nodes;
2139+    for (uint32_t i = kAdd1; i <= kMatmul1; ++i) {
2140+      delete_nodes.insert(real_node_map_[i]);
2141+    }
2142+    link_info_manager_->AddDeleteNodes(delete_nodes);
2143+    return RET_OK;
2144+  }
2145+
2146+  std::map<std::pair<int, uint32_t>, NodeInfo> pattern_;
2147+  std::map<uint32_t, uint32_t> real_node_map_;
2148+  schema::MetaGraphT *graph_{nullptr};
2149+  std::shared_ptr<LinkInfoManager> link_info_manager_{nullptr};
2150+};
2151+
2152+STATUS GruFusionPass::Run(schema::MetaGraphT *graph) {
2153+#ifndef ENABLE_ARM64
2154+  return RET_OK;
2155+#endif
2156+  if (graph == nullptr) {
2157+    MS_LOG(ERROR) << "graph is a nullptr.";
2158+    return RET_NULL_PTR;
2159+  }
2160+  if (graph->subGraph.size() != 1) {
2161+    return RET_OK;
2162+  }
2163+  if (FuseToGruCell(graph) != RET_OK) {
2164+    return RET_ERROR;
2165+  }
2166+  return FuseGruCell(graph);
2167+}
2168+
2169+STATUS GruFusionPass::FuseToGruCell(schema::MetaGraphT *graph) {
2170+  GruCellFusion gru_cell_fusion{};
2171+  if (gru_cell_fusion.Run(graph) != RET_OK) {
2172+    MS_LOG(ERROR) << "Fuse GruCell failed.";
2173+    return RET_ERROR;
2174+  }
2175+  return RET_OK;
2176+}
2177+
2178+STATUS GruFusionPass::FuseGruCell(schema::MetaGraphT *graph) {
2179+  link_info_manager_ = std::make_shared<LinkInfoManager>(graph);
2180+  for (uint32_t i = 0; i < static_cast<uint32_t>(graph->nodes.size()); ++i) {
2181+    if (!CheckStack(graph, i)) {
2182+      continue;
2183+    }
2184+    std::vector<uint32_t> strided_slices;
2185+    std::vector<uint32_t> squeezes;
2186+    std::vector<uint32_t> gru_cells;
2187+    if (!MatchPatten(graph, i, &strided_slices, &squeezes, &gru_cells)) {
2188+      continue;
2189+    }
2190+    if (CreateGru(graph, i, strided_slices, squeezes, gru_cells) != RET_OK) {
2191+      MS_LOG(ERROR) << "Fuse GruCell failed.";
2192+      return RET_ERROR;
2193+    }
2194+  }
2195+  link_info_manager_->UpdateMetaGraph();
2196+  link_info_manager_ = nullptr;
2197+  return RET_OK;
2198+}
2199+
2200+bool GruFusionPass::MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector<uint32_t> *strided_slices,
2201+                                std::vector<uint32_t> *squeezes, std::vector<uint32_t> *gru_cells) {
2202+  auto &link_infos = link_info_manager_->GetLinkInfos();
2203+  auto &stack_node = graph->nodes[stack_index];
2204+  int batch_point = 0;
2205+  auto CommonCheck = [&link_infos](uint32_t tensor_index) {
2206+    if (link_infos.find(tensor_index) == link_infos.end()) {
2207+      return std::make_pair(false, 0);
2208+    }
2209+    const auto &in_node_info = link_infos.at(tensor_index).first;
2210+    if (in_node_info.size() != 1 && in_node_info.front().in_indexes.size() != 1) {
2211+      return std::make_pair(false, 0);
2212+    }
2213+    auto node_index = link_infos.at(tensor_index).second.node_index;
2214+    if (node_index < 0) {
2215+      return std::make_pair(false, 0);
2216+    }
2217+    return std::make_pair(true, node_index);
2218+  };
2219+  for (auto tensor_index : stack_node->inputIndex) {
2220+    auto check_info = CommonCheck(tensor_index);
2221+    if (!check_info.first || !CheckGruCell(graph, check_info.second)) {
2222+      return false;
2223+    }
2224+    gru_cells->push_back(check_info.second);
2225+    auto &gru_cell_node = graph->nodes[check_info.second];
2226+    check_info = CommonCheck(gru_cell_node->inputIndex.front());
2227+    if (!check_info.first || !CheckSqueeze(graph, check_info.second)) {
2228+      return false;
2229+    }
2230+    squeezes->push_back(check_info.second);
2231+    auto &squeeze_node = graph->nodes[check_info.second];
2232+    check_info = CommonCheck(squeeze_node->inputIndex.front());
2233+    if (!check_info.first || !CheckStridedSlice(graph, check_info.second, batch_point)) {
2234+      return false;
2235+    }
2236+    strided_slices->push_back(check_info.second);
2237+    ++batch_point;
2238+  }
2239+  if (strided_slices->empty()) {
2240+    return false;
2241+  }
2242+  uint32_t input_index = graph->nodes[strided_slices->front()]->inputIndex.front();
2243+  if (std::any_of(strided_slices->begin(), strided_slices->end(), [input_index, graph](uint32_t strided_slice) {
2244+        return graph->nodes[strided_slice]->inputIndex.front() != input_index;
2245+      })) {
2246+    return false;
2247+  }
2248+  auto in_shape = graph->allTensors[input_index]->dims;
2249+  if (in_shape.empty() || in_shape.front() != batch_point) {
2250+    return false;
2251+  }
2252+  return CheckGruCellConnection(graph, *gru_cells);
2253+}
2254+
2255+bool GruFusionPass::CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector<uint32_t> &gru_cells) {
2256+  auto &first_node = graph->nodes[gru_cells.front()];
2257+  if (first_node->inputIndex.size() != C6NUM) {
2258+    return false;
2259+  }
2260+  auto init_h = first_node->outputIndex.front();
2261+  for (size_t i = 1; i < gru_cells.size(); ++i) {
2262+    auto &node = graph->nodes[gru_cells[i]];
2263+    if (node->inputIndex.size() != first_node->inputIndex.size()) {
2264+      return false;
2265+    }
2266+    for (size_t j = 1; j < C5NUM; ++j) {
2267+      if (node->inputIndex[j] != first_node->inputIndex[j]) {
2268+        return false;
2269+      }
2270+    }
2271+    if (node->inputIndex[C5NUM] != init_h) {
2272+      return false;
2273+    }
2274+    init_h = node->outputIndex.front();
2275+  }
2276+  return true;
2277+}
2278+
2279+STATUS GruFusionPass::CreateGru(schema::MetaGraphT *graph, uint32_t stack_index,
2280+                                const std::vector<uint32_t> &strided_slices, const std::vector<uint32_t> &squeezes,
2281+                                const std::vector<uint32_t> &gru_cells) {
2282+  auto &gru_cell_node = graph->nodes[gru_cells.front()];
2283+  gru_cell_node->inputIndex[0] = graph->nodes[strided_slices.front()]->inputIndex[0];
2284+  gru_cell_node->outputIndex[0] = graph->nodes[stack_index]->outputIndex[0];
2285+  std::set<uint32_t> delete_node{stack_index};
2286+  (void)delete_node.insert(strided_slices.begin(), strided_slices.end());
2287+  (void)delete_node.insert(squeezes.begin(), squeezes.end());
2288+  (void)delete_node.insert(gru_cells.begin() + 1, gru_cells.end());
2289+  link_info_manager_->AddDeleteNodes(delete_node);
2290+  return RET_OK;
2291+}
2292+}  // namespace lite
2293+}  // namespace mindspore
2294diff --git a/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h
2295new file mode 100644
2296index 00000000..5e2b705d
2297--- /dev/null
2298+++ b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h
2299@@ -0,0 +1,45 @@
2300+/**
2301+ * Copyright 2023 Huawei Technologies Co., Ltd
2302+ *
2303+ * Licensed under the Apache License, Version 2.0 (the "License");
2304+ * you may not use this file except in compliance with the License.
2305+ * You may obtain a copy of the License at
2306+ *
2307+ * http://www.apache.org/licenses/LICENSE-2.0
2308+ *
2309+ * Unless required by applicable law or agreed to in writing, software
2310+ * distributed under the License is distributed on an "AS IS" BASIS,
2311+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2312+ * See the License for the specific language governing permissions and
2313+ * limitations under the License.
2314+ */
2315+
2316+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_
2317+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_
2318+
2319+#include <memory>
2320+#include <vector>
2321+#include "tools/converter/optimizer.h"
2322+
2323+namespace mindspore {
2324+namespace lite {
2325+class LinkInfoManager;
2326+class GruFusionPass : public GraphPass {
2327+ public:
2328+  GruFusionPass() = default;
2329+  ~GruFusionPass() override = default;
2330+  STATUS Run(schema::MetaGraphT *graph) override;
2331+
2332+ private:
2333+  STATUS FuseToGruCell(schema::MetaGraphT *graph);
2334+  STATUS FuseGruCell(schema::MetaGraphT *graph);
2335+  bool MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector<uint32_t> *strided_slices,
2336+                   std::vector<uint32_t> *squeezes, std::vector<uint32_t> *gru_cells);
2337+  bool CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector<uint32_t> &gru_cells);
2338+  STATUS CreateGru(schema::MetaGraphT *graph, uint32_t stack_index, const std::vector<uint32_t> &strided_slices,
2339+                   const std::vector<uint32_t> &squeezes, const std::vector<uint32_t> &gru_cells);
2340+  std::shared_ptr<LinkInfoManager> link_info_manager_{nullptr};
2341+};
2342+}  // namespace lite
2343+}  // namespace mindspore
2344+#endif  // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_
2345diff --git a/mindspore/lite/src/train/static_allocator.h b/mindspore/lite/src/train/static_allocator.h
2346index d78e13ba..bd80651d 100644
2347--- a/mindspore/lite/src/train/static_allocator.h
2348+++ b/mindspore/lite/src/train/static_allocator.h
2349@@ -40,12 +40,12 @@ class StaticAllocator : public Allocator {
2350     if (ptr == nullptr) return STATIC_ALLOCATION;
2351     char *ptrc = reinterpret_cast<char *>(ptr);
2352     char *bufc = reinterpret_cast<char *>(start_buf_);
2353-    return ((ptrc < bufc) || (ptrc - bufc >= static_cast<ptrdiff_t>(size_)) ? 1 : 0);
2354+    return ((ptrc < bufc) || (ptrc >= bufc + size_)) ? 1 : 0;
2355   }
2356
2357  private:
2358-  void *start_buf_;
2359-  size_t size_;
2360+  void *start_buf_{nullptr};
2361+  size_t size_{0};
2362   size_t total_size_ = 0;
2363 };
2364 };      // namespace mindspore
2365diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc
2366index 7e504c4e..008de7c5 100644
2367--- a/mindspore/lite/src/train/train_export.cc
2368+++ b/mindspore/lite/src/train/train_export.cc
2369@@ -30,6 +30,10 @@
2370 #include "src/train/graph_fusion.h"
2371 #include "src/train/graph_dropout.h"
2372 #include "src/runtime/weight_decoder.h"
2373+#include "src/runtime/kernel/cpu/fp16/fp16_op_handler.h"
2374+#ifndef ENABLE_ARM
2375+#include "base/float16.h"
2376+#endif
2377
2378 namespace mindspore {
2379 namespace lite {
2380@@ -645,6 +649,40 @@ int TrainExport::SaveToBuffer() {
2381   return RET_OK;
2382 }
2383
2384+int TrainExport::SaveWeightsToFile(bool enable_fp16, const std::vector<std::string> &changeable_weights_name) {
2385+  const auto &all_tensors = meta_graph_->allTensors;
2386+  std::ofstream weights(file_name_, std::ios::out | std::ios::trunc | std::ios::binary);
2387+  for (auto &tensor : all_tensors) {
2388+    MS_CHECK_TRUE_MSG(tensor != nullptr, RET_NULL_PTR, "Exist tensor is a nullptr.");
2389+    if (tensor->data.empty()) {
2390+      continue;
2391+    }
2392+    if (std::find(changeable_weights_name.begin(), changeable_weights_name.end(), tensor->name) !=
2393+        changeable_weights_name.end()) {
2394+      auto shape = tensor->dims;
2395+      weights.write(reinterpret_cast<const char *>(shape.data()), shape.size() * sizeof(uint32_t));
2396+    }
2397+    if (!enable_fp16 || tensor->dataType != kNumberTypeFloat32) {
2398+      weights.write(reinterpret_cast<const char *>(tensor->data.data()), tensor->data.size());
2399+    } else {
2400+      std::vector<uint16_t> data_fp16(tensor->data.size() / sizeof(float));
2401+#ifndef ENABLE_ARM
2402+      auto fp32_data = reinterpret_cast<const float *>(tensor->data.data());
2403+      auto fp16_data = reinterpret_cast<float16 *>(data_fp16.data());
2404+      CHECK_NULL_RETURN(fp32_data);
2405+      CHECK_NULL_RETURN(fp16_data);
2406+      for (size_t j = 0; j < data_fp16.size(); ++j) {
2407+        fp16_data[j] = float16(fp32_data[j]);
2408+      }
2409+#else
2410+      Float32ToFloat16_fp16_handler(tensor->data.data(), data_fp16.data(), data_fp16.size(), true);
2411+#endif
2412+      weights.write(reinterpret_cast<const char *>(data_fp16.data()), data_fp16.size() * sizeof(uint16_t));
2413+    }
2414+  }
2415+  weights.close();
2416+  return RET_OK;
2417+}
2418
2419 bool TrainExport::IsInputTensor(const schema::TensorT &t) {
2420   int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>());
2421diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h
2422index 8e802021..d6f81187 100644
2423--- a/mindspore/lite/src/train/train_export.h
2424+++ b/mindspore/lite/src/train/train_export.h
2425@@ -52,6 +52,7 @@ class TrainExport {
2426   int ExportInit(const std::string model_name, std::string version);
2427   int SaveToFile();
2428   int SaveToBuffer();
2429+  int SaveWeightsToFile(bool enable_fp16 = false, const std::vector<std::string> &changeable_weights_name = {});
2430   void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; }
2431   int LoadModel(void *buf, size_t buf_size);
2432   int AddTransformNode();
2433diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc
2434index b40ff8c2..2f9aa99b 100644
2435--- a/mindspore/lite/src/train/train_session.cc
2436+++ b/mindspore/lite/src/train/train_session.cc
2437@@ -1233,10 +1233,45 @@ int TrainSession::Export(Buffer *model_buffer, ModelType model_type, Quantizatio
2438   return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name);
2439 }
2440
2441+int TrainSession::ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type,
2442+                                                    FormatType format, bool enable_fp16,
2443+                                                    const std::vector<std::string> &changeable_weights_name) {
2444+  MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty");
2445+  struct stat path_type;
2446+  if (stat(file_name.c_str(), &path_type) == RET_OK) {
2447+    if (path_type.st_mode & S_IFDIR) {
2448+      MS_LOG(ERROR) << "Destination must be path, now is a directory";
2449+      return RET_ERROR;
2450+    }
2451+  }
2452+  MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "Format must be `FT_FLATBUFFERS`");
2453+  MS_CHECK_FALSE_MSG(model_type != mindspore::lite::MT_INFERENCE, RET_ERROR,
2454+                     "Currently, can only export inference-model's weights.");
2455+  int status = Eval();
2456+  TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed");
2457+
2458+  TrainExport texport(file_name);
2459+  status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_);
2460+  TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export");
2461+
2462+  status = texport.ExportNet(inference_kernels_, tensors_, eval_output_tensor_names_, model_.get(), QT_DEFAULT);
2463+  TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network.");
2464+  status = texport.TrainModelDrop();
2465+  TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed.");
2466+  status = texport.TrainModelFusion();
2467+  TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed.");
2468+  status = texport.SaveWeightsToFile(enable_fp16, changeable_weights_name);
2469+  if (status != RET_OK) {
2470+    MS_LOG(ERROR) << "Failed to save to " << file_name;
2471+    return status;
2472+  }
2473+  return RET_OK;
2474+}
2475+
2476 std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const {
2477   std::vector<lite::Tensor *> features;
2478   for (auto cur_tensor : this->tensors_) {
2479-    if (cur_tensor->category() ==lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) {
2480+    if (cur_tensor->category() == lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) {
2481       features.push_back(cur_tensor);
2482     }
2483   }
2484diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h
2485index 5acff82a..edcab32d 100644
2486--- a/mindspore/lite/src/train/train_session.h
2487+++ b/mindspore/lite/src/train/train_session.h
2488@@ -106,6 +106,9 @@ class TrainSession : virtual public lite::LiteSession {
2489              std::vector<std::string> out_put_tensor_name = {}) override;
2490   int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType,
2491              std::vector<std::string> out_put_tensor_name = {}) override;
2492+  int ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type, FormatType,
2493+                                        bool enable_fp16,
2494+                                        const std::vector<std::string> &changeable_weights_name) override;
2495
2496   std::vector<lite::Tensor *> GetFeatureMaps() const override;
2497
2498diff --git a/mindspore/lite/test/config_level0/micro/micro_arm64.cfg b/mindspore/lite/test/config_level0/micro/micro_arm64.cfg
2499index 0375ebf7..765549ab 100644
2500--- a/mindspore/lite/test/config_level0/micro/micro_arm64.cfg
2501+++ b/mindspore/lite/test/config_level0/micro/micro_arm64.cfg
2502@@ -25,3 +25,10 @@ support_parallel=false
2503
2504 # enable debug
2505 debug_mode=false
2506+
2507+# false indicates that only the required weights are saved. If collaborate with lite-train, the parameter must be true.
2508+keep_original_weight=false
2509+
2510+# the names of those weight-tensors whose shape is changeable, only embedding-table supports change.
2511+# the parameter is used to collaborate with lite-train. If set, `keep_original_weight` must be true.
2512+#changeable_weights_name=name0,name1
2513diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
2514index ea263e64..4c6bd237 100644
2515--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
2516+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
2517@@ -221,11 +221,14 @@ int ConfigFileParser::ParseAclOptionCfgString(const std::map<std::string, std::m
2518 int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
2519   if (maps.find(kMicroParam) != maps.end()) {
2520     const auto &map = maps.at(kMicroParam);
2521-    std::map<std::string, std::string &> parse_map{{"target", micro_param_string_.target},
2522-                                                   {"codegen_mode", micro_param_string_.codegen_mode},
2523-                                                   {"debug_mode", micro_param_string_.debug_mode},
2524-                                                   {"support_parallel", micro_param_string_.support_parallel},
2525-                                                   {"enable_micro", micro_param_string_.enable_micro}};
2526+    std::map<std::string, std::string &> parse_map{
2527+      {"target", micro_param_string_.target},
2528+      {"codegen_mode", micro_param_string_.codegen_mode},
2529+      {"debug_mode", micro_param_string_.debug_mode},
2530+      {"support_parallel", micro_param_string_.support_parallel},
2531+      {"enable_micro", micro_param_string_.enable_micro},
2532+      {"keep_original_weight", micro_param_string_.keep_original_weight},
2533+      {"changeable_weights_name", micro_param_string_.changeable_weights_name}};
2534     return SetMapData(map, parse_map, kMicroParam);
2535   }
2536   return RET_OK;
2537diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
2538index 0ada406e..8854e5f7 100644
2539--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h
2540+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
2541@@ -86,6 +86,8 @@ struct MicroParamString {
2542   std::string support_parallel;
2543   std::string debug_mode;
2544   std::string enable_micro;
2545+  std::string keep_original_weight;
2546+  std::string changeable_weights_name;
2547 };
2548
2549 struct ThirdPartyModelString {
2550diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc
2551index 310b2398..559bee8b 100644
2552--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc
2553+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc
2554@@ -61,6 +61,23 @@ STATUS MicroParamParser::ParseEnableMicro(const std::string &enable_micro, micro
2555   return RET_OK;
2556 }
2557
2558+STATUS MicroParamParser::ParseKeepOriginalWeight(const std::string &save_all_weights, micro::MicroParam *micro_param) {
2559+  MS_LOG(DEBUG) << "Micro enables : " << save_all_weights;
2560+  micro_param->keep_original_weight = false;  // default
2561+  bool is_keep_original_weight;
2562+  if (ConvertBool(save_all_weights, &is_keep_original_weight)) {
2563+    micro_param->keep_original_weight = is_keep_original_weight;
2564+  }
2565+  return RET_OK;
2566+}
2567+
2568+STATUS MicroParamParser::ParseChangeableWeightsName(const std::string &changeable_weights_name,
2569+                                                    micro::MicroParam *micro_param) {
2570+  MS_LOG(DEBUG) << "Micro record changeable weights name: " << changeable_weights_name;
2571+  micro_param->changeable_weights_name = changeable_weights_name;
2572+  return RET_OK;
2573+}
2574+
2575 STATUS MicroParamParser::ParseMicroParam(const MicroParamString &micro_param_string, micro::MicroParam *micro_param) {
2576   CHECK_NULL_RETURN(micro_param);
2577   if (!micro_param_string.target.empty()) {
2578@@ -93,6 +110,22 @@ STATUS MicroParamParser::ParseMicroParam(const MicroParamString &micro_param_str
2579       return RET_INPUT_PARAM_INVALID;
2580     }
2581   }
2582+  if (!micro_param_string.keep_original_weight.empty()) {
2583+    if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) {
2584+      MS_LOG(ERROR) << "Parse keep_original_weight val; " << micro_param_string.keep_original_weight;
2585+      return RET_INPUT_PARAM_INVALID;
2586+    }
2587+  }
2588+  if (!micro_param_string.changeable_weights_name.empty()) {
2589+    if (!micro_param->keep_original_weight) {
2590+      MS_LOG(ERROR) << "When changeable_weights_name is set, the keep_original_weight must be true.";
2591+      return RET_INPUT_PARAM_INVALID;
2592+    }
2593+    if (ParseChangeableWeightsName(micro_param_string.changeable_weights_name, micro_param) != RET_OK) {
2594+      MS_LOG(ERROR) << "Parse changeable_weights_name val: " << micro_param_string.changeable_weights_name;
2595+      return RET_INPUT_PARAM_INVALID;
2596+    }
2597+  }
2598   return RET_OK;
2599 }
2600 }  // namespace lite
2601diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h
2602index 860182af..93a30b39 100644
2603--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h
2604+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h
2605@@ -33,6 +33,8 @@ class MicroParamParser {
2606   STATUS ParseCodeGenMode(const std::string &codegen_mode, micro::MicroParam *micro_param);
2607   STATUS ParseSupportParallel(const std::string &support_parallel, micro::MicroParam *micro_param);
2608   STATUS ParseDebugMode(const std::string &debug_mode, micro::MicroParam *micro_param);
2609+  STATUS ParseKeepOriginalWeight(const std::string &save_all_weights, micro::MicroParam *micro_param);
2610+  STATUS ParseChangeableWeightsName(const std::string &changeable_weights_name, micro::MicroParam *micro_param);
2611 };
2612 }  // namespace lite
2613 }  // namespace mindspore
2614diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc
2615index 944ed29c..6177d379 100644
2616--- a/mindspore/lite/tools/converter/converter.cc
2617+++ b/mindspore/lite/tools/converter/converter.cc
2618@@ -186,6 +186,9 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara>
2619     }
2620   }
2621
2622+  if (param->fmk_type == FmkType::kFmkTypeMsLite) {
2623+    return nullptr;
2624+  }
2625   auto graph = BuildFuncGraph(param);
2626   if (graph == nullptr) {
2627     MS_LOG(ERROR) << "Parser/Import model return nullptr";
2628@@ -557,7 +560,7 @@ int CheckFmkType(const std::shared_ptr<ConverterPara> &param) {
2629   }
2630   const std::set kValidFmkTypes = {FmkType::kFmkTypeTf,        FmkType::kFmkTypeCaffe,  FmkType::kFmkTypeOnnx,
2631                                    FmkType::kFmkTypeMs,        FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch,
2632-                                   FmkType::kFmkTypeThirdParty};
2633+                                   FmkType::kFmkTypeThirdParty, FmkType::kFmkTypeMsLite};
2634   if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) {
2635     MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|THIRDPARTY"
2636                   << ", but got " << param->fmk_type;
2637@@ -780,6 +783,14 @@ int RunConverter(const std::shared_ptr<ConverterPara> &param, void **model_data,
2638   NotSupportOp::GetInstance()->PrintOps();
2639   status = ReturnCode::GetSingleReturnCode()->status_code();
2640   if (meta_graph == nullptr) {
2641+    if (param->fmk_type == FmkType::kFmkTypeMsLite && param->microParam.enable_micro) {
2642+      status = micro::Coder::MicroSourceCodeGeneration(param->model_file, param->output_file, param->microParam,
2643+                                                       param->weight_fp16);
2644+      if (status != RET_OK) {
2645+        CONVERTER_LOG_ERROR("MICRO CODEGEN FAILED:" << status << " " << GetErrorInfo(status));
2646+      }
2647+      return status;
2648+    }
2649     CONVERTER_LOG_ERROR("CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status));
2650     status = RET_ERROR;
2651     return status;
2652@@ -797,9 +808,8 @@ int RunConverter(const std::shared_ptr<ConverterPara> &param, void **model_data,
2653   }
2654
2655   if (param->microParam.enable_micro) {
2656-    status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam.codegen_mode,
2657-                                                     param->microParam.target, param->microParam.support_parallel,
2658-                                                     param->microParam.debug_mode, param->weight_fp16);
2659+    status =
2660+      micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam, param->weight_fp16);
2661     if (status != RET_OK) {
2662       delete meta_graph;
2663       CONVERTER_LOG_ERROR("MICRO CODEGEN FAILED:" << status << " " << GetErrorInfo(status));
2664diff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
2665index 595b59ed..e30994cc 100644
2666--- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
2667+++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
2668@@ -29,7 +29,7 @@ using mindspore::lite::RET_INPUT_PARAM_INVALID;
2669 using mindspore::lite::RET_OK;
2670
2671 Flags::Flags() {
2672-  AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
2673+  AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX | MSLITE", "");
2674   AddFlag(&Flags::modelFile, "modelFile",
2675           "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
2676   AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
2677@@ -120,7 +120,7 @@ int Flags::InitFmk() {
2678   // value check not here, it is in converter c++ API's CheckValueParam method.
2679   std::map<std::string, FmkType> StrToEnumFmkTypeMap = {
2680     {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs},       {"TFLITE", kFmkTypeTflite},        {"ONNX", kFmkTypeOnnx},
2681-    {"TF", kFmkTypeTf},       {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}};
2682+    {"TF", kFmkTypeTf},       {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}, {"MSLITE", kFmkTypeMsLite}};
2683   if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) {
2684     this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn);
2685   } else {
2686diff --git a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake
2687index 9ae54538..589ee81a 100644
2688--- a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake
2689+++ b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake
2690@@ -53,6 +53,7 @@ set(CODER_OPCODERS_SRC
2691         ${MICRO_DIR}/coder/opcoders/base/reduce_base_coder.cc
2692         ${MICRO_DIR}/coder/opcoders/base/resize_base_coder.cc
2693         ${MICRO_DIR}/coder/opcoders/base/reshape_base_coder.cc
2694+        ${MICRO_DIR}/coder/opcoders/base/stack_base_coder.cc
2695         ${MICRO_DIR}/coder/opcoders/base/softmax_base_coder.cc
2696         ${MICRO_DIR}/coder/opcoders/base/detection_post_process_base_coder.cc
2697         ${MICRO_DIR}/coder/opcoders/base/strided_slice_base_coder.cc
2698@@ -71,6 +72,7 @@ set(CODER_OPCODERS_SRC
2699         ${MICRO_DIR}/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc
2700         ${MICRO_DIR}/coder/opcoders/nnacl/fp16/avg_pooling_fp16_coder.cc
2701         ${MICRO_DIR}/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc
2702+        ${MICRO_DIR}/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc
2703         ${MICRO_DIR}/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc
2704         ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_coder.cc
2705         ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc
2706@@ -90,6 +92,7 @@ set(CODER_OPCODERS_SRC
2707         ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc
2708         ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc
2709         ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc
2710+        ${MICRO_DIR}/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc
2711         ${MICRO_DIR}/coder/opcoders/nnacl/fp32/full_connection_fp32_coder.cc
2712         ${MICRO_DIR}/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc
2713         ${MICRO_DIR}/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc
2714diff --git a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc
2715index 9c5839b4..be314ed6 100644
2716--- a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc
2717+++ b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc
2718@@ -79,18 +79,26 @@ void MemoryAllocator::Free() {
2719       iter++;
2720     }
2721   }
2722+  for (auto &item : auxiliary_weights_) {
2723+    delete item.second.first;
2724+  }
2725   malloc_weights_addr_.clear();
2726   for (auto &item : allocated_) {
2727     free(item);
2728     item = nullptr;
2729   }
2730   allocated_.clear();
2731+  origin_weights_.clear();
2732+  auxiliary_weights_.clear();
2733 }
2734
2735 std::map<Tensor *, std::string> MemoryAllocator::tensors_map() const {
2736   std::map<Tensor *, std::string> res;
2737   res.insert(tensors_addr_.begin(), tensors_addr_.end());
2738   res.insert(malloc_weights_addr_.begin(), malloc_weights_addr_.end());
2739+  (void)std::for_each(
2740+    auxiliary_weights_.begin(), auxiliary_weights_.end(),
2741+    [&res](const std::pair<Tensor *, std::pair<Tensor *, std::string>> &item) { res.insert(item.second); });
2742   return res;
2743 }
2744
2745@@ -121,17 +129,25 @@ void MemoryAllocator::AssignGraphInputs(const std::vector<Tensor *> &inputs) {
2746   }
2747 }
2748
2749-void MemoryAllocator::RecordOriginWeightsAddr(const std::vector<std::unique_ptr<OperatorCoder>> &nodes) {
2750-  for (const auto &node : nodes) {
2751-    std::vector<Tensor *> inputs = node->input_tensors();
2752-    for (const auto &tensor : inputs) {
2753-      if (tensor->category() == lite::Category::CONST_TENSOR || tensor->category() == lite::Category::CONST_SCALAR) {
2754-        std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_);
2755-        origin_weights_addr_.insert(std::make_pair(tensor, runtime_addr));
2756-        weight_index_++;
2757+int MemoryAllocator::RecordOriginWeightsAddr(const std::vector<Tensor *> &all_tensors,
2758+                                             const std::string &changeable_weights_name) {
2759+  std::vector<std::string> weights_name;
2760+  if (!changeable_weights_name.empty()) {
2761+    weights_name = StrSplit(changeable_weights_name, ",");
2762+  }
2763+  for (const auto &tensor : all_tensors) {
2764+    if (tensor->category() == lite::Category::CONST_TENSOR || tensor->category() == lite::Category::CONST_SCALAR) {
2765+      if (std::find(weights_name.begin(), weights_name.end(), tensor->tensor_name()) != weights_name.end()) {
2766+        if (RecordChangeableWeights(tensor) != RET_OK) {
2767+          return RET_ERROR;
2768+        }
2769       }
2770+      std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++);
2771+      origin_weights_addr_.insert(std::make_pair(tensor, runtime_addr));
2772+      origin_weights_.push_back(tensor);
2773     }
2774   }
2775+  return RET_OK;
2776 }
2777
2778 int MemoryAllocator::AssignTensors(const std::vector<std::unique_ptr<OperatorCoder>> &nodes) {
2779@@ -150,9 +166,13 @@ int MemoryAllocator::AssignTensors(const std::vector<std::unique_ptr<OperatorCod
2780 }
2781
2782 int MemoryAllocator::Assign(const std::vector<Tensor *> &inputs,
2783-                            const std::vector<std::unique_ptr<OperatorCoder>> &nodes) {
2784+                            const std::vector<std::unique_ptr<OperatorCoder>> &nodes,
2785+                            const std::vector<Tensor *> &all_tensors, const std::string &changeable_weights_name) {
2786   AssignGraphInputs(inputs);
2787-  RecordOriginWeightsAddr(nodes);
2788+  if (RecordOriginWeightsAddr(all_tensors, changeable_weights_name) != RET_OK) {
2789+    MS_LOG(ERROR) << "RecordOriginWeightsAddr failed.";
2790+    return RET_ERROR;
2791+  }
2792   return AssignTensors(nodes);
2793 }
2794
2795@@ -163,4 +183,46 @@ void MemoryAllocator::MarkSharedWeight(const Tensor *src, void *pack_weight) {
2796 void *MemoryAllocator::GetSharedWeightAddr(const Tensor *src) {
2797   return shared_pack_weights_.find(src) == shared_pack_weights_.end() ? nullptr : shared_pack_weights_[src];
2798 }
2799+
2800+int MemoryAllocator::RecordChangeableWeights(Tensor *src) {
2801+  MS_ASSERT(src != nullptr);
2802+  auto variable_str = GetAuxiliaryWeight(src);
2803+  if (!variable_str.empty()) {
2804+    return RET_OK;
2805+  }
2806+  if (!src->IsConst()) {
2807+    MS_LOG(ERROR) << "Currently, the tensor must be a constant.";
2808+    return RET_NOT_SUPPORT;
2809+  }
2810+  auto shape = src->shape();
2811+  auto shape_tensor = new (std::nothrow)
2812+    Tensor(kNumberTypeInt32, {static_cast<int>(shape.size())}, src->format(), Category::CONST_TENSOR);
2813+  if (shape_tensor == nullptr) {
2814+    MS_LOG(ERROR) << "Create an assistant tensor failed.";
2815+    return RET_NULL_PTR;
2816+  }
2817+  auto data = shape_tensor->MutableData();
2818+  if (data == nullptr) {
2819+    MS_LOG(ERROR) << "Create an assistant tensor failed.";
2820+    delete shape_tensor;
2821+    return RET_NULL_PTR;
2822+  }
2823+  if (memcpy_s(data, shape_tensor->Size(), shape.data(), shape.size() * sizeof(int)) != EOK) {
2824+    MS_LOG(ERROR) << "Create an assistant tensor failed.";
2825+    delete shape_tensor;
2826+    return RET_ERROR;
2827+  }
2828+  shape_tensor->set_tensor_name(src->tensor_name() + "_shape");
2829+  std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++);
2830+  auxiliary_weights_[src] = std::make_pair(shape_tensor, runtime_addr);
2831+  return RET_OK;
2832+}
2833+
2834+std::string MemoryAllocator::GetAuxiliaryWeight(Tensor *src) {
2835+  auto iter = auxiliary_weights_.find(src);
2836+  if (iter != auxiliary_weights_.end()) {
2837+    return iter->second.second;
2838+  }
2839+  return {};
2840+}
2841 }  // namespace mindspore::lite::micro
2842diff --git a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h
2843index 8a1331fb..f5bacf6f 100644
2844--- a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h
2845+++ b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h
2846@@ -56,7 +56,8 @@ class MemoryAllocator {
2847   /*
2848    * assign model's input, original weights and all tensors memory addr
2849    */
2850-  int Assign(const std::vector<Tensor *> &inputs, const std::vector<std::unique_ptr<OperatorCoder>> &nodes);
2851+  int Assign(const std::vector<Tensor *> &inputs, const std::vector<std::unique_ptr<OperatorCoder>> &nodes,
2852+             const std::vector<Tensor *> &all_tensors, const std::string &changeable_weights_name = {});
2853
2854   // allocator holds the space malloced by opcoders, will free before session coder destroy
2855   void Free();
2856@@ -141,14 +142,18 @@ class MemoryAllocator {
2857   void *MallocWeightTensor(TypeId type_id, size_t size, MallocType type, const std::string &tensor_name = "");
2858   void MarkSharedWeight(const Tensor *src, void *pack_weight);
2859   void *GetSharedWeightAddr(const Tensor *src);
2860+  std::string GetAuxiliaryWeight(Tensor *src);
2861+  std::vector<Tensor *> origin_weights() const { return origin_weights_; }
2862+  std::map<Tensor *, std::pair<Tensor *, std::string>> auxiliary_weights() const { return auxiliary_weights_; }
2863
2864  private:
2865   int AssignTensors(const std::vector<std::unique_ptr<OperatorCoder>> &nodes);
2866   void AssignGraphInputs(const std::vector<Tensor *> &inputs);
2867   void AssignWorkspaces(void *addr, size_t size);
2868-  void RecordOriginWeightsAddr(const std::vector<std::unique_ptr<OperatorCoder>> &nodes);
2869+  int RecordOriginWeightsAddr(const std::vector<Tensor *> &all_tensors,
2870+                              const std::string &changeable_weights_name = {});
2871   void RecordTensorsAddr(const std::map<Tensor *, size_t> &offsets);
2872-
2873+  int RecordChangeableWeights(Tensor *src);
2874   MemoryAllocator() = default;
2875   ~MemoryAllocator() = default;
2876
2877@@ -160,11 +165,13 @@ class MemoryAllocator {
2878   bool is_next_{false};
2879   size_t offset_{0};
2880   std::vector<void *> allocated_;
2881+  std::vector<Tensor *> origin_weights_;
2882   std::map<std::string, Tensor *> saved_weights_addr_;
2883   std::map<Tensor *, std::string> origin_weights_addr_;
2884   std::map<Tensor *, std::string> malloc_weights_addr_;
2885   std::map<Tensor *, std::string> tensors_addr_;
2886   std::map<const Tensor *, void *> shared_pack_weights_;
2887+  std::map<Tensor *, std::pair<Tensor *, std::string>> auxiliary_weights_;
2888 };
2889 }  // namespace mindspore::lite::micro
2890 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_ALLOCATOR_ALLOCATOR_H_
2891diff --git a/mindspore/lite/tools/converter/micro/coder/coder.cc b/mindspore/lite/tools/converter/micro/coder/coder.cc
2892index cca4687e..a94ac91b 100644
2893--- a/mindspore/lite/tools/converter/micro/coder/coder.cc
2894+++ b/mindspore/lite/tools/converter/micro/coder/coder.cc
2895@@ -93,25 +93,48 @@ bool Coder::InitPath(const std::string &output_path) {
2896 }
2897
2898 int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path,
2899-                                     const std::string &codegen_mode, const std::string &device, bool support_parallel,
2900-                                     bool debug_mode, bool enableFp16) {
2901+                                     const MicroParam &param, bool enable_fp16) {
2902   flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
2903   auto offset = schema::MetaGraph::Pack(builder, &graph);
2904   builder.Finish(offset);
2905   schema::FinishMetaGraphBuffer(builder, offset);
2906   size_t size = builder.GetSize();
2907+  if (ExecuteMicroGeneration(builder.GetBufferPointer(), size, output_path, param, enable_fp16) != RET_OK) {
2908+    MS_LOG(ERROR) << "Execute Micro failed.";
2909+    return RET_ERROR;
2910+  }
2911+  return RET_OK;
2912+}
2913+
2914+int Coder::MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path,
2915+                                     const MicroParam &param, bool enable_fp16) {
2916+  size_t buffer_size;
2917+  auto model_buf = lite::ReadFile(model_file.c_str(), &buffer_size);
2918+  if (model_buf == nullptr) {
2919+    MS_LOG(ERROR) << "Read model-file failed.";
2920+    return RET_NULL_PTR;
2921+  }
2922+  auto ret = ExecuteMicroGeneration(model_buf, buffer_size, output_path, param, enable_fp16);
2923+  if (ret != RET_OK) {
2924+    MS_LOG(ERROR) << "Execute Micro failed.";
2925+  }
2926+  delete[] model_buf;
2927+  return ret;
2928+}
2929+int Coder::ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path,
2930+                                  const MicroParam &param, bool enable_fp16) {
2931   micro::Coder code_gen;
2932   if (!code_gen.InitPath(output_path)) {
2933     MS_LOG(ERROR) << "Init path failed";
2934     return RET_ERROR;
2935   }
2936   // codegeneration for micro
2937-  STATUS status = code_gen.Init(codegen_mode, device, support_parallel, debug_mode);
2938+  STATUS status = code_gen.Init(param);
2939   if (status != RET_OK) {
2940     MS_LOG(ERROR) << "Codegen init Error";
2941     return RET_ERROR;
2942   }
2943-  status = code_gen.Run(builder.GetBufferPointer(), size, enableFp16);
2944+  status = code_gen.Run(model_buf, size, enable_fp16);
2945   if (status != RET_OK) {
2946     MS_LOG(ERROR) << "Codegen Run Error";
2947     return RET_ERROR;
2948@@ -120,29 +143,30 @@ int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std:
2949   return RET_OK;
2950 }
2951
2952-int Coder::Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode) const {
2953+int Coder::Init(const MicroParam &param) const {
2954   static const std::map<std::string, Target> kTargetMap = {
2955     {"x86", kX86}, {"Cortex-M", kCortex_M}, {"ARM32", kARM32}, {"ARM64", kARM64}, {"All", kAllTargets}};
2956   static const std::map<std::string, CodeMode> kCodeModeMap = {{"Inference", Inference}, {"Train", Train}};
2957   Configurator *config = Configurator::GetInstance();
2958
2959-  auto target_item = kTargetMap.find(target);
2960-  MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + target);
2961+  auto target_item = kTargetMap.find(param.target);
2962+  MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + param.target);
2963   config->set_target(target_item->second);
2964
2965-  auto code_item = kCodeModeMap.find(code_mode);
2966-  MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + code_mode);
2967+  auto code_item = kCodeModeMap.find(param.codegen_mode);
2968+  MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + param.codegen_mode);
2969   config->set_code_mode(code_item->second);
2970
2971-  if (support_parallel && config->target() == kCortex_M) {
2972+  if (param.support_parallel && config->target() == kCortex_M) {
2973     MS_LOG(ERROR) << "Cortex-M cannot support parallel.";
2974     return RET_ERROR;
2975   }
2976-  config->set_support_parallel(support_parallel);
2977-  config->set_debug_mode(debug_mode);
2978+  config->set_support_parallel(param.support_parallel);
2979+  config->set_debug_mode(param.debug_mode);
2980
2981   config->set_proj_dir(model_name_);
2982-
2983+  config->set_keep_original_weight(param.keep_original_weight);
2984+  config->set_changeable_weights_name(param.changeable_weights_name);
2985   const std::string slash = std::string(kSlash);
2986   if (!save_path_.empty() && !DirExists(save_path_)) {
2987     MS_LOG(ERROR) << "code_gen code path " << save_path_ << " is not valid";
2988@@ -170,6 +194,7 @@ int Coder::Init(const std::string code_mode, const std::string target, bool supp
2989   print_parameter("codePath", config->code_path());
2990   print_parameter("codeMode", config->code_mode());
2991   print_parameter("debugMode", config->debug_mode());
2992+  print_parameter("keepOriginalWeight", config->keep_original_weight());
2993   return RET_OK;
2994 }
2995 }  // namespace mindspore::lite::micro
2996diff --git a/mindspore/lite/tools/converter/micro/coder/coder.h b/mindspore/lite/tools/converter/micro/coder/coder.h
2997index 96531e6f..0753156a 100644
2998--- a/mindspore/lite/tools/converter/micro/coder/coder.h
2999+++ b/mindspore/lite/tools/converter/micro/coder/coder.h
3000@@ -31,11 +31,14 @@ class Coder final {
3001
3002   ~Coder() = default;
3003   static int MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path,
3004-                                       const std::string &codegen_mode, const std::string &device,
3005-                                       bool support_parallel, bool debug_mode, bool enableFp16);
3006+                                       const MicroParam &param, bool enable_fp16);
3007+  static int MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path,
3008+                                       const MicroParam &param, bool enable_fp16);
3009
3010  private:
3011-  int Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode_) const;
3012+  static int ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path,
3013+                                    const MicroParam &param, bool enable_fp16);
3014+  int Init(const MicroParam &param) const;
3015   int Run(const void *model_buff, size_t size, bool enableFp16);
3016   bool InitPath(const std::string &output_path);
3017   std::shared_ptr<CoderSession> session_{nullptr};
3018diff --git a/mindspore/lite/tools/converter/micro/coder/config.h b/mindspore/lite/tools/converter/micro/coder/config.h
3019index 84285932..42e0f50e 100644
3020--- a/mindspore/lite/tools/converter/micro/coder/config.h
3021+++ b/mindspore/lite/tools/converter/micro/coder/config.h
3022@@ -26,9 +26,11 @@ enum CodeMode { Inference = 0, Train = 1, Code_Unknown = 99 };
3023 struct MicroParam {
3024   std::string codegen_mode = "Inference";
3025   std::string target;
3026+  std::string changeable_weights_name;
3027   bool enable_micro{false};
3028   bool support_parallel{false};
3029   bool debug_mode{false};
3030+  bool keep_original_weight{false};
3031 };
3032
3033 class Configurator {
3034@@ -56,6 +58,12 @@ class Configurator {
3035   void set_proj_dir(std::string dir) { proj_dir_ = dir; }
3036   std::string proj_dir() const { return proj_dir_; }
3037
3038+  void set_keep_original_weight(bool keep_weight) { keep_original_weight_ = keep_weight; }
3039+  bool keep_original_weight() const { return keep_original_weight_; }
3040+
3041+  void set_changeable_weights_name(const std::string &weights_name) { changeable_weights_name_ = weights_name; }
3042+  const std::string &changeable_weights_name() const { return changeable_weights_name_; }
3043+
3044  private:
3045   Configurator() = default;
3046   ~Configurator() = default;
3047@@ -64,7 +72,9 @@ class Configurator {
3048   CodeMode code_mode_{Code_Unknown};
3049   bool support_parallel_{false};
3050   bool debug_mode_{false};
3051+  bool keep_original_weight_{false};
3052   std::string proj_dir_;
3053+  std::string changeable_weights_name_;
3054 };
3055 }  // namespace mindspore::lite::micro
3056 #endif  // MICRO_CODER_CONFIG_H
3057diff --git a/mindspore/lite/tools/converter/micro/coder/context.h b/mindspore/lite/tools/converter/micro/coder/context.h
3058index 724475fe..cec385bb 100644
3059--- a/mindspore/lite/tools/converter/micro/coder/context.h
3060+++ b/mindspore/lite/tools/converter/micro/coder/context.h
3061@@ -69,6 +69,25 @@ class CoderContext {
3062   void set_saved_weights(const std::map<std::string, Tensor *> &saved_weights) { saved_weights_ = saved_weights; }
3063   std::map<std::string, Tensor *> saved_weights() const { return saved_weights_; }
3064
3065+  void set_origin_weights(const std::vector<Tensor *> &origin_weights) { origin_weights_ = origin_weights; }
3066+  const std::vector<Tensor *> &origin_weights() const { return origin_weights_; }
3067+
3068+  void set_auxiliary_weights(const std::map<Tensor *, std::pair<Tensor *, std::string>> &auxiliary_weights) {
3069+    auxiliary_weights_ = auxiliary_weights;
3070+  }
3071+  const std::map<Tensor *, std::pair<Tensor *, std::string>> &auxiliary_weights() const { return auxiliary_weights_; }
3072+
3073+  bool JudgeIsValid(bool keep_origin_weight) {
3074+    if (!keep_origin_weight) {
3075+      return true;
3076+    }
3077+    return std::all_of(saved_weights_.begin(), saved_weights_.end(),
3078+                       [this](const std::pair<std::string, Tensor *> &item) {
3079+                         return std::find(this->origin_weights_.begin(), this->origin_weights_.end(), item.second) !=
3080+                                this->origin_weights_.end();
3081+                       });
3082+  }
3083+
3084   void set_total_buffer_size(size_t size) { total_buffer_size_ = size; }
3085   size_t total_buffer_size() const { return total_buffer_size_; }
3086
3087@@ -107,7 +126,11 @@ class CoderContext {
3088  private:
3089   std::vector<Tensor *> graph_inputs_;
3090   std::vector<Tensor *> graph_outputs_;
3091-  // primitive const tensors, parsed from model, without packed.
3092+  // primitive const tensors, parsed from model, without packed. Maybe exist tensor is not used.
3093+  std::vector<Tensor *> origin_weights_;
3094+  // assistant content for origin-weights if needed.
3095+  std::map<Tensor *, std::pair<Tensor *, std::string>> auxiliary_weights_;
3096+  // primitive const tensors, parsed from model, with packed. Tensors are all real used.
3097   std::map<std::string, Tensor *> saved_weights_;
3098   // all tensors, include parsed from model and packed tensors.
3099   std::map<Tensor *, std::string> tensors_map_;
3100diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc
3101index 058f0ba0..d30e0133 100644
3102--- a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc
3103+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc
3104@@ -141,6 +141,7 @@ void CodeMSModelDestory(std::ofstream &ofs, const Configurator *config) {
3105   }
3106   ofs << "    MSTensorHandleArrayDestroy(micro_model->inputs);\n"
3107          "    MSTensorHandleArrayDestroy(micro_model->outputs);\n"
3108+         "    FreeResource();\n"
3109          "    free(*model);\n"
3110          "    *model = NULL;\n"
3111          "  }\n";
3112@@ -331,10 +332,12 @@ void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderCo
3113     }
3114     ofs << "  void **allocated[] = {\n";
3115     size_t num = 0;
3116+    auto &w_auxiliary = ctx->auxiliary_weights();
3117     for (const auto &item : ctx->tensors_map()) {
3118       Tensor *tensor = item.first;
3119       std::string name = item.second;
3120-      if (tensor->data() != nullptr && !(CheckConstantTensor(tensor))) {
3121+      if (tensor->data() != nullptr &&
3122+          (!(CheckConstantTensor(tensor)) || w_auxiliary.find(tensor) != w_auxiliary.end())) {
3123         ofs << "    (void**)&" << name << ",\n";
3124         num++;
3125       }
3126diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc
3127index 4d102391..d0824ecb 100644
3128--- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc
3129+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc
3130@@ -22,6 +22,32 @@
3131 #include "coder/opcoders/parallel.h"
3132
3133 namespace mindspore::lite::micro {
3134+namespace {
3135+struct camp {
3136+  bool operator()(const std::string &a, const std::string &b) const { return a.size() < b.size() || a < b; }
3137+};
3138+
3139+std::string GenerateArrayContent(const std::vector<size_t> &contents, const std::string &prefix) {
3140+  std::string lines;
3141+  std::string line = prefix;
3142+  for (auto content : contents) {
3143+    std::string append = std::to_string(content) + ", ";
3144+    if (line == prefix) {
3145+      line += append;
3146+      continue;
3147+    }
3148+    if (line.size() + append.size() > 120) {
3149+      lines += line + "\n";
3150+      line = prefix + append;
3151+    } else {
3152+      line += append;
3153+    }
3154+  }
3155+  lines += line + "\n";
3156+  return lines;
3157+}
3158+}  // namespace
3159+
3160 void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) {
3161   ofs << g_hwLicense;
3162   // include all operator header
3163@@ -71,10 +97,11 @@ void CodeModelParamsData(std::ofstream &ofs, const std::map<std::string, Tensor
3164 void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::unique_ptr<CoderContext> &ctx,
3165                            const Configurator &config) {
3166   // reverse key and value of tensors_map
3167-  std::map<std::string, Tensor *> address_map;
3168+  std::map<std::string, Tensor *, camp> address_map;
3169   for (const auto &item : ctx->tensors_map()) {
3170     address_map.insert(std::make_pair(item.second, item.first));
3171   }
3172+  auto &w_auxiliary = ctx->auxiliary_weights();
3173   for (auto &item : address_map) {
3174     std::string name = item.first;
3175     Tensor *tensor = item.second;
3176@@ -83,13 +110,22 @@ void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::
3177     }
3178     if (CheckConstantTensor(tensor)) {
3179       if (config.target() != kCortex_M) {
3180-        hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[];\n";
3181-        cofs << GetTensorDataType(tensor->data_type()) << name << "[" << tensor->ElementsNum() << "];\n";
3182+        if (w_auxiliary.find(tensor) == w_auxiliary.end()) {
3183+          hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[];  // " << tensor->tensor_name()
3184+               << std::endl;
3185+          cofs << GetTensorDataType(tensor->data_type()) << name << "[" << tensor->ElementsNum() << "];\n";
3186+        } else {
3187+          hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << ";  // "
3188+               << tensor->tensor_name() << std::endl;
3189+          cofs << GetTensorDataType(tensor->data_type()) << "*" << name << " = NULL;\n";
3190+        }
3191       } else {
3192-        hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[];\n";
3193+        hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[];  // " << tensor->tensor_name()
3194+             << std::endl;
3195       }
3196     } else if (tensor->category() == lite::Category::VAR) {
3197-      hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << ";\n";
3198+      hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << ";  // " << tensor->tensor_name()
3199+           << std::endl;
3200       cofs << GetTensorDataType(tensor->data_type()) << "*" << name << " = NULL;\n";
3201     }
3202   }
3203@@ -104,6 +140,186 @@ void CodeInitWeightState(std::ofstream &ofs) {
3204       << "int Init(void *weight_buffer, int weight_size);\n\n";
3205 }
3206
3207+void CodeWeightContentInit(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx,
3208+                           const std::map<Tensor *, int> &tensors_index) {
3209+  auto &w_auxiliary = ctx->auxiliary_weights();
3210+  std::map<std::string, Tensor *, camp> real_need_tensors;
3211+  auto record_saved_tensors = ctx->saved_weights();
3212+  for (auto &item : record_saved_tensors) {
3213+    real_need_tensors.insert(std::make_pair(item.first, item.second));
3214+  }
3215+  std::string non_copy;
3216+  std::string copy_static;
3217+  std::string copy_dynamic;
3218+  int copy_static_num = 0;
3219+  int copy_dynamic_num = 0;
3220+  auto tensors_map = ctx->tensors_map();
3221+  for (const auto &item : real_need_tensors) {
3222+    if (!CheckConstantTensor(item.second) || item.second->data() == nullptr) {
3223+      continue;
3224+    }
3225+    auto iter = tensors_map.find(item.second);
3226+    if (iter == tensors_map.end()) {
3227+      TypeId data_type = item.second->data_type();
3228+      non_copy += "  " + GetTensorDataType(data_type) + "*" + item.first + " = (weight_buffer + offsets[" +
3229+                  std::to_string(tensors_index.at(item.second)) + "]);\n";
3230+      continue;
3231+    }
3232+    if (w_auxiliary.find(item.second) == w_auxiliary.end()) {
3233+      copy_static += "    {" + item.first + ", " + std::to_string(tensors_index.at(item.second)) + "},\n";
3234+      ++copy_static_num;
3235+    } else {
3236+      copy_dynamic += "    {&" + item.first + ", " + std::to_string(tensors_index.at(item.second)) + "},\n";
3237+      ++copy_dynamic_num;
3238+    }
3239+  }
3240+  for (const auto &item : w_auxiliary) {
3241+    copy_static += "    {" + item.second.second + ", " + std::to_string(tensors_index.at(item.second.first)) + "},\n";
3242+    ++copy_static_num;
3243+  }
3244+  ofs << non_copy << "\n";
3245+  if (copy_static_num > 0) {
3246+    ofs << "  {\n  struct ModelParameter static_copy[] = {\n" << copy_static << "    };\n";
3247+    ofs << "    for(int i = 0; i < " << copy_static_num << "; ++i) {\n"
3248+        << "      int index = static_copy[i].index;\n"
3249+        << "      if (offsets[index] + tensors_size[index] > weight_size) {\n"
3250+           "        return RET_ERROR;\n"
3251+           "      }\n"
3252+        << "      memcpy(static_copy[i].addr, (weight_buffer + offsets[index]), tensors_size[index]);\n"
3253+        << "    }\n  }\n\n";
3254+  }
3255+  ofs << "  size_t " << ctx->weight_size_name() << " = PackWeightSize() + dynamic_memory;\n";
3256+  ofs << "  if (" << ctx->weight_size_name() << " > 0) {\n";
3257+  ofs << "    " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n";
3258+  ofs << "    if (" << ctx->weight_name() << " == NULL) {\n      return RET_ERROR;\n    }\n";
3259+  ofs << "    memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n";
3260+  ofs << "  }\n";
3261+  ofs << "  size_t " << ctx->weight_offset_name() << " = 0;\n";
3262+  if (copy_dynamic_num > 0) {
3263+    ofs << "  {\n  struct ModelParameter dynamic_copy[] = {\n" << copy_dynamic << "  };\n";
3264+    ofs << "    for(int i = 0; i < " << copy_dynamic_num << "; ++i) {\n"
3265+        << "      int index = dynamic_copy[i].index;\n"
3266+        << "      memcpy(" << ctx->weight_name() << " + " << ctx->weight_offset_name()
3267+        << ", (weight_buffer + offsets[index]), tensors_size[index]);\n"
3268+        << "      *((void **)dynamic_copy[i].addr) = " << ctx->weight_name() << " + " << ctx->weight_offset_name()
3269+        << ";\n"
3270+        << "      " << ctx->weight_offset_name() << " += tensors_size[index];\n"
3271+        << "    }\n  }\n\n";
3272+  }
3273+}
3274+
3275+void CodeWeightInitIfKeepWeight(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
3276+  auto &w_origin = ctx->origin_weights();
3277+  auto &w_auxiliary = ctx->auxiliary_weights();
3278+  std::vector<size_t> tensors_size;
3279+  std::vector<size_t> online_compute_index;
3280+  std::map<Tensor *, int> tensors_index;
3281+  for (auto tensor : w_origin) {
3282+    if (!(CheckConstantTensor(tensor)) || tensor->data() == nullptr) {
3283+      continue;
3284+    }
3285+    auto iter = w_auxiliary.find(tensor);
3286+    if (iter == w_auxiliary.end()) {
3287+      tensors_index[tensor] = tensors_size.size();
3288+      tensors_size.push_back(tensor->Size());
3289+    } else {
3290+      tensors_index[iter->second.first] = tensors_size.size();
3291+      tensors_size.push_back(iter->second.first->Size());
3292+      tensors_index[tensor] = tensors_size.size();
3293+      online_compute_index.push_back(tensors_size.size());
3294+      tensors_size.push_back(DataTypeSize(tensor->data_type()));
3295+    }
3296+  }
3297+  std::vector<size_t> offsets{0};
3298+  int last = online_compute_index.empty() ? tensors_size.size() - 1 : online_compute_index.front();
3299+  for (int i = 1; i <= last; ++i) {
3300+    offsets.push_back(offsets[i - 1] + tensors_size[i - 1]);
3301+  }
3302+  ofs << "int Init(void *weight_buffer, int weight_size) {\n"
3303+      << "  if (weight_buffer == NULL) {\n"
3304+      << "    return RET_ERROR;\n"
3305+      << "  }\n";
3306+  ofs << "  struct ModelParameter {\n"
3307+      << "    void *addr;\n"
3308+      << "    int index;\n"
3309+      << "  };\n";
3310+  ofs << "  int offsets[" << std::to_string(tensors_size.size()) << "] = {\n"
3311+      << GenerateArrayContent(offsets, "      ") << "  };\n";
3312+  ofs << "  size_t tensors_size[" << std::to_string(tensors_size.size()) << "] = {\n"
3313+      << GenerateArrayContent(tensors_size, "         ") << "  };\n";
3314+  ofs << "  size_t dynamic_memory = 0;\n";
3315+  offsets.insert(offsets.end(), tensors_size.size() - offsets.size(), 0);
3316+  if (!online_compute_index.empty()) {
3317+    ofs << "  int online_compute_index[] = {\n" << GenerateArrayContent(online_compute_index, "      ") << "  };\n";
3318+    ofs << "  for (size_t i = 0; i < " << std::to_string(online_compute_index.size()) + "; ++i) {\n";
3319+    ofs << "    int *shape = (int *)(weight_buffer + offsets[online_compute_index[i] - 1]);\n";
3320+    ofs << "    int dim_num = tensors_size[online_compute_index[i] - 1] / 4;\n";
3321+    ofs << "    size_t tensor_size = tensors_size[online_compute_index[i]];\n";
3322+    ofs << "    for (int j = 0; j < dim_num; ++j) {\n";
3323+    ofs << "      tensor_size *= shape[j];\n";
3324+    ofs << "    }\n";
3325+    ofs << "    tensors_size[online_compute_index[i]] = tensor_size;\n";
3326+    ofs << "    dynamic_memory += tensor_size;\n";
3327+    ofs << "    int next_index = (i + 1) < " << std::to_string(online_compute_index.size())
3328+        << " ? online_compute_index[i + 1] : " << std::to_string(tensors_size.size()) << " - 1;\n";
3329+    ofs << "    for (int j = online_compute_index[i] + 1; j <= next_index; ++j) {\n";
3330+    ofs << "      offsets[j] = offsets[j - 1] + tensors_size[j - 1];\n";
3331+    ofs << "    }\n  }\n";
3332+  }
3333+  CodeWeightContentInit(ofs, ctx, tensors_index);
3334+}
3335+
3336+void CodeWeightInitIfNonKeepWeight(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) {
3337+  ofs << "int Init(void *weight_buffer, int weight_size) {\n"
3338+      << "  if (weight_buffer == NULL) {\n"
3339+      << "    return RET_ERROR;\n"
3340+      << "  }\n";
3341+  ofs << "  struct ModelParameter {\n"
3342+      << "    void *addr;\n"
3343+      << "    size_t size;\n"
3344+      << "    size_t offset;\n"
3345+      << "  };\n";
3346+
3347+  ofs << "  size_t " << ctx->weight_size_name() << " = PackWeightSize();\n";
3348+  size_t params_num = 0;
3349+  size_t offset = 0;
3350+  std::string params;
3351+  std::string origins;
3352+  for (const auto &item : ctx->saved_weights()) {
3353+    std::string name = item.first;
3354+    Tensor *tensor = item.second;
3355+    if (!CheckConstantTensor(tensor)) {
3356+      continue;
3357+    }
3358+    std::map<Tensor *, std::string> ctx_tensor_map = ctx->tensors_map();
3359+    auto iter = ctx_tensor_map.find(tensor);
3360+    if (iter != ctx_tensor_map.end()) {
3361+      origins += "    {" + name + ", " + std::to_string(tensor->Size()) + ", " + std::to_string(offset) + "},\n";
3362+      params_num++;
3363+    } else {
3364+      TypeId data_type = tensor->data_type();
3365+      params +=
3366+        "  " + GetTensorDataType(data_type) + "*" + name + " = (weight_buffer + " + std::to_string(offset) + ");\n";
3367+    }
3368+    offset += tensor->Size();
3369+  }
3370+  ofs << params << "\n";
3371+  ofs << "  struct ModelParameter model_params[] = {\n" << origins << "  };\n";
3372+  ofs << "\n";
3373+  ofs << "  for(int i = 0; i < " << params_num << "; ++i) {\n"
3374+      << "    if (model_params[i].offset + model_params[i].size > weight_size) {\n"
3375+         "      return RET_ERROR;\n"
3376+         "    }\n"
3377+      << "    memcpy(model_params[i].addr, (weight_buffer + model_params[i].offset), model_params[i].size);\n"
3378+      << "  }\n";
3379+  ofs << "  if (" << ctx->weight_size_name() << " > 0) {\n";
3380+  ofs << "    " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n";
3381+  ofs << "    if (" << ctx->weight_name() << " == NULL) {\n      return RET_ERROR;\n    }\n";
3382+  ofs << "    memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n";
3383+  ofs << "  }\n";
3384+  ofs << "  size_t " << ctx->weight_offset_name() << " = 0;\n";
3385+}
3386+
3387 void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) {
3388   if (config.target() != kCortex_M) {
3389     ofs << "static size_t PackWeightSize() {\n";
3390@@ -114,58 +330,16 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext>
3391     ofs << "  return w_size;\n";
3392     ofs << "}\n\n";
3393
3394-    ofs << "int Init(void *weight_buffer, int weight_size) {\n"
3395-        << "  if (weight_buffer == NULL) {\n"
3396-        << "    return RET_ERROR;\n"
3397-        << "  }\n";
3398-    ofs << "  struct ModelParameter {\n"
3399-        << "    void *addr;\n"
3400-        << "    size_t size;\n"
3401-        << "    size_t offset;\n"
3402-        << "  };\n";
3403-
3404-    ofs << "  size_t " << ctx->weight_size_name() << " = PackWeightSize();\n";
3405-    size_t params_num = 0;
3406-    size_t offset = 0;
3407-    std::string params;
3408-    std::string origins;
3409-    for (const auto &item : ctx->saved_weights()) {
3410-      std::string name = item.first;
3411-      Tensor *tensor = item.second;
3412-      if (!CheckConstantTensor(tensor)) {
3413-        continue;
3414-      }
3415-      std::map<Tensor *, std::string> ctx_tensor_map = ctx->tensors_map();
3416-      auto iter = ctx_tensor_map.find(tensor);
3417-      if (iter != ctx_tensor_map.end()) {
3418-        origins += "    {" + name + ", " + std::to_string(tensor->Size()) + ", " + std::to_string(offset) + "},\n";
3419-        params_num++;
3420-      } else {
3421-        TypeId data_type = tensor->data_type();
3422-        params +=
3423-          "  " + GetTensorDataType(data_type) + "*" + name + " = (weight_buffer + " + std::to_string(offset) + ");\n";
3424-      }
3425-      offset += tensor->Size();
3426-    }
3427-    ofs << params << "\n";
3428-    ofs << "  struct ModelParameter model_params[] = {\n" << origins << "  };\n";
3429-    ofs << "\n";
3430-    ofs << "  for(int i = 0; i < " << params_num << "; ++i) {\n"
3431-        << "    if (model_params[i].offset + model_params[i].size > weight_size) {\n"
3432-           "      return RET_ERROR;\n"
3433-           "    }\n"
3434-        << "    memcpy(model_params[i].addr, (weight_buffer + model_params[i].offset), model_params[i].size);\n"
3435-        << "  }\n";
3436-    ofs << "  if (" << ctx->weight_size_name() << " > 0) {\n";
3437-    ofs << "    " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n";
3438-    ofs << "    if (" << ctx->weight_name() << " == NULL) {\n      return RET_ERROR;\n    }\n";
3439-    ofs << "    memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n";
3440-    ofs << "  }\n";
3441+    if (config.keep_original_weight()) {
3442+      CodeWeightInitIfKeepWeight(ofs, ctx);
3443+    } else {
3444+      CodeWeightInitIfNonKeepWeight(ofs, ctx);
3445+    }
3446   } else {
3447     ofs << "int Init(void *weight_buffer, int weight_size) {\n";
3448     ofs << "  const size_t w_size = " << ctx->weight_buffer_size() << ";\n";
3449+    ofs << "  size_t " << ctx->weight_offset_name() << " = 0;\n";
3450   }
3451-  ofs << "  size_t " << ctx->weight_offset_name() << " = 0;\n";
3452   for (const auto &block : ctx->init_contents()) {
3453     ofs << "{\n" << block << "}\n";
3454   }
3455@@ -175,11 +349,26 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext>
3456   ofs << "}\n\n";
3457 }
3458
3459-void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const std::string &net_file) {
3460+void SaveDataToNet(const std::unique_ptr<CoderContext> &ctx, const std::string &net_file, bool keep_weight) {
3461   std::ofstream net(net_file, std::ios::out | std::ios::trunc | std::ios::binary);
3462   MS_CHECK_TRUE_WITHOUT_RET(net.is_open(), "net file open failed!");
3463-  for (auto &item : saved_weights) {
3464-    Tensor *tensor = item.second;
3465+  std::vector<Tensor *> save_tensors;
3466+  if (keep_weight) {
3467+    auto &w_origin = ctx->origin_weights();
3468+    auto &w_auxiliary = ctx->auxiliary_weights();
3469+    (void)std::for_each(w_origin.begin(), w_origin.end(), [&save_tensors, &w_auxiliary](Tensor *tensor) {
3470+      auto iter = w_auxiliary.find(tensor);
3471+      if (iter != w_auxiliary.end()) {
3472+        save_tensors.push_back(iter->second.first);
3473+      }
3474+      save_tensors.push_back(tensor);
3475+    });
3476+  } else {
3477+    auto recorded_saved_tensors = ctx->saved_weights();
3478+    (void)std::transform(recorded_saved_tensors.begin(), recorded_saved_tensors.end(), std::back_inserter(save_tensors),
3479+                         [](const std::pair<std::string, Tensor *> &item) { return item.second; });
3480+  }
3481+  for (auto tensor : save_tensors) {
3482     if ((CheckConstantTensor(tensor)) && tensor->data() != nullptr) {
3483       net.write(reinterpret_cast<const char *>(tensor->data()), tensor->Size());
3484     }
3485diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h
3486index 3a68a540..98c56afd 100644
3487--- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h
3488+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h
3489@@ -31,7 +31,7 @@ void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr<CoderContext
3490 void CodeModelParamsState(std::ofstream &ofs, const std::map<std::string, Tensor *> &weights);
3491 void CodeModelParamsData(std::ofstream &ofs, const std::map<std::string, Tensor *> &weights);
3492
3493-void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const std::string &net_file);
3494+void SaveDataToNet(const std::unique_ptr<CoderContext> &ctx, const std::string &net_file, bool keep_weight);
3495 void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::unique_ptr<CoderContext> &ctx,
3496                            const Configurator &config);
3497
3498diff --git a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc
3499index 5b29978f..8add577f 100644
3500--- a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc
3501+++ b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc
3502@@ -259,7 +259,7 @@ int Generator::CodeWeightFile() {
3503     cofs << "unsigned char * " << ctx_->buffer_name() << " = 0; \n";
3504     cofs << "unsigned char * " << ctx_->weight_name() << " = 0; \n";
3505     std::string net_file = net_src_file_path_ + "net.bin";
3506-    SaveDataToNet(ctx_->saved_weights(), net_file);
3507+    SaveDataToNet(ctx_, net_file, config_->keep_original_weight());
3508   } else {
3509     if (!ctx_->weight_buffer_size_code_blocks().empty()) {
3510       MS_LOG(ERROR) << "Weight init code generation error ";
3511diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc
3512index d25b3e6b..56b22333 100644
3513--- a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc
3514+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc
3515@@ -41,13 +41,18 @@ int ReshapeBaseCoder::DoCode(CoderContext *const context) {
3516 }
3517
3518 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Reshape, CPUOpCoderCreator<ReshapeBaseCoder>)
3519+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Reshape, CPUOpCoderCreator<ReshapeBaseCoder>)
3520 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Reshape, CPUOpCoderCreator<ReshapeBaseCoder>)
3521 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Flatten, CPUOpCoderCreator<ReshapeBaseCoder>)
3522+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Flatten, CPUOpCoderCreator<ReshapeBaseCoder>)
3523 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Flatten, CPUOpCoderCreator<ReshapeBaseCoder>)
3524 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_ExpandDims, CPUOpCoderCreator<ReshapeBaseCoder>)
3525+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_ExpandDims, CPUOpCoderCreator<ReshapeBaseCoder>)
3526 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_ExpandDims, CPUOpCoderCreator<ReshapeBaseCoder>)
3527 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Squeeze, CPUOpCoderCreator<ReshapeBaseCoder>)
3528+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Squeeze, CPUOpCoderCreator<ReshapeBaseCoder>)
3529 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Squeeze, CPUOpCoderCreator<ReshapeBaseCoder>)
3530 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CPUOpCoderCreator<ReshapeBaseCoder>)
3531+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Unsqueeze, CPUOpCoderCreator<ReshapeBaseCoder>)
3532 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Unsqueeze, CPUOpCoderCreator<ReshapeBaseCoder>)
3533 }  // namespace mindspore::lite::micro
3534diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc
3535new file mode 100644
3536index 00000000..ee887342
3537--- /dev/null
3538+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc
3539@@ -0,0 +1,85 @@
3540+/**
3541+ * Copyright 2023 Huawei Technologies Co., Ltd
3542+ *
3543+ * Licensed under the Apache License, Version 2.0 (the "License");
3544+ * you may not use this file except in compliance with the License.
3545+ * You may obtain a copy of the License at
3546+ *
3547+ * http://www.apache.org/licenses/LICENSE-2.0
3548+ *
3549+ * Unless required by applicable law or agreed to in writing, software
3550+ * distributed under the License is distributed on an "AS IS" BASIS,
3551+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3552+ * See the License for the specific language governing permissions and
3553+ * limitations under the License.
3554+ */
3555+#include "coder/opcoders/base/stack_base_coder.h"
3556+#include <string>
3557+#include <vector>
3558+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
3559+#include "coder/opcoders/file_collector.h"
3560+#include "coder/opcoders/parallel.h"
3561+
3562+using mindspore::schema::PrimitiveType_Stack;
3563+
3564+namespace mindspore::lite::micro::nnacl {
3565+int StackFP32Coder::Prepare(CoderContext *const context) {
3566+  stack_param_ = reinterpret_cast<StackParameter *>(parameter_);
3567+  return ReSize();
3568+}
3569+
3570+int StackFP32Coder::ReSize() {
3571+  axis_ = stack_param_->axis_ >= 0 ? stack_param_->axis_
3572+                                   : static_cast<int>(input_tensor_->shape().size()) + stack_param_->axis_ + 1;
3573+  if (axis_ < 0 || axis_ > static_cast<int>(input_tensor_->shape().size())) {
3574+    return RET_ERROR;
3575+  }
3576+  return RET_OK;
3577+}
3578+
3579+int StackFP32Coder::DoCode(CoderContext *const context) {
3580+  Collect(context,
3581+          {
3582+            "nnacl/base/stack_base.h",
3583+          },
3584+          {
3585+            "stack_base.c",
3586+          });
3587+
3588+  size_t input_num = input_tensors_.size();
3589+
3590+  NNaclFp32Serializer code;
3591+  code << "\t\tvoid *inputs_addr[] = {";
3592+  for (size_t i = 0; i < input_num; ++i) {
3593+    code << allocator_->GetRuntimeAddr(input_tensors_.at(i)) << ", ";
3594+  }
3595+  code << "};\n";
3596+
3597+  size_t copy_size = 0;
3598+  int outer_size = 1;
3599+  auto shape = input_tensor_->shape();
3600+  if (input_tensors_.empty()) {
3601+    copy_size = 0;
3602+    outer_size = 0;
3603+  } else if (input_tensors_.size() == 1) {
3604+    copy_size = input_tensor_->ElementsNum();
3605+    outer_size = 1;
3606+  } else {
3607+    copy_size = 1;
3608+    for (int i = axis_; i < static_cast<int>(shape.size()); ++i) {
3609+      copy_size *= shape[i];
3610+    }
3611+    for (int i = 0; i < axis_; ++i) {
3612+      outer_size *= shape[i];
3613+    }
3614+  }
3615+  copy_size *= DataTypeSize(input_tensor_->data_type());
3616+  code.CodeFunction("Stack", "inputs_addr", output_tensor_, input_num, copy_size, 0, outer_size);
3617+  context->AppendCode(code.str());
3618+  return RET_OK;
3619+}
3620+
3621+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Stack, CPUOpCoderCreator<StackFP32Coder>)
3622+REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Stack, CPUOpCoderCreator<StackFP32Coder>)
3623+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Stack, CPUOpCoderCreator<StackFP32Coder>)
3624+}  // namespace mindspore::lite::micro::nnacl
3625diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h
3626new file mode 100644
3627index 00000000..08074332
3628--- /dev/null
3629+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h
3630@@ -0,0 +1,42 @@
3631+/**
3632+ * Copyright 2023 Huawei Technologies Co., Ltd
3633+ *
3634+ * Licensed under the Apache License, Version 2.0 (the "License");
3635+ * you may not use this file except in compliance with the License.
3636+ * You may obtain a copy of the License at
3637+ *
3638+ * http://www.apache.org/licenses/LICENSE-2.0
3639+ *
3640+ * Unless required by applicable law or agreed to in writing, software
3641+ * distributed under the License is distributed on an "AS IS" BASIS,
3642+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3643+ * See the License for the specific language governing permissions and
3644+ * limitations under the License.
3645+ */
3646+
3647+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_
3648+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_
3649+
3650+#include <vector>
3651+#include "coder/opcoders/op_coder.h"
3652+#include "nnacl/stack_parameter.h"
3653+
3654+namespace mindspore::lite::micro::nnacl {
3655+class StackFP32Coder final : public OperatorCoder {
3656+ public:
3657+  StackFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
3658+                 const LiteGraph::Node *node, size_t node_index, Target target)
3659+      : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
3660+  ~StackFP32Coder() override = default;
3661+
3662+  int Prepare(CoderContext *const context) override;
3663+  int DoCode(CoderContext *const context) override;
3664+
3665+ private:
3666+  int ReSize();
3667+
3668+  int axis_{0};
3669+  StackParameter *stack_param_{nullptr};
3670+};
3671+}  // namespace mindspore::lite::micro::nnacl
3672+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_
3673diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc
3674index ba9fbaa1..ffc70e1c 100644
3675--- a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc
3676+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc
3677@@ -33,6 +33,8 @@ size_t GetInnerSize(TypeId type_id, int inner_elements) {
3678       return inner_elements * sizeof(float);
3679     case kNumberTypeInt32:
3680       return inner_elements * sizeof(int32_t);
3681+    case kNumberTypeFloat16:
3682+      return inner_elements * sizeof(uint16_t);
3683     default:
3684       MS_LOG(ERROR) << "Not supported data type: " << type_id;
3685       return 0;
3686@@ -142,6 +144,23 @@ int StridedSliceBaseCoder::DoFastCode(CoderContext *ctx) {
3687 }
3688
3689 int StridedSliceBaseCoder::DoNormalCode(CoderContext *ctx) {
3690+  switch (input_tensor_->data_type()) {
3691+    case kNumberTypeInt8:
3692+      strided_slice_parameter_->data_type = ::kNumberTypeInt8;
3693+      break;
3694+    case kNumberTypeFloat32:
3695+      strided_slice_parameter_->data_type = ::kNumberTypeFloat32;
3696+      break;
3697+    case kNumberTypeInt32:
3698+      strided_slice_parameter_->data_type = ::kNumberTypeInt32;
3699+      break;
3700+    case kNumberTypeFloat16:
3701+      strided_slice_parameter_->data_type = ::kNumberTypeFloat16;
3702+      break;
3703+    default:
3704+      MS_LOG(ERROR) << "Not supported data type: " << input_tensor_->data_type();
3705+      return RET_ERROR;
3706+  }
3707   nnacl::NNaclFp32Serializer code;
3708   code.CodeStruct("strided_slice_parameter", *strided_slice_parameter_);
3709   code.CodeFunction("DoStridedSlice", input_tensor_, output_tensor_,
3710@@ -166,6 +185,8 @@ int StridedSliceBaseCoder::DoCode(CoderContext *ctx) {
3711 }
3712 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_StridedSlice,
3713                    CPUOpCoderCreator<StridedSliceBaseCoder>)
3714+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_StridedSlice,
3715+                   CPUOpCoderCreator<StridedSliceBaseCoder>)
3716 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_StridedSlice, CPUOpCoderCreator<StridedSliceBaseCoder>)
3717 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_StridedSlice, CPUOpCoderCreator<StridedSliceBaseCoder>)
3718 }  // namespace mindspore::lite::micro
3719diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc
3720new file mode 100644
3721index 00000000..5470b56a
3722--- /dev/null
3723+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc
3724@@ -0,0 +1,34 @@
3725+/**
3726+ * Copyright 2023 Huawei Technologies Co., Ltd
3727+ *
3728+ * Licensed under the Apache License, Version 2.0 (the "License");
3729+ * you may not use this file except in compliance with the License.
3730+ * You may obtain a copy of the License at
3731+ *
3732+ * http://www.apache.org/licenses/LICENSE-2.0
3733+ *
3734+ * Unless required by applicable law or agreed to in writing, software
3735+ * distributed under the License is distributed on an "AS IS" BASIS,
3736+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3737+ * See the License for the specific language governing permissions and
3738+ * limitations under the License.
3739+ */
3740+#include "coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h"
3741+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
3742+#include "coder/opcoders/file_collector.h"
3743+
3744+using mindspore::schema::PrimitiveType_Custom;
3745+
3746+namespace mindspore::lite::micro::nnacl {
3747+void CustomGruFP16Coder::InitNnaclFile(CoderContext *const context) {
3748+  Collect(context, {"nnacl/fp16/custom_gru_fp16.h"},
3749+          {"custom_gru_fp16.c", "pack_fp16.c", "matmul_fp16.c", "arithmetic_fp16.c", "activation_fp16.c"});
3750+}
3751+
3752+void CustomGruFP16Coder::InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst,
3753+                                         int row, int col) {
3754+  init_code->CodeFunction("RowMajor2Col8MajorFp16", src, dst, row, col, false);
3755+}
3756+
3757+REG_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Custom, CPUOpCoderCreator<CustomGruFP16Coder>)
3758+}  // namespace mindspore::lite::micro::nnacl
3759diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h
3760new file mode 100644
3761index 00000000..eb76faf6
3762--- /dev/null
3763+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h
3764@@ -0,0 +1,44 @@
3765+/**
3766+ * Copyright 2023 Huawei Technologies Co., Ltd
3767+ *
3768+ * Licensed under the Apache License, Version 2.0 (the "License");
3769+ * you may not use this file except in compliance with the License.
3770+ * You may obtain a copy of the License at
3771+ *
3772+ * http://www.apache.org/licenses/LICENSE-2.0
3773+ *
3774+ * Unless required by applicable law or agreed to in writing, software
3775+ * distributed under the License is distributed on an "AS IS" BASIS,
3776+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3777+ * See the License for the specific language governing permissions and
3778+ * limitations under the License.
3779+ */
3780+
3781+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_
3782+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_
3783+
3784+#include <string>
3785+#include <vector>
3786+#include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h"
3787+#include "nnacl/custom_gru_parameter.h"
3788+
3789+namespace mindspore::lite::micro::nnacl {
3790+class CustomGruFP16Coder : public CustomGruFP32Coder {
3791+ public:
3792+  CustomGruFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
3793+                     const LiteGraph::Node *node, size_t node_index, Target target)
3794+      : CustomGruFP32Coder(in_tensors, out_tensors, node, node_index, target) {
3795+    data_type_ = kNumberTypeFloat16;
3796+    row_tile_ = C4NUM;
3797+    col_tile_ = C8NUM;
3798+    op_func_ = "CustomGruFp16";
3799+  }
3800+  ~CustomGruFP16Coder() override = default;
3801+
3802+ protected:
3803+  void InitNnaclFile(CoderContext *const context) override;
3804+  void InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, int row,
3805+                       int col) override;
3806+};
3807+}  // namespace mindspore::lite::micro::nnacl
3808+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_
3809diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc
3810index f2aec9d2..37b90b65 100644
3811--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc
3812+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc
3813@@ -30,13 +30,12 @@ int MatMulFP16BaseCoder::InitBiasData() {
3814   if (bias_ptr_) {
3815     return RET_OK;
3816   }
3817-  bias_pack_ptr_size_ = static_cast<size_t>(params_->col_align_ * data_type_size_);
3818+  bias_pack_ptr_size_ = static_cast<size_t>(params_->col_align_ * DataTypeSize(data_type_));
3819   if (input_tensors_.size() == C3NUM) {
3820-    bias_ptr_ = allocator_->Malloc(kNumberTypeUInt8, kOnlineSize, kOnlinePackWeight,
3821-                                   bias_tensor_->tensor_name() + "_online_pack");
3822-  } else {
3823     bias_ptr_ =
3824-      allocator_->Malloc(kNumberTypeUInt8, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack");
3825+      allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack");
3826+  } else {
3827+    bias_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack");
3828   }
3829   return RET_OK;
3830 }
3831@@ -45,18 +44,19 @@ int MatMulFP16BaseCoder::InitBufferA() {
3832   if (a_pack_ptr_ != nullptr || vec_matmul_) {
3833     return RET_OK;
3834   }
3835-  a_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * sizeof(uint16_t));
3836+  a_pack_ptr_size_ =
3837+    static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * DataTypeSize(data_type_));
3838   if (params_->a_const_) {
3839     a_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(0));
3840     if (a_pack_ptr_ == nullptr) {
3841-      a_pack_ptr_ = allocator_->Malloc(kNumberTypeFloat16, kOnlineSize, kOnlinePackWeight,
3842+      a_pack_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight,
3843                                        input_tensors_.at(0)->tensor_name() + "_online_pack");
3844       allocator_->MarkSharedWeight(input_tensors_.at(0), a_pack_ptr_);
3845     } else {
3846       a_packed_ = true;
3847     }
3848   } else {
3849-    a_pack_ptr_ = allocator_->Malloc(kNumberTypeFloat16, a_pack_ptr_size_, kWorkspace);
3850+    a_pack_ptr_ = allocator_->Malloc(data_type_, a_pack_ptr_size_, kWorkspace);
3851   }
3852   MS_CHECK_PTR(a_pack_ptr_);
3853   return RET_OK;
3854@@ -77,7 +77,7 @@ std::string MatMulFP16BaseCoder::InitMatrixA(NNaclFp32Serializer *const code, NN
3855     return allocator_->GetRuntimeAddr(input_tensor_, input_tensor_->IsConst());
3856   }
3857   std::string input_a_str = allocator_->GetRuntimeAddr(input_tensor_);
3858-  std::string input_a_pack_str = "(float16_t *)" + allocator_->GetRuntimeAddr(a_pack_ptr_);
3859+  std::string input_a_pack_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(a_pack_ptr_));
3860   if (params_->a_const_) {
3861     init_code->CodeBufferOffsetExpression(a_pack_ptr_, context->weight_name(), context->weight_offset_name(),
3862                                           context->weight_size_name(), a_pack_ptr_size_);
3863@@ -132,7 +132,7 @@ std::string MatMulFP16BaseCoder::InitMatrixB(NNaclFp32Serializer *const code, NN
3864     return allocator_->GetRuntimeAddr(filter_tensor_, filter_tensor_->IsConst());
3865   }
3866   std::string input_b_str = allocator_->GetRuntimeAddr(filter_tensor_);
3867-  std::string input_b_pack_str = "(float16_t *)" + allocator_->GetRuntimeAddr(b_pack_ptr_);
3868+  std::string input_b_pack_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(b_pack_ptr_));
3869   if (params_->b_const_) {
3870     init_code->CodeBufferOffsetExpression(b_pack_ptr_, context->weight_name(), context->weight_offset_name(),
3871                                           context->weight_size_name(), b_pack_ptr_size_);
3872@@ -248,7 +248,7 @@ int MatMulFP16BaseCoder::DoCode(CoderContext *const context) {
3873   init_code.CodeBufferOffsetExpression(bias_ptr_, context->weight_name(), context->weight_offset_name(),
3874                                        context->weight_size_name(), bias_pack_ptr_size_);
3875   w_buf_size += bias_pack_ptr_size_;
3876-  std::string bias_str = "(float16_t *)" + allocator_->GetRuntimeAddr(bias_ptr_);
3877+  std::string bias_str = allocator_->GetRuntimeAddr(bias_ptr_);
3878   if (input_tensors_.size() == DIMENSION_3D) {
3879     auto origin_bias_str = allocator_->GetRuntimeAddr(bias_tensor_);
3880     init_code.CodeFunction("memcpy", bias_str, origin_bias_str, bias_tensor_->Size());
3881diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h
3882index 864f54ae..38270456 100644
3883--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h
3884+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h
3885@@ -27,7 +27,9 @@ class MatMulFP16BaseCoder : public MatMulFP32BaseCoder {
3886  public:
3887   MatMulFP16BaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
3888                       const LiteGraph::Node *node, size_t node_index, Target target)
3889-      : MatMulFP32BaseCoder(in_tensors, out_tensors, node, node_index, target) {}
3890+      : MatMulFP32BaseCoder(in_tensors, out_tensors, node, node_index, target) {
3891+    data_type_ = kNumberTypeFloat16;
3892+  }
3893
3894   ~MatMulFP16BaseCoder() override = default;
3895
3896diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h
3897index 3a1cb66a..c5ea36cd 100644
3898--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h
3899+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h
3900@@ -26,9 +26,7 @@ class MatMulFP16Coder final : public MatMulFP16BaseCoder {
3901  public:
3902   MatMulFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
3903                   const LiteGraph::Node *node, size_t node_index, Target target)
3904-      : MatMulFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) {
3905-    data_type_size_ = sizeof(uint16_t);
3906-  }
3907+      : MatMulFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) {}
3908
3909   ~MatMulFP16Coder() override = default;
3910
3911diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc
3912index f46005c6..e53472ca 100644
3913--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc
3914+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc
3915@@ -20,42 +20,88 @@
3916 #include "coder/opcoders/parallel.h"
3917 #include "coder/opcoders/file_collector.h"
3918 #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
3919+#include "src/common/utils.h"
3920
3921 namespace mindspore::lite::micro::nnacl {
3922 int ConvolutionDepthwiseFP32Coder::Prepare(CoderContext *const context) {
3923   MS_CHECK_RET_CODE(Conv2DBaseCoder::Init(), "Conv2DBaseCoder::Init() failed!");
3924-  MS_CHECK_RET_CODE(InitWeightBias(), "dwconvolution do init weightbais failed");
3925+  MS_CHECK_RET_CODE(InitParameter(), "dwconvolution do InitParamter failed");
3926+  if (Configurator::GetInstance()->keep_original_weight()) {
3927+    MS_CHECK_RET_CODE(InitWeightBiasOnline(), "dwconvolution do InitWeightBiasOnline failed");
3928+  } else {
3929+    MS_CHECK_RET_CODE(InitWeightBiasOffline(), "dwconvolution do InitWeightBiasOffline failed");
3930+  }
3931   conv_param_->thread_num_ = MSMIN(thread_num_, conv_param_->output_h_);
3932   return RET_OK;
3933 }
3934
3935-int ConvolutionDepthwiseFP32Coder::InitWeightBias() {
3936+int ConvolutionDepthwiseFP32Coder::InitParameter() {
3937+  auto shape = filter_tensor_->shape();
3938+  MS_CHECK_TRUE_MSG(shape.size() == C4NUM, RET_ERROR, "Conv: filter-weight's shape must be 4D.");
3939+  packed_weight_size_ =
3940+    filter_tensor_->Batch() * filter_tensor_->Height() * filter_tensor_->Width() * DataTypeSize(data_type_);
3941+  packed_bias_size_ = filter_tensor_->Batch() * DataTypeSize(data_type_);
3942+  return RET_OK;
3943+}
3944+
3945+int ConvolutionDepthwiseFP32Coder::InitWeightBiasOffline() {
3946   auto *origin_weight = reinterpret_cast<float *>(filter_tensor_->data());
3947   MS_CHECK_PTR(origin_weight);
3948   int channel = filter_tensor_->Batch();
3949-  size_t pack_weight_size = filter_tensor_->Batch() * filter_tensor_->Height() * filter_tensor_->Width();
3950-  size_t packed_weight_data_size = pack_weight_size * sizeof(float);
3951-  packed_weight_ =
3952-    reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, packed_weight_data_size, kOfflinePackWeight));
3953+  packed_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, packed_weight_size_, kOfflinePackWeight));
3954   MS_CHECK_PTR(packed_weight_);
3955-  MS_CHECK_RET_CODE(memset_s(packed_weight_, packed_weight_data_size, 0, packed_weight_data_size),
3956+  MS_CHECK_RET_CODE(memset_s(packed_weight_, packed_weight_size_, 0, packed_weight_size_),
3957                     "memset packed weight failed!");
3958   PackNCHWToNHWCFp32(origin_weight, packed_weight_, 1, filter_tensor_->Height() * filter_tensor_->Width(), channel,
3959                      kDefaultTaskId, 0);
3960
3961-  auto bias_size = static_cast<size_t>(channel * sizeof(float));
3962-  bias_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, bias_size, kOfflinePackWeight));
3963+  bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, packed_bias_size_, kOfflinePackWeight));
3964   MS_CHECK_PTR(bias_);
3965-  MS_CHECK_RET_CODE(memset_s(bias_, bias_size, 0, bias_size), "memset bias failed!");
3966+  MS_CHECK_RET_CODE(memset_s(bias_, packed_bias_size_, 0, packed_bias_size_), "memset bias failed!");
3967   // init bias
3968   if (input_tensors_.size() == kInputSize2) {
3969     auto *ori_bias = reinterpret_cast<float *>(bias_tensor_->data());
3970     MS_CHECK_TRUE(bias_tensor_->ElementsNum() > 0, "invalid bias length");
3971-    MS_CHECK_RET_CODE(memcpy_s(bias_, bias_size, ori_bias, bias_tensor_->Size()), "memcpy_s bias failed!");
3972+    MS_CHECK_RET_CODE(memcpy_s(bias_, packed_bias_size_, ori_bias, bias_tensor_->Size()), "memcpy_s bias failed!");
3973   }
3974   return RET_OK;
3975 }
3976
3977+int ConvolutionDepthwiseFP32Coder::InitWeightBiasOnline() {
3978+  packed_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight));
3979+  MS_CHECK_PTR(packed_weight_);
3980+  bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight));
3981+  MS_CHECK_PTR(bias_);
3982+  return RET_OK;
3983+}
3984+
3985+void ConvolutionDepthwiseFP32Coder::InitCodeOnline(CoderContext *const context) {
3986+  if (!Configurator::GetInstance()->keep_original_weight()) {
3987+    return;
3988+  }
3989+  Collect(context,
3990+          {
3991+            "nnacl/fp32/pack_fp32.h",
3992+          },
3993+          {"pack_fp32.c"});
3994+  NNaclFp32Serializer init_code;
3995+  init_code.CodeBufferOffsetExpression(packed_weight_, context->weight_name(), context->weight_offset_name(),
3996+                                       context->weight_size_name(), packed_weight_size_);
3997+  auto filter_str = allocator_->GetRuntimeAddr(filter_tensor_);
3998+  init_code.CodeFunction("PackNCHWToNHWCFp32", filter_str, packed_weight_, 1,
3999+                         filter_tensor_->Height() * filter_tensor_->Width(), filter_tensor_->Batch(), 0, 0);
4000+  init_code.CodeBufferOffsetExpression(bias_, context->weight_name(), context->weight_offset_name(),
4001+                                       context->weight_size_name(), packed_bias_size_);
4002+  if (input_tensors_.size() == kInputSize2) {
4003+    auto bias_str = allocator_->GetRuntimeAddr(bias_tensor_);
4004+    init_code.CodeFunction("memcpy", bias_, bias_str, bias_tensor_->Size());
4005+  } else {
4006+    init_code.CodeFunction("memcpy", bias_, 0, packed_bias_size_);
4007+  }
4008+  context->AppendInitWeightSizeCode(packed_weight_size_ + packed_bias_size_);
4009+  context->AppendInitCode(init_code.str());
4010+}
4011+
4012 int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) {
4013   MS_CHECK_TRUE(conv_param_->input_channel_ == conv_param_->output_channel_,
4014                 "Only support input channel equals output channel.");
4015@@ -78,6 +124,7 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) {
4016             "activation_fp32.c",
4017           },
4018           {});
4019+  InitCodeOnline(context);
4020   nnacl::NNaclFp32Serializer code;
4021   // call the op function
4022   std::string param_name = "conv_parameter";
4023diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h
4024index a5827f4f..39757871 100644
4025--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h
4026+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h
4027@@ -34,9 +34,15 @@ class ConvolutionDepthwiseFP32Coder final : public Conv2DBaseCoder {
4028   int DoCode(CoderContext *const context) override;
4029
4030  private:
4031-  int InitWeightBias();
4032+  int InitParameter();
4033+  int InitWeightBiasOffline();
4034+  int InitWeightBiasOnline();
4035+  void InitCodeOnline(CoderContext *const context);
4036+  size_t packed_weight_size_{0};
4037   float *packed_weight_{nullptr};
4038+  size_t packed_bias_size_{0};
4039   float *bias_{nullptr};
4040+  TypeId data_type_{kNumberTypeFloat32};
4041 };
4042 }  // namespace mindspore::lite::micro::nnacl
4043 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CONVOLUTION_DEPTHWISE_FP32_CODER_H_
4044diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc
4045index 556f851a..466db21a 100644
4046--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc
4047+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc
4048@@ -77,13 +77,6 @@ const std::array<std::string, kEight> OutputTransFuncRelu6List8 = {"",
4049                                                                    "OutputTransform8x6Relu6Unit",
4050                                                                    "OutputTransform8x7Relu6Unit"};
4051
4052-int ConvolutionWinogradFP32Coder::WinogradFilterTransform(const float *weight_data, float *matrix_g,
4053-                                                          const float *matrix_gt, int oc_block) {
4054-  MS_CHECK_TRUE(oc_block, "Divide by zero!");
4055-  return WinogradWeightTransform(weight_data, trans_weight_, matrix_g, matrix_gt, oc_block, input_unit_, kernel_unit_,
4056-                                 conv_param_->input_channel_, conv_param_->output_channel_, true);
4057-}
4058-
4059 int ConvolutionWinogradFP32Coder::InitTmpBuffer() {
4060   int channel_out = conv_param_->output_channel_;
4061   int oc8 = UP_DIV(channel_out, C8NUM);
4062@@ -115,12 +108,16 @@ int ConvolutionWinogradFP32Coder::Prepare(CoderContext *const context) {
4063   input_unit_ = output_unit_ + kernel_unit_ - 1;
4064   conv_param_->input_unit_ = input_unit_;
4065   conv_param_->output_unit_ = output_unit_;
4066-  ret = InitWeightBias();
4067-  MS_CHECK_RET_CODE(ret, "Init weight bias failed.");
4068+  MS_CHECK_RET_CODE(InitParameter(), "Winograd convolution do InitParameter failed");
4069+  if (Configurator::GetInstance()->keep_original_weight()) {
4070+    MS_CHECK_RET_CODE(InitWeightBiasOnline(), "Winograd convolution do InitWeightBiasOnline failed");
4071+  } else {
4072+    MS_CHECK_RET_CODE(InitWeightBiasOffline(), "Winograd convolution do InitWeightBiasOffline failed");
4073+  }
4074   return ReSize();
4075 }  // namespace micro
4076
4077-int ConvolutionWinogradFP32Coder::InitWeightBias() {
4078+int ConvolutionWinogradFP32Coder::InitParameter() {
4079   int in_channel = filter_tensor_->Channel();
4080   int out_channel = filter_tensor_->Batch();
4081   MS_CHECK_TRUE(in_channel > 0, "invalid in channel size");
4082@@ -132,14 +129,10 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() {
4083   const int oc_block = C8NUM;
4084   int oc_block_num = UP_DIV(out_channel, C8NUM);
4085   // init weight
4086-  int trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block;
4087-  trans_weight_ = reinterpret_cast<float *>(
4088-    allocator_->Malloc(kNumberTypeFloat32, trans_matrix_data_size * sizeof(float), kOfflinePackWeight));
4089-  MS_CHECK_PTR(trans_weight_);
4090-  int ret = memset_s(trans_weight_, trans_matrix_data_size * sizeof(float), 0, trans_matrix_data_size * sizeof(float));
4091-  MS_CHECK_RET_CODE(ret, "memset_s failed!");
4092-  float matrix_g[k64];
4093-  float matrix_gt[k64];
4094+  trans_weight_size_ = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * DataTypeSize(data_type_);
4095+  packed_bias_size_ = oc4 * C4NUM * DataTypeSize(data_type_);
4096+  matrix_g_.resize(k64);
4097+  matrix_gt_.resize(k64);
4098   float matrix_a[k64];
4099   float matrix_at[k64];
4100   float matrix_b[k64];
4101@@ -148,31 +141,41 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() {
4102   if (input_unit_ == DIMENSION_8D) {
4103     coef = 0.5f;
4104   }
4105-  ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_);
4106+  auto ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g_.data(), matrix_gt_.data(), coef,
4107+                            output_unit_, kernel_unit_);
4108   MS_CHECK_RET_CODE(ret, "CookToomFilter failed!");
4109-  auto out_channel_size = static_cast<size_t>(out_channel);
4110+  return RET_OK;
4111+}
4112+
4113+int ConvolutionWinogradFP32Coder::InitWeightBiasOffline() {
4114+  trans_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, trans_weight_size_, kOfflinePackWeight));
4115+  MS_CHECK_PTR(trans_weight_);
4116+  int ret = memset_s(trans_weight_, trans_weight_size_, 0, trans_weight_size_);
4117   auto weight_data = reinterpret_cast<float *>(filter_tensor_->data());
4118   MS_CHECK_PTR(weight_data);
4119-  ret = WinogradFilterTransform(weight_data, matrix_g, matrix_gt, oc_block);
4120+  WinogradWeightTransform(weight_data, trans_weight_, matrix_g_.data(), matrix_gt_.data(), C8NUM, input_unit_,
4121+                          kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_, true);
4122   MS_CHECK_RET_CODE(ret, "winograd filter transform failed!");
4123-  // init bias
4124-  int new_bias_ele_num = oc4 * C4NUM;
4125-  auto new_bias_ele_size = static_cast<size_t>(new_bias_ele_num * sizeof(float));
4126-  new_bias_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, new_bias_ele_size, kOfflinePackWeight));
4127+  new_bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, packed_bias_size_, kOfflinePackWeight));
4128   MS_CHECK_PTR(new_bias_);
4129-  ret = memset_s(new_bias_, new_bias_ele_size, 0, new_bias_ele_size);
4130+  ret = memset_s(new_bias_, packed_bias_size_, 0, packed_bias_size_);
4131   MS_CHECK_RET_CODE(ret, "memset_s failed!");
4132   if (input_tensors_.size() == kInputSize2) {
4133     auto ori_bias_addr = reinterpret_cast<float *>(bias_tensor_->data());
4134     MS_CHECK_PTR(ori_bias_addr);
4135-    MS_CHECK_RET_CODE(memcpy_s(new_bias_, new_bias_ele_size, ori_bias_addr, out_channel_size * sizeof(float)),
4136-                      "memcpy_s failed!");
4137-  } else {
4138-    MS_CHECK_RET_CODE(memset_s(new_bias_, new_bias_ele_size, 0, new_bias_ele_size), "memset_s failed!");
4139+    MS_CHECK_RET_CODE(memcpy_s(new_bias_, packed_bias_size_, ori_bias_addr, bias_tensor_->Size()), "memcpy_s failed!");
4140   }
4141   return RET_OK;
4142 }
4143
4144+int ConvolutionWinogradFP32Coder::InitWeightBiasOnline() {
4145+  trans_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight));
4146+  MS_CHECK_PTR(trans_weight_);
4147+  new_bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight));
4148+  MS_CHECK_PTR(new_bias_);
4149+  return RET_OK;
4150+}
4151+
4152 int ConvolutionWinogradFP32Coder::ConfigInputOutput() {
4153   trans_func_str_.in_func_ = GetInputTransFunc(input_unit_);
4154   MS_CHECK_TRUE(!trans_func_str_.in_func_.empty(), "Get input_trans_func failed.");
4155@@ -217,6 +220,36 @@ std::string ConvolutionWinogradFP32Coder::GetOutputTransFunc(int input_unit, int
4156   }
4157 }
4158
4159+void ConvolutionWinogradFP32Coder::InitCodeOnline(CoderContext *const context) {
4160+  if (!Configurator::GetInstance()->keep_original_weight()) {
4161+    return;
4162+  }
4163+  Collect(context,
4164+          {
4165+            "nnacl/base/minimal_filtering_generator.h",
4166+            "nnacl/fp32/pack_fp32.h",
4167+          },
4168+          {"minimal_filtering_generator.c", "nnacl/fp32/pack_fp32.h"});
4169+  NNaclFp32Serializer init_code;
4170+  init_code.CodeBufferOffsetExpression(trans_weight_, context->weight_name(), context->weight_offset_name(),
4171+                                       context->weight_size_name(), trans_weight_size_);
4172+  auto filter_str = allocator_->GetRuntimeAddr(filter_tensor_);
4173+  init_code.CodeArray("matrix_g", matrix_g_.data(), k64);
4174+  init_code.CodeArray("matrix_gt", matrix_gt_.data(), k64);
4175+  init_code.CodeFunction("WinogradWeightTransform", filter_str, trans_weight_, "matrix_g", "matrix_gt", C8NUM,
4176+                         input_unit_, kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_, true);
4177+  init_code.CodeBufferOffsetExpression(new_bias_, context->weight_name(), context->weight_offset_name(),
4178+                                       context->weight_size_name(), packed_bias_size_);
4179+  if (input_tensors_.size() == kInputSize2) {
4180+    auto bias_str = allocator_->GetRuntimeAddr(bias_tensor_);
4181+    init_code.CodeFunction("memcpy", new_bias_, bias_str, bias_tensor_->Size());
4182+  } else {
4183+    init_code.CodeFunction("memcpy", new_bias_, 0, packed_bias_size_);
4184+  }
4185+  context->AppendInitWeightSizeCode(trans_weight_size_ + packed_bias_size_);
4186+  context->AppendInitCode(init_code.str());
4187+}
4188+
4189 int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) {
4190   Collect(context,
4191           {
4192@@ -253,6 +286,7 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) {
4193   } else if (target_ == kARM64) {
4194     Collect(context, {}, {},
4195             {
4196+              "BigMatmulFp32Opt.S",
4197               "MatmulFp32.S",
4198               "MatmulFp32Opt.S",
4199               "PreSum4x16Int8Peroc.S",
4200@@ -263,14 +297,14 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) {
4201               "MatmulInt8.S",
4202             });
4203   }
4204-
4205+  InitCodeOnline(context);
4206   NNaclFp32Serializer code;
4207   // call the op function
4208   code.CodeFunction("memset", trans_input_, "0", tile_buffer_size_);
4209   code.CodeFunction("memset", gemm_out_, "0", gemm_out_size_);
4210   code.CodeFunction("memset", tmp_data_, "0", tmp_data_size_);
4211   code.CodeFunction("memset", col_buffer_, "0", col_buffer_size_);
4212-  code << "\t\tfloat *tmp_buffer_address_list[4] = {" << allocator_->GetRuntimeAddr(trans_input_) << ", "
4213+  code << "    float *tmp_buffer_address_list[4] = {" << allocator_->GetRuntimeAddr(trans_input_) << ", "
4214        << allocator_->GetRuntimeAddr(gemm_out_) << ", " << allocator_->GetRuntimeAddr(tmp_data_) << ", "
4215        << allocator_->GetRuntimeAddr(col_buffer_) << "};\n";
4216   code.CodeStruct("conv_parameter", *conv_param_);
4217diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h
4218index d583312a..a4a0438f 100644
4219--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h
4220+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h
4221@@ -38,7 +38,13 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder {
4222   ~ConvolutionWinogradFP32Coder() override = default;
4223
4224  private:
4225-  int InitWeightBias();
4226+  int InitParameter();
4227+
4228+  int InitWeightBiasOffline();
4229+
4230+  int InitWeightBiasOnline();
4231+
4232+  void InitCodeOnline(CoderContext *const context);
4233
4234   int ConfigInputOutput();
4235
4236@@ -46,13 +52,13 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder {
4237
4238   int ReSize();
4239
4240-  int WinogradFilterTransform(const float *weight_data, float *matrix_g, const float *matrix_gt, int oc_block);
4241-
4242   std::string GetInputTransFunc(int input_unit);
4243
4244   std::string GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);
4245
4246+  size_t trans_weight_size_{0};
4247   float *trans_weight_{nullptr};
4248+  size_t packed_bias_size_{0};
4249   float *new_bias_{nullptr};
4250
4251   int kernel_unit_{0};
4252@@ -70,6 +76,9 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder {
4253   float *col_buffer_{nullptr};
4254
4255   TransFuncStr trans_func_str_;
4256+  TypeId data_type_{kNumberTypeFloat32};
4257+  std::vector<float> matrix_g_;
4258+  std::vector<float> matrix_gt_;
4259 };
4260 }  // namespace mindspore::lite::micro::nnacl
4261 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CONVOLUTION_WINOGRAD_FP32_CODER_H_
4262diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc
4263new file mode 100644
4264index 00000000..50146e72
4265--- /dev/null
4266+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc
4267@@ -0,0 +1,214 @@
4268+/**
4269+ * Copyright 2023 Huawei Technologies Co., Ltd
4270+ *
4271+ * Licensed under the Apache License, Version 2.0 (the "License");
4272+ * you may not use this file except in compliance with the License.
4273+ * You may obtain a copy of the License at
4274+ *
4275+ * http://www.apache.org/licenses/LICENSE-2.0
4276+ *
4277+ * Unless required by applicable law or agreed to in writing, software
4278+ * distributed under the License is distributed on an "AS IS" BASIS,
4279+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4280+ * See the License for the specific language governing permissions and
4281+ * limitations under the License.
4282+ */
4283+#include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h"
4284+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
4285+#include "coder/opcoders/file_collector.h"
4286+#include "nnacl/custom_gru_parameter.h"
4287+
4288+using mindspore::schema::PrimitiveType_Custom;
4289+
4290+namespace mindspore::lite::micro::nnacl {
4291+namespace {
4292+constexpr size_t kOutputNum = 3;
4293+constexpr size_t kInputDims = 3;
4294+constexpr size_t kWeightDims = 2;
4295+constexpr size_t kInputSize = 6;
4296+}  // namespace
4297+int CustomGruFP32Coder::Prepare(CoderContext *const context) {
4298+  if (input_tensors_.size() != kInputSize) {
4299+    MS_LOG(ERROR) << "built-in CustomGru must have 6 input." << node_->name_;
4300+    return RET_ERROR;
4301+  }
4302+  for (size_t i = 1; i < kInputSize - 1; ++i) {
4303+    if (!input_tensors_[i]->IsConst()) {
4304+      MS_LOG(ERROR) << "built-in CustomGru only support first-input and last-input is variable." << node_->name_;
4305+      return RET_NOT_SUPPORT;
4306+    }
4307+  }
4308+  if (InitParamter() != RET_OK) {
4309+    MS_LOG(ERROR) << "Init built-in CustomGru Parameter failed." << node_->name_;
4310+    return RET_ERROR;
4311+  }
4312+  return ReSize();
4313+}
4314+
4315+int CustomGruFP32Coder::InitParamter() {
4316+  param_ = reinterpret_cast<CustomGruParameter *>(parameter_);
4317+  param_->op_parameter_.thread_num_ = 1;
4318+  auto weight_in_shape = input_tensors_[1]->shape();
4319+  auto weight_hidden_shape = input_tensors_[C2NUM]->shape();
4320+  if (weight_in_shape.size() != kWeightDims || weight_hidden_shape.size() != kWeightDims) {
4321+    MS_LOG(ERROR) << "built-in CustomGru's weight must be 2D." << node_->name_;
4322+    return RET_ERROR;
4323+  }
4324+  if (weight_in_shape[0] != weight_hidden_shape[0]) {
4325+    MS_LOG(ERROR) << "Built-in CustomGru's weight-in and weight-hidden first-dim must be same." << node_->name_;
4326+    return RET_ERROR;
4327+  }
4328+  if (weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) {
4329+    MS_LOG(ERROR) << "Built-in CustomGru's weight-hidden first-dim must be 3 * second-dim." << node_->name_;
4330+    return RET_ERROR;
4331+  }
4332+  auto bias_in_shape = input_tensors_[C3NUM]->shape();
4333+  auto bias_hidden_shape = input_tensors_[C4NUM]->shape();
4334+  if (bias_in_shape.size() != 1) {
4335+    MS_LOG(ERROR) << "built-in CustomGru's bias must be 1D." << node_->name_;
4336+    return RET_ERROR;
4337+  }
4338+  if (bias_in_shape != bias_hidden_shape) {
4339+    MS_LOG(ERROR) << "built-in CustomGru's bias-in and bias-hidden must have same shape." << node_->name_;
4340+    return RET_ERROR;
4341+  }
4342+  if (bias_in_shape.back() != weight_in_shape.front()) {
4343+    MS_LOG(ERROR) << "built-in CustomGru's bias-in shape don't match with the first-dim of weight." << node_->name_;
4344+    return RET_ERROR;
4345+  }
4346+  if (bias_in_shape.front() % C3NUM != 0) {
4347+    MS_LOG(ERROR) << "The first-dim of CustomGru's weight must be 3 * hidden.";
4348+    return RET_ERROR;
4349+  }
4350+  param_->input_size = weight_in_shape.back();
4351+  param_->hidden_size = bias_in_shape.front() / C3NUM;
4352+  return RET_OK;
4353+}
4354+
4355+int CustomGruFP32Coder::ReSize() {
4356+  auto in_shape = input_tensor_->shape();
4357+  if (in_shape.size() != kInputDims) {
4358+    MS_LOG(ERROR) << "built-in CustomGru's first-input must be 3D." << node_->name_;
4359+    return RET_ERROR;
4360+  }
4361+  param_->num_step = in_shape[0];
4362+  param_->batch_size = in_shape[1];
4363+  if (in_shape.back() != param_->input_size) {
4364+    MS_LOG(ERROR) << "built-in CustomGru's fisrt-input don't match its weight." << node_->name_;
4365+    return RET_ERROR;
4366+  }
4367+  return InitWeightAndBias();
4368+}
4369+
4370+int CustomGruFP32Coder::InitWeightAndBias() {
4371+  auto col_align = UP_ROUND(param_->hidden_size, col_tile_);
4372+  auto data_type_size = DataTypeSize(data_type_);
4373+  bias_pack_size_ = col_align * data_type_size;
4374+  weight_in_pack_size_ = static_cast<size_t>(col_align * param_->input_size) * data_type_size;
4375+  weight_input_ = allocator_->Malloc(data_type_, weight_in_pack_size_ * C3NUM, kOnlinePackWeight,
4376+                                     input_tensors_.at(1)->tensor_name() + "_online_pack");
4377+  MS_CHECK_TRUE_MSG(weight_input_ != nullptr, RET_NULL_PTR, "Init weight-in failed.");
4378+  weight_hidden_pack_size_ = static_cast<size_t>(col_align * param_->hidden_size) * data_type_size;
4379+  weight_hidden_ = allocator_->Malloc(data_type_, weight_hidden_pack_size_ * C3NUM, kOnlinePackWeight,
4380+                                      input_tensors_.at(C2NUM)->tensor_name() + "_online_pack");
4381+  MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, RET_NULL_PTR, "Init weight-hidden failed.");
4382+  bias_input_ = allocator_->Malloc(data_type_, bias_pack_size_ * C3NUM, kOnlinePackWeight,
4383+                                   input_tensors_.at(C3NUM)->tensor_name() + "_online_pack");
4384+  MS_CHECK_TRUE_MSG(bias_input_ != nullptr, RET_NULL_PTR, "Init bias-in failed.");
4385+  bias_hidden_ = allocator_->Malloc(data_type_, bias_pack_size_ * C3NUM, kOnlinePackWeight,
4386+                                    input_tensors_.at(C4NUM)->tensor_name() + "_online_pack");
4387+  MS_CHECK_TRUE_MSG(bias_hidden_ != nullptr, RET_NULL_PTR, "Init bias-hidden failed.");
4388+  auto row_align = UP_ROUND(param_->batch_size, row_tile_);
4389+  auto work_space =
4390+    (row_align * (param_->input_size + param_->hidden_size) + param_->batch_size * param_->hidden_size * C6NUM) *
4391+    data_type_size;
4392+  run_buffer_ = allocator_->Malloc(data_type_, work_space, kWorkspace);
4393+  MS_CHECK_TRUE_MSG(run_buffer_ != nullptr, RET_NULL_PTR, "Init run_buffer failed.");
4394+  return RET_OK;
4395+}
4396+
4397+void CustomGruFP32Coder::InitNnaclFile(CoderContext *const context) {
4398+  Collect(context, {"nnacl/fp32/custom_gru_fp32.h"},
4399+          {"custom_gru_fp32.c", "pack_fp32.c", "matmul_fp32.c", "arithmetic_fp32.c", "activation_fp32.c"});
4400+}
4401+
4402+void CustomGruFP32Coder::InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst,
4403+                                         int row, int col) {
4404+  init_code->CodeFunction("RowMajor2Col8Major", src, dst, row, col);
4405+}
4406+
4407+void CustomGruFP32Coder::InitBiasCode(CoderContext *const context, NNaclFp32Serializer *init_code) {
4408+  auto data_type_size = DataTypeSize(data_type_);
4409+  init_code->CodeBufferOffsetExpression(bias_input_, context->weight_name(), context->weight_offset_name(),
4410+                                        context->weight_size_name(), bias_pack_size_ * C3NUM);
4411+  auto bias_in_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_input_);
4412+  auto bias_in_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C3NUM]);
4413+  for (int i = 0; i < C3NUM; ++i) {
4414+    auto dst_bias_in = bias_in_str + " + " + std::to_string(i * bias_pack_size_ / data_type_size);
4415+    auto src_bias_in = bias_in_tensor + " + " + std::to_string(i * param_->hidden_size);
4416+    init_code->CodeFunction("memcpy", dst_bias_in, src_bias_in, param_->hidden_size * data_type_size);
4417+  }
4418+  init_code->CodeBufferOffsetExpression(bias_hidden_, context->weight_name(), context->weight_offset_name(),
4419+                                        context->weight_size_name(), bias_pack_size_ * C3NUM);
4420+  auto bias_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_hidden_);
4421+  auto bias_hidden_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C4NUM]);
4422+  for (int i = 0; i < C3NUM; ++i) {
4423+    auto dst_bias_hidden = bias_hidden_str + " + " + std::to_string(i * bias_pack_size_ / data_type_size);
4424+    auto src_bias_hidden = bias_hidden_tensor + " + " + std::to_string(i * param_->hidden_size);
4425+    init_code->CodeFunction("memcpy", dst_bias_hidden, src_bias_hidden, param_->hidden_size * data_type_size);
4426+  }
4427+}
4428+
4429+void CustomGruFP32Coder::InitWeightCode(CoderContext *const context, NNaclFp32Serializer *init_code) {
4430+  auto data_type_size = DataTypeSize(data_type_);
4431+  init_code->CodeBufferOffsetExpression(weight_input_, context->weight_name(), context->weight_offset_name(),
4432+                                        context->weight_size_name(), weight_in_pack_size_ * C3NUM);
4433+  auto weight_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_input_);
4434+  auto weight_in_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[1]);
4435+  for (int i = 0; i < C3NUM; ++i) {
4436+    auto dst_weight_in = weight_input_str + " + " + std::to_string(i * weight_in_pack_size_ / data_type_size);
4437+    auto src_weight_in = weight_in_tensor + " + " + std::to_string(i * param_->hidden_size * param_->input_size);
4438+    InitPackMatrixB(init_code, src_weight_in, dst_weight_in, param_->hidden_size, param_->input_size);
4439+  }
4440+
4441+  init_code->CodeBufferOffsetExpression(weight_hidden_, context->weight_name(), context->weight_offset_name(),
4442+                                        context->weight_size_name(), weight_hidden_pack_size_ * C3NUM);
4443+  auto weight_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_hidden_);
4444+  auto weight_hidden_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C2NUM]);
4445+  for (int i = 0; i < C3NUM; ++i) {
4446+    auto dst_weight_hidden = weight_hidden_str + " + " + std::to_string(i * weight_hidden_pack_size_ / data_type_size);
4447+    auto src_weight_hidden =
4448+      weight_hidden_tensor + " + " + std::to_string(i * param_->hidden_size * param_->hidden_size);
4449+    InitPackMatrixB(init_code, src_weight_hidden, dst_weight_hidden, param_->hidden_size, param_->hidden_size);
4450+  }
4451+}
4452+
4453+int CustomGruFP32Coder::DoCode(CoderContext *const context) {
4454+  NNaclFp32Serializer code, init_code;
4455+  code.CodeStruct("custom_gru_parm", *param_);
4456+  InitNnaclFile(context);
4457+  InitWeightCode(context, &init_code);
4458+  InitBiasCode(context, &init_code);
4459+  auto row_align = UP_ROUND(param_->batch_size, row_tile_);
4460+  auto data_type_str = GetTensorDataType(data_type_);
4461+  auto buffer_name = "( " + data_type_str + "*)" + MemoryAllocator::GetInstance()->GetRuntimeAddr(run_buffer_);
4462+  int offset1 = row_align * param_->input_size;
4463+  int offset2 = offset1 + param_->batch_size * param_->hidden_size * C3NUM;
4464+  int offset3 = offset2 + row_align * param_->hidden_size;
4465+  code << data_type_str + "*buffer[4] = {" << buffer_name << ", " << buffer_name + " + " << offset1 << ", "
4466+       << buffer_name + " + " << offset2 << ", " << buffer_name + " + " << offset3 << "};\n";
4467+  auto weight_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_input_);
4468+  auto weight_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_hidden_);
4469+  auto bias_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_input_);
4470+  auto bias_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_hidden_);
4471+  code.CodeFunction(op_func_, output_tensor_, input_tensor_, weight_input_str, weight_hidden_str, bias_input_str,
4472+                    bias_hidden_str, input_tensors_[C5NUM], "buffer", "&custom_gru_parm");
4473+  context->AppendInitWeightSizeCode((weight_in_pack_size_ + weight_hidden_pack_size_) * C3NUM +
4474+                                    bias_pack_size_ * C6NUM);
4475+  context->AppendInitCode(init_code.str());
4476+  context->AppendCode(code.str());
4477+  return RET_OK;
4478+}
4479+
4480+REG_BUILIN_CUSTOM_CODER(kARM64, kNumberTypeFloat32, "CustomGRU", CPUOpCoderCreator<CustomGruFP32Coder>)
4481+}  // namespace mindspore::lite::micro::nnacl
4482diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h
4483new file mode 100644
4484index 00000000..27db0f94
4485--- /dev/null
4486+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h
4487@@ -0,0 +1,64 @@
4488+/**
4489+ * Copyright 2023 Huawei Technologies Co., Ltd
4490+ *
4491+ * Licensed under the Apache License, Version 2.0 (the "License");
4492+ * you may not use this file except in compliance with the License.
4493+ * You may obtain a copy of the License at
4494+ *
4495+ * http://www.apache.org/licenses/LICENSE-2.0
4496+ *
4497+ * Unless required by applicable law or agreed to in writing, software
4498+ * distributed under the License is distributed on an "AS IS" BASIS,
4499+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4500+ * See the License for the specific language governing permissions and
4501+ * limitations under the License.
4502+ */
4503+
4504+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_
4505+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_
4506+
4507+#include <string>
4508+#include <vector>
4509+#include "coder/opcoders/op_coder.h"
4510+#include "nnacl/custom_gru_parameter.h"
4511+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h"
4512+
4513+namespace mindspore::lite::micro::nnacl {
4514+class CustomGruFP32Coder : public OperatorCoder {
4515+ public:
4516+  CustomGruFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
4517+                     const LiteGraph::Node *node, size_t node_index, Target target)
4518+      : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {}
4519+  ~CustomGruFP32Coder() override = default;
4520+
4521+  int Prepare(CoderContext *const context) override;
4522+
4523+  int DoCode(CoderContext *const context) override;
4524+
4525+ protected:
4526+  virtual void InitNnaclFile(CoderContext *const context);
4527+  virtual void InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, int row,
4528+                               int col);
4529+  TypeId data_type_{kNumberTypeFloat32};
4530+  int row_tile_{C12NUM};
4531+  int col_tile_{C8NUM};
4532+  void *weight_input_{nullptr};
4533+  void *weight_hidden_{nullptr};
4534+  void *bias_input_{nullptr};
4535+  void *bias_hidden_{nullptr};
4536+  size_t weight_in_pack_size_{0};
4537+  size_t weight_hidden_pack_size_{0};
4538+  size_t bias_pack_size_{0};
4539+  std::string op_func_{"CustomGru"};
4540+  CustomGruParameter *param_{nullptr};
4541+
4542+ private:
4543+  int InitParamter();
4544+  int InitWeightAndBias();
4545+  int ReSize();
4546+  void InitWeightCode(CoderContext *const context, NNaclFp32Serializer *init_code);
4547+  void InitBiasCode(CoderContext *const context, NNaclFp32Serializer *init_code);
4548+  void *run_buffer_{nullptr};
4549+};
4550+}  // namespace mindspore::lite::micro::nnacl
4551+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_
4552diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc
4553index 3c31479c..c6b93abf 100644
4554--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc
4555+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc
4556@@ -27,11 +27,30 @@ using mindspore::schema::PrimitiveType_Gather;
4557 namespace mindspore::lite::micro::nnacl {
4558 int GatherFP32Coder::Prepare(CoderContext *const context) { return RET_OK; }
4559
4560+void GatherFP32Coder::InitCodeInChange(CoderContext *const context, std::string *auxiliary_variable) {
4561+  auto input0_shape_str = allocator_->GetAuxiliaryWeight(input_tensor_);
4562+  if (input0_shape_str.empty()) {
4563+    return;
4564+  }
4565+  *auxiliary_variable = input0_shape_str;
4566+  NNaclFp32Serializer init_code;
4567+  auto in_shape = input_tensor_->shape();
4568+  int in_rank = static_cast<int>(in_shape.size());
4569+  init_code.CodeArray("shape", in_shape.data(), in_rank);
4570+  init_code << "  for (int i = 0; i < " << in_rank << "; ++i) {\n";
4571+  init_code << "    if (i != " << axis_ << " && " << input0_shape_str << "[i] != shape[i]) {\n";
4572+  init_code << "      return RET_ERROR;\n";
4573+  init_code << "    }\n";
4574+  init_code << "  }\n";
4575+  context->AppendInitCode(init_code.str());
4576+}
4577+
4578 int GatherFP32Coder::DoCode(CoderContext *context) {
4579   Tensor *input0 = input_tensors_.at(0);
4580   Tensor *input1 = input_tensors_.at(1);
4581   MS_CHECK_PTR(input0);
4582   MS_CHECK_PTR(input1);
4583+  MS_CHECK_PTR(parameter_);
4584   MS_CHECK_TRUE_MSG(input1->data_type() == kNumberTypeInt32 || input1->data_type() == kNumberTypeInt, RET_ERROR,
4585                     "index's data-type is not int32");
4586   // generate code .h .c
4587@@ -44,18 +63,16 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
4588           });
4589
4590   NNaclFp32Serializer code;
4591-  std::vector<int> in_shape = input0->shape();
4592+  auto in_shape = input0->shape();
4593   int in_rank = static_cast<int>(in_shape.size());
4594-  MS_CHECK_PTR(parameter_);
4595-  int axis = (reinterpret_cast<GatherParameter *>(parameter_))->axis_;
4596-  MS_CHECK_TRUE(static_cast<int>(in_shape.size()) >= axis, "invalid axis in gather parameter");
4597-  const int limit = in_shape.at(axis);
4598-
4599-  int outer_size = 1, inner_size = 1;
4600-  for (int i = 0; i < axis; ++i) {
4601+  axis_ = *(reinterpret_cast<int *>(input_tensors_.at(THIRD_INPUT)->data()));
4602+  MS_CHECK_TRUE(static_cast<int>(in_shape.size()) >= axis_, "invalid axis in gather parameter");
4603+  int outer_size = 1;
4604+  for (int i = 0; i < axis_; ++i) {
4605     outer_size *= in_shape.at(i);
4606   }
4607-  for (int i = axis + 1; i < in_rank; ++i) {
4608+  int inner_size = 1;
4609+  for (int i = axis_ + 1; i < in_rank; ++i) {
4610     inner_size *= in_shape.at(i);
4611   }
4612   auto data_size = static_cast<int>(lite::DataTypeSize(input0->data_type()));
4613@@ -67,22 +84,22 @@ int GatherFP32Coder::DoCode(CoderContext *context) {
4614   int start = stride * kDefaultTaskId;
4615   int count = MSMIN(stride, outer_size - stride * kDefaultTaskId);
4616   std::string input0_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input0, true);
4617-  if (input0_data.empty()) {
4618-    MS_LOG(ERROR) << "pointer is not allocated by the allocator";
4619-    return RET_ERROR;
4620-  }
4621+  MS_CHECK_TRUE_MSG(!input0_data.empty(), RET_ERROR, "pointer is not allocated by the allocator");
4622   std::string input1_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input1, true);
4623-  if (input1_data.empty()) {
4624-    MS_LOG(ERROR) << "pointer is not allocated by the allocator";
4625-    return RET_ERROR;
4626-  }
4627+  MS_CHECK_TRUE_MSG(!input1_data.empty(), RET_ERROR, "pointer is not allocated by the allocator");
4628   std::string output_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(output_tensor_, true);
4629-  if (output_data.empty()) {
4630-    MS_LOG(ERROR) << "pointer is not allocated by the allocator";
4631-    return RET_ERROR;
4632+  MS_CHECK_TRUE_MSG(!output_data.empty(), RET_ERROR, "pointer is not allocated by the allocator");
4633+
4634+  std::string limit = std::to_string(in_shape[axis_]);
4635+  std::string in_offset = std::to_string(start * in_shape[axis_] * byte_inner_size);
4636+  std::string auxiliary_variable;
4637+  InitCodeInChange(context, &auxiliary_variable);
4638+  if (!auxiliary_variable.empty()) {
4639+    limit = auxiliary_variable + "[" + std::to_string(axis_) + "]";
4640+    in_offset = std::to_string(start) + " * " + limit + " * " + std::to_string(byte_inner_size);
4641   }
4642   code << "\t\tconst int8_t *int8_in = (const int8_t *)" << input0_data << ";\n";
4643-  code << "\t\tint8_in += " << std::to_string(start * limit * byte_inner_size) << ";\n";
4644+  code << "\t\tint8_in += " << in_offset << ";\n";
4645   code << "\t\tconst int *index_data = (const int *)" << input1_data << ";\n";
4646   code << "\t\tint8_t *int8_out = (int8_t *)" << output_data << ";\n";
4647   code << "\t\tint8_out += " << std::to_string(start * byte_out_stride) << ";\n";
4648diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h
4649index a14d9c3c..a175d694 100644
4650--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h
4651+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h
4652@@ -35,7 +35,9 @@ class GatherFP32Coder final : public OperatorCoder {
4653   int DoCode(CoderContext *const context) override;
4654
4655  private:
4656+  void InitCodeInChange(CoderContext *const context, std::string *auxiliary_variable);
4657   int32_t *indices_{nullptr};
4658+  int axis_{0};
4659 };
4660 }  // namespace mindspore::lite::micro::nnacl
4661 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_FP32_CODER_H_
4662diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc
4663index 790a142e..6115edb5 100644
4664--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc
4665+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc
4666@@ -50,13 +50,13 @@ int MatMulFP32BaseCoder::ReSize() {
4667 int MatMulFP32BaseCoder::InitBiasData() {
4668   if (input_tensors_.size() == DIMENSION_3D) {
4669     int max_bias_data = params_->col_align_;
4670-    bias_pack_ptr_size_ = static_cast<size_t>(max_bias_data * sizeof(float));
4671+    bias_pack_ptr_size_ = static_cast<size_t>(max_bias_data * DataTypeSize(data_type_));
4672     if (bias_tensor_->ElementsNum() == 1) {
4673       is_bias_broadcast_ = true;
4674     }
4675-    ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * sizeof(float);
4676-    bias_ptr_ = allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight,
4677-                                   bias_tensor_->tensor_name() + "_online_pack");
4678+    ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * DataTypeSize(data_type_);
4679+    bias_ptr_ =
4680+      allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack");
4681     MS_CHECK_PTR(bias_ptr_);
4682   }
4683   return RET_OK;
4684@@ -83,18 +83,19 @@ int MatMulFP32BaseCoder::InitBufferA() {
4685   if (a_pack_ptr_ != nullptr) {
4686     return RET_OK;
4687   }
4688-  a_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * sizeof(float));
4689+  a_pack_ptr_size_ =
4690+    static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * DataTypeSize(data_type_));
4691   if (params_->a_const_) {
4692-    a_pack_ptr_ = reinterpret_cast<float *>(allocator_->GetSharedWeightAddr(input_tensors_.at(0)));
4693+    a_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(0));
4694     if (a_pack_ptr_ == nullptr) {
4695-      a_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight,
4696-                                                                 input_tensors_.at(0)->tensor_name() + "_online_pack"));
4697+      a_pack_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight,
4698+                                       input_tensors_.at(0)->tensor_name() + "_online_pack");
4699       allocator_->MarkSharedWeight(input_tensors_.at(0), a_pack_ptr_);
4700     } else {
4701       a_packed_ = true;
4702     }
4703   } else {
4704-    a_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, a_pack_ptr_size_, kWorkspace));
4705+    a_pack_ptr_ = allocator_->Malloc(data_type_, a_pack_ptr_size_, kWorkspace);
4706   }
4707   MS_CHECK_PTR(a_pack_ptr_);
4708   return RET_OK;
4709@@ -104,18 +105,19 @@ int MatMulFP32BaseCoder::InitBufferB() {
4710   if (b_pack_ptr_ != nullptr) {
4711     return RET_OK;
4712   }
4713-  b_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->col_align_ * params_->deep_ * data_type_size_);
4714+  b_pack_ptr_size_ =
4715+    static_cast<size_t>(params_->batch * params_->col_align_ * params_->deep_ * DataTypeSize(data_type_));
4716   if (params_->b_const_) {
4717-    b_pack_ptr_ = reinterpret_cast<float *>(allocator_->GetSharedWeightAddr(input_tensors_.at(1)));
4718+    b_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(1));
4719     if (b_pack_ptr_ == nullptr) {
4720-      b_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeUInt8, b_pack_ptr_size_, kOnlinePackWeight,
4721-                                                                 input_tensors_.at(1)->tensor_name() + "_online_pack"));
4722+      b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size_, kOnlinePackWeight,
4723+                                       input_tensors_.at(1)->tensor_name() + "_online_pack");
4724       allocator_->MarkSharedWeight(input_tensors_.at(1), b_pack_ptr_);
4725     } else {
4726       b_packed_ = true;
4727     }
4728   } else {
4729-    b_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeUInt8, b_pack_ptr_size_, kWorkspace));
4730+    b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size_, kWorkspace);
4731   }
4732   MS_CHECK_PTR(b_pack_ptr_);
4733   return RET_OK;
4734@@ -194,7 +196,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
4735   NNaclFp32Serializer code, init_code;
4736   size_t w_buf_size = 0;
4737   std::string param_name = "mat_mul_parameter";
4738-  std::string bias_ptr_str = "((float *)(" + allocator_->GetRuntimeAddr(bias_ptr_) + "))";
4739+
4740   code.CodeStruct(param_name, *params_);
4741   if (support_parallel_) {
4742     code << "    " << param_name << ".op_parameter_.thread_num_ = 1;\n";
4743@@ -207,6 +209,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
4744     int max_bias_data = params_->col_align_;
4745     if (is_bias_broadcast_) {
4746       float broad_cast_data = (reinterpret_cast<float *>(bias_tensor_->data()))[0];
4747+      std::string bias_ptr_str = allocator_->GetRuntimeAddr(bias_ptr_);
4748       init_code << "\t    for (int i = 0; i < " << max_bias_data << "; ++i) {\n";
4749       init_code << "\t\t    " << bias_ptr_str << "[i] = " << broad_cast_data << ";\n";
4750       init_code << "   }\n";
4751@@ -219,8 +222,8 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
4752   std::string a_str = allocator_->GetRuntimeAddr(input_tensor_);
4753   std::string b_str = allocator_->GetRuntimeAddr(filter_tensor_);
4754   std::string c_str = allocator_->GetRuntimeAddr(output_tensor_);
4755-  std::string a_pack_str = allocator_->GetRuntimeAddr(a_pack_ptr_);
4756-  std::string b_pack_str = allocator_->GetRuntimeAddr(b_pack_ptr_);
4757+  std::string a_pack_str = allocator_->GetRuntimeAddr(static_cast<float *>(a_pack_ptr_));
4758+  std::string b_pack_str = allocator_->GetRuntimeAddr(static_cast<float *>(b_pack_ptr_));
4759   // do const value packing to init
4760   if ((params_->a_const_ && !a_packed_) || (params_->b_const_ && !b_packed_)) {
4761     init_code.CodeStruct("mat_mul_parameter", *params_);
4762@@ -271,7 +274,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
4763     code << "      const float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_ << ";\n";
4764     code << "      float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n  ";
4765
4766-    code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_str, params_->act_type_,
4767+    code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_,
4768                       params_->deep_, cur_oc);
4769   } else {
4770     code << "      const float *batch_a_ptr = " << a_pack_str << " + i * " << params_->row_align_ * params_->deep_
4771@@ -280,7 +283,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
4772          << ";\n";
4773     code << "      float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n  ";
4774
4775-    code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_str, params_->act_type_,
4776+    code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_,
4777                       params_->deep_, params_->row_, cur_oc, params_->col_, "OutType_Nhwc");
4778   }
4779   code << "    }\n";
4780diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h
4781index 68b2658a..a5ef9277 100644
4782--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h
4783+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h
4784@@ -69,7 +69,7 @@ class MatMulFP32BaseCoder : public OperatorCoder {
4785   size_t a_pack_ptr_size_{0};
4786   size_t b_pack_ptr_size_{0};
4787   bool is_bias_broadcast_{false};
4788-  size_t data_type_size_{C4NUM};
4789+  TypeId data_type_{kNumberTypeFloat32};
4790 };
4791 }  // namespace mindspore::lite::micro::nnacl
4792 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_MATMUL_FP32_BASE_CODER_H_
4793diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc
4794index a107c3cf..45b2e37f 100644
4795--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc
4796+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc
4797@@ -27,6 +27,14 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build(int schema_version) {
4798   MS_CHECK_PTR_RET_NULL(node_->primitive_);
4799   int primitive_type = GetPrimitiveType(node_->primitive_, schema_version);
4800   CoderKey coder_key(target_, data_type_, primitive_type);
4801+  if (builtin_custom_) {
4802+    auto custom_type = reinterpret_cast<const schema::Primitive *>(node_->primitive_)->value_as_Custom()->type();
4803+    if (custom_type == nullptr || custom_type->str().empty()) {
4804+      MS_LOG(ERROR) << "Builtin custom-op has no type.";
4805+      return nullptr;
4806+    }
4807+    coder_key = CoderKey(target_, data_type_, schema::PrimitiveType_Custom, custom_type->str());
4808+  }
4809   CoderCreatorFunc creator_func = OpCoderFactory::GetInstance()->FindOpCoder(coder_key);
4810   if (creator_func == nullptr) {
4811     MS_LOG(ERROR) << "caught unsupported layer: " << node_->name_;
4812@@ -112,5 +120,10 @@ OpCoderBuilder &OpCoderBuilder::support_parallel(bool parallel) {
4813   return *this;
4814 }
4815
4816+OpCoderBuilder &OpCoderBuilder::is_builtin_custom(bool builtin_custom) {
4817+  builtin_custom_ = builtin_custom;
4818+  return *this;
4819+}
4820+
4821 void OpCoderBuilder::Reset() {}
4822 }  // namespace mindspore::lite::micro
4823diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h
4824index adce6c73..d85f1c32 100644
4825--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h
4826+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h
4827@@ -46,6 +46,8 @@ class OpCoderBuilder {
4828
4829   OpCoderBuilder &support_parallel(bool parallel);
4830
4831+  OpCoderBuilder &is_builtin_custom(bool builtin_custom);
4832+
4833   void Reset();
4834
4835  private:
4836@@ -70,6 +72,8 @@ class OpCoderBuilder {
4837   std::vector<uint32_t> output_indices_;
4838
4839   bool support_parallel_{false};
4840+
4841+  bool builtin_custom_{false};
4842 };
4843 }  // namespace mindspore::lite::micro
4844 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_BUILDER_H_
4845diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc
4846index 031df2e7..cf26d51d 100644
4847--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc
4848+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc
4849@@ -37,9 +37,9 @@ OpCoderFactory *OpCoderFactory::GetInstance() {
4850 }
4851
4852 int OpCoderFactory::RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type,
4853-                                  const CoderCreatorFunc &creator_func) {
4854+                                  const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func) {
4855   // check key
4856-  CoderKey key(target, data_type, operator_type);
4857+  CoderKey key(target, data_type, operator_type, builtin_custom_type);
4858   // insert pair to registry
4859   if (this->opcoder_sets_.find(key) != this->opcoder_sets_.end()) {
4860     MS_LOG(ERROR) << "coder already exist: " << key.ToString();
4861@@ -63,7 +63,7 @@ CoderCreatorFunc OpCoderFactory::FindOpCoder(const CoderKey &key) {
4862 }
4863
4864 OpCoderRegister::OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type,
4865-                                 const CoderCreatorFunc &creatorFunc) {
4866-  OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, creatorFunc);
4867+                                 const std::string &builtin_custom_type, const CoderCreatorFunc &creatorFunc) {
4868+  OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, builtin_custom_type, creatorFunc);
4869 }
4870 }  // namespace mindspore::lite::micro
4871diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h
4872index 9a1aed63..acbd3a22 100644
4873--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h
4874+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h
4875@@ -18,6 +18,7 @@
4876 #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_
4877
4878 #include <map>
4879+#include <utility>
4880 #include <vector>
4881 #include <memory>
4882 #include <string>
4883@@ -34,10 +35,14 @@ class CoderKey {
4884  public:
4885   CoderKey() = delete;
4886
4887-  CoderKey(Target target, TypeId data_type, int op_type) : target_(target), data_type_(data_type), op_type_(op_type) {}
4888+  CoderKey(Target target, TypeId data_type, int op_type, std::string builtin_custom_type = "")
4889+      : target_(target),
4890+        data_type_(data_type),
4891+        op_type_(op_type),
4892+        builtin_custom_type_(std::move(builtin_custom_type)) {}
4893
4894   CoderKey AllKey() const {
4895-    CoderKey key(kAllTargets, data_type_, op_type_);
4896+    CoderKey key(kAllTargets, data_type_, op_type_, builtin_custom_type_);
4897     return key;
4898   }
4899
4900@@ -50,6 +55,7 @@ class CoderKey {
4901   Target target_ = kTargetUnknown;
4902   TypeId data_type_ = kTypeUnknown;
4903   int op_type_ = schema::PrimitiveType_NONE;
4904+  std::string builtin_custom_type_;
4905 };
4906
4907 class OpCoderFactory {
4908@@ -59,7 +65,7 @@ class OpCoderFactory {
4909   static OpCoderFactory *GetInstance();
4910
4911   int RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type,
4912-                    const CoderCreatorFunc &creator_func);
4913+                    const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func);
4914
4915   CoderCreatorFunc FindOpCoder(const CoderKey &key);
4916
4917@@ -75,11 +81,16 @@ class OpCoderRegister {
4918   OpCoderRegister() = delete;
4919
4920   OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type,
4921-                  const CoderCreatorFunc &creator_func);
4922+                  const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func);
4923
4924   ~OpCoderRegister() = default;
4925 };
4926-#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \
4927-  static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, creator_func);
4928+#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func)                                   \
4929+  static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, "", \
4930+                                                                       creator_func);
4931+
4932+#define REG_BUILIN_CUSTOM_CODER(target, data_type, custom_type, creator_func) \
4933+  static OpCoderRegister g_##target##data_type##operator_type##Creator(       \
4934+    target, data_type, schema::PrimitiveType_Custom, custom_type, creator_func);
4935 }  // namespace mindspore::lite::micro
4936 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_
4937diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc
4938index c333b621..cde08fd8 100644
4939--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc
4940+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc
4941@@ -196,6 +196,11 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const BroadcastSha
4942                         ToString(op_param.output_shape_), op_param.output_shape_size_);
4943 }
4944
4945+void NNaclFp32Serializer::CodeStruct(const std::string &name, const CustomGruParameter &op_param) {
4946+  CodeBaseStruct<false>("CustomGruParameter", name, op_param.op_parameter_, op_param.num_step, op_param.batch_size,
4947+                        op_param.input_size, op_param.hidden_size);
4948+}
4949+
4950 void NNaclFp32Serializer::CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector<Tensor *> tensor) {
4951   std::vector<std::string> tensor_names;
4952   int size = tensor.size();
4953diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h
4954index f52ced20..797a9574 100644
4955--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h
4956+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h
4957@@ -44,6 +44,7 @@
4958 #include "nnacl/layer_norm_parameter.h"
4959 #include "nnacl/broadcast_to_parameter.h"
4960 #include "nnacl/split_parameter.h"
4961+#include "nnacl/custom_gru_parameter.h"
4962
4963 namespace mindspore::lite::micro::nnacl {
4964 class NNaclFp32Serializer : public Serializer {
4965@@ -74,6 +75,7 @@ class NNaclFp32Serializer : public Serializer {
4966   void CodeStruct(const std::string &name, const SplitParameter &split_parameter);
4967   void CodeStruct(const std::string &name, const LayerNormParameter &param);
4968   void CodeStruct(const std::string &name, const BroadcastShapeInfo &param);
4969+  void CodeStruct(const std::string &name, const CustomGruParameter &param);
4970   void CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector<Tensor *> tensor);
4971
4972  private:
4973diff --git a/mindspore/lite/tools/converter/micro/coder/session.cc b/mindspore/lite/tools/converter/micro/coder/session.cc
4974index 471f1491..756b7222 100644
4975--- a/mindspore/lite/tools/converter/micro/coder/session.cc
4976+++ b/mindspore/lite/tools/converter/micro/coder/session.cc
4977@@ -40,11 +40,38 @@
4978 #include "coder/opcoders/nnacl/dequant/de_quant.h"
4979
4980 namespace mindspore::lite::micro {
4981+namespace {
4982+bool IsBuiltInCustomNode(const void *primitive, int schema_version) {
4983+  if (!IsCustomNode(primitive, schema_version)) {
4984+    return false;
4985+  }
4986+  const auto &custom = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_Custom();
4987+  if (custom == nullptr) {
4988+    return false;
4989+  }
4990+  const auto &attrs = custom->attr();
4991+  if (attrs == nullptr) {
4992+    return false;
4993+  }
4994+  for (size_t i = 0; i < attrs->size(); ++i) {
4995+    if (attrs->Get(i) == nullptr || attrs->Get(i)->name() == nullptr) {
4996+      continue;
4997+    }
4998+    if (attrs->Get(i)->name()->str() == "builtin") {
4999+      return true;
5000+    }
5001+  }
5002+  return false;
5003+}
5004+}  // namespace
5005+
5006 CoderSession::CoderSession() { allocator_ = MemoryAllocator::GetInstance(); }
5007
5008-void CoderSession::EndCode() {
5009+int CoderSession::PassArgsToContext() {
5010   context_->set_tensor_map(allocator_->tensors_map());
5011   context_->set_saved_weights(allocator_->saved_weights());
5012+  context_->set_origin_weights(allocator_->origin_weights());
5013+  context_->set_auxiliary_weights(allocator_->auxiliary_weights());
5014   size_t de_quant_max_workspace_size = nnacl::Dequant::GetInstance()->de_quant_max_workspace();
5015   size_t final_total_size = allocator_->total_buffer_size() > de_quant_max_workspace_size
5016                               ? allocator_->total_buffer_size()
5017@@ -61,13 +88,20 @@ void CoderSession::EndCode() {
5018   if (config->code_mode() == Train) {
5019     Train::TransformGraphForTrain(context_.get(), op_coders_, schema_version_);
5020   }
5021+  if (!context_->JudgeIsValid(Configurator::GetInstance()->keep_original_weight())) {
5022+    MS_LOG(ERROR) << "Current model cannot keep-original-weight, due to existing generated tensor-data, please set "
5023+                     "'keep_original_weight' to false.";
5024+    return RET_NOT_SUPPORT;
5025+  }
5026+  return RET_OK;
5027 }
5028
5029 int CoderSession::Run() {
5030   MS_LOG(INFO) << "start run opcoders";
5031   // 1. assign memory
5032   std::vector<lite::Tensor *> inputs = coder_graph_->input_tensors();
5033-  int ret = allocator_->Assign(inputs, op_coders_);
5034+  int ret = allocator_->Assign(inputs, op_coders_, coder_graph_->all_tensors(),
5035+                               Configurator::GetInstance()->changeable_weights_name());
5036   MS_CHECK_RET_CODE(ret, "assign memory failed");
5037   // 2. prepare, init model parameters
5038   for (const auto &op_coder : op_coders_) {
5039@@ -84,10 +118,10 @@ int CoderSession::Run() {
5040     ret = op_coder->DoCode(this->context_.get());
5041     MS_CHECK_RET_CODE(ret, "do coder " << op_coder->name() << " failed");
5042   }
5043-
5044-  this->EndCode();
5045+  ret = PassArgsToContext();
5046+  MS_CHECK_RET_CODE(ret, "PassArgsToContext failed");
5047   MS_LOG(INFO) << "run opcoders success";
5048-  return RET_OK;
5049+  return ret;
5050 }
5051
5052 int CoderSession::GenerateCode() {
5053@@ -269,7 +303,9 @@ int CoderSession::CreateOpCoders() {
5054     }
5055
5056     OpParameter *parameter = nullptr;
5057-    if (IsCustomNode(node->primitive_, schema_version_)) {
5058+    bool is_custom_op = IsCustomNode(node->primitive_, schema_version_);
5059+    bool is_built_in_custom_op = IsBuiltInCustomNode(node->primitive_, schema_version_);
5060+    if (is_custom_op && !is_built_in_custom_op) {
5061       KernelRegistry::GetInstance()->RegisterKernel(schema::PrimitiveType_Custom);
5062     } else {
5063       parameter = GenParameterAndInfer(node, inputs, &outputs);  // built-in ops infer
5064@@ -287,6 +323,7 @@ int CoderSession::CreateOpCoders() {
5065                                                 .mode(code_mode)
5066                                                 .input_indices(input_indices)
5067                                                 .output_indices(output_indices)
5068+                                                .is_builtin_custom(is_built_in_custom_op)
5069                                                 .build(schema_version_);
5070     if (op_coder == nullptr) {
5071       coder_graph_->DumpUnSupportLayer(code_target);
5072diff --git a/mindspore/lite/tools/converter/micro/coder/session.h b/mindspore/lite/tools/converter/micro/coder/session.h
5073index 3a8f7290..20f6b2b5 100644
5074--- a/mindspore/lite/tools/converter/micro/coder/session.h
5075+++ b/mindspore/lite/tools/converter/micro/coder/session.h
5076@@ -50,7 +50,7 @@ class CoderSession {
5077   int CreateOpCoders();
5078   int InitCodeGraph();
5079   int CompileGraph();
5080-  void EndCode();
5081+  int PassArgsToContext();
5082
5083   std::unique_ptr<CoderGraph> coder_graph_{nullptr};
5084   std::unique_ptr<CoderContext> context_{nullptr};
5085diff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc
5086index a552da05..3b868b41 100644
5087--- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc
5088+++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc
5089@@ -103,6 +103,8 @@ std::string GetTensorDataType(TypeId type) {
5090       return "uint32_t ";
5091     case kNumberTypeInt64:
5092       return "int64_t ";
5093+    case kNumberTypeFloat16:
5094+      return "float16_t ";
5095     default:
5096       MS_LOG(ERROR) << "unsupported data type: " << EnumNameDataType(type);
5097       return "";
5098@@ -152,7 +154,6 @@ std::string EnumMicroTensorDataType(TypeId type) {
5099     case kNumberTypeUInt16:
5100       return "DataType_DT_UINT16";
5101     case kNumberTypeFloat16:
5102-      MS_LOG(WARNING) << "unsupported data type: kNumberTypeFloat16";
5103       return "DataType_DT_FLOAT16";
5104     default:
5105       MS_LOG(WARNING) << "unsupported data type: " << type << ", reference: " << kNumberTypeInt;
5106diff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h
5107index 7753e123..61c7c923 100644
5108--- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h
5109+++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h
5110@@ -27,6 +27,7 @@
5111 #include "src/common/log_adapter.h"
5112 #include "nnacl/op_base.h"
5113 #include "tools/converter/micro/coder/config.h"
5114+#include "base/float16.h"
5115
5116 namespace mindspore::lite::micro {
5117 std::string EnumNameDataType(TypeId type);
5118@@ -63,7 +64,8 @@ std::string GetVariableTypeName() {
5119                                                        {std::type_index(typeid(int16_t *)), "int16_t *"},
5120                                                        {std::type_index(typeid(int8_t *)), "int8_t *"},
5121                                                        {std::type_index(typeid(uint8_t *)), "uint8_t *"},
5122-                                                       {std::type_index(typeid(float *)), "float *"}};
5123+                                                       {std::type_index(typeid(float *)), "float *"},
5124+                                                       {std::type_index(typeid(float16 *)), "float16_t *"}};
5125   auto item = types_name.find(std::type_index(typeid(T)));
5126   if (item != types_name.end()) {
5127     return item->second;
5128--
51292.17.1
5130
5131