• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1From 9ec126625b9b723925b1b0c488355e041888f8bc Mon Sep 17 00:00:00 2001
2From: albert-yan <albert.liyan@huawei.com>
3Date: Thu, 13 Jul 2023 19:55:07 +0800
4Subject: [PATCH] new dynamic quant algorigthm and init packed
5
6---
7 .../plugin/device/cpu/kernel/nnacl/BUILD.gn   |   1 +
8 .../opt/DynamicMatmulSdot4x4x16AIWI.S         | 240 +++++-
9 .../opt/DynamicMatmulSdot4x4x16AIWIForFp16.S  | 789 ++++++++++++++++++
10 .../kernel/nnacl/dynamic_quant_parameter.h    |   3 +
11 .../kernel/nnacl/int8/dynamic_matmul_int8.c   |  80 +-
12 .../kernel/nnacl/int8/dynamic_matmul_int8.h   |  48 +-
13 .../kernel/nnacl/int8/dynamic_quant_int8.c    |  24 +-
14 .../kernel/nnacl/int8/dynamic_quant_int8.h    |   2 +
15 .../kernel/nnacl/int8/quant_dtype_cast_int8.c | 134 +++
16 .../kernel/nnacl/int8/quant_dtype_cast_int8.h |   6 +
17 .../cpu/kernel/nnacl/matmul_parameter.h       |  17 +-
18 mindspore/core/ops/dynamic_quant.cc           |  19 +
19 mindspore/core/ops/dynamic_quant.h            |  30 +
20 mindspore/core/ops/op_name.h                  |   3 +
21 mindspore/lite/BUILD.gn                       |   2 +
22 mindspore/lite/schema/inner/ops_generated.h   |  64 +-
23 mindspore/lite/schema/ops.fbs                 |   3 +
24 mindspore/lite/schema/ops_generated.h         |  34 +-
25 mindspore/lite/src/CMakeLists.txt             |   2 +
26 mindspore/lite/src/common/mmap_utils.cc       |  63 ++
27 mindspore/lite/src/common/mmap_utils.h        |  27 +
28 mindspore/lite/src/common/ops/ops_def.cc      |   3 +
29 .../ops/populate/dynamic_quant_populate.cc    |   3 +
30 .../lite/src/common/primitive_t_utils.cc      |  14 +-
31 mindspore/lite/src/common/primitive_t_utils.h |   3 +-
32 mindspore/lite/src/runtime/inner_context.h    |   9 +
33 .../runtime/kernel/cpu/int8/dynamic_quant.cc  | 166 +++-
34 .../runtime/kernel/cpu/int8/dynamic_quant.h   |  23 +-
35 .../kernel/cpu/int8/matmul_base_int8.h        |   1 +
36 .../cpu/int8/matmul_dynamic_base_int8.cc      | 237 ++++--
37 .../cpu/int8/matmul_dynamic_base_int8.h       |  33 +-
38 .../kernel/cpu/int8/matmul_dynamic_int8.cc    |  35 +-
39 .../kernel/cpu/int8/matmul_dynamic_int8.h     |   4 +-
40 .../cpu/int8/matmul_dynamic_sdot_int8.cc      | 132 ++-
41 .../cpu/int8/matmul_dynamic_sdot_int8.h       |  23 +-
42 .../runtime/kernel/cpu/int8/matmul_int8.cc    |   2 +-
43 .../src/runtime/kernel/cpu/int8/matmul_int8.h |   4 +-
44 mindspore/lite/src/runtime/kernel_registry.h  |   4 +-
45 mindspore/lite/src/runtime/lite_kernel.h      |   2 +
46 mindspore/lite/src/runtime/lite_model.cc      |   7 +-
47 mindspore/lite/src/runtime/lite_model.h       |   1 +
48 mindspore/lite/src/runtime/lite_session.cc    |  45 +-
49 mindspore/lite/src/runtime/lite_session.h     |   3 +-
50 .../src/runtime/runtime_packed_node_pass.cc   | 358 ++++++++
51 .../src/runtime/runtime_packed_node_pass.h    |  83 ++
52 mindspore/lite/tools/common/graph_util.cc     | 103 +++
53 mindspore/lite/tools/common/graph_util.h      |   6 +
54 mindspore/lite/tools/converter/CMakeLists.txt |   4 +
55 .../lite/tools/converter/anf_transform.cc     |   8 +
56 .../lite/tools/converter/anf_transform.h      |   1 +
57 .../config_parser/config_file_parser.cc       |  19 +
58 .../config_parser/config_file_parser.h        |   9 +
59 .../config_parser/cpu_option_param_parser.cc  |  41 +
60 .../config_parser/cpu_option_param_parser.h   |  32 +
61 .../config_parser/quant_param_parser.cc       |  20 +
62 .../config_parser/quant_param_parser.h        |   1 +
63 mindspore/lite/tools/converter/converter.cc   |  19 +
64 .../tools/converter/converter_packed_node.cc  | 179 ++++
65 .../tools/converter/converter_packed_node.h   |  29 +
66 .../tools/converter/cxx_api/converter_para.h  |   6 +
67 .../converter/offline_packing_optimizer.cc    | 307 +++++++
68 .../converter/offline_packing_optimizer.h     |  87 ++
69 .../converter/quantizer/dynamic_quantizer.cc  |  13 +-
70 .../converter/quantizer/dynamic_quantizer.h   |   2 +
71 .../quantizer/insert_quant_node_manager.cc    |  60 +-
72 .../quantizer/insert_quant_node_manager.h     |   9 +-
73 .../tools/converter/quantizer/quant_params.h  |   7 +
74 67 files changed, 3497 insertions(+), 251 deletions(-)
75 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S
76 create mode 100644 mindspore/lite/src/common/mmap_utils.cc
77 create mode 100644 mindspore/lite/src/common/mmap_utils.h
78 create mode 100644 mindspore/lite/src/runtime/runtime_packed_node_pass.cc
79 create mode 100644 mindspore/lite/src/runtime/runtime_packed_node_pass.h
80 create mode 100644 mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc
81 create mode 100644 mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h
82 create mode 100644 mindspore/lite/tools/converter/converter_packed_node.cc
83 create mode 100644 mindspore/lite/tools/converter/converter_packed_node.h
84 create mode 100644 mindspore/lite/tools/converter/offline_packing_optimizer.cc
85 create mode 100644 mindspore/lite/tools/converter/offline_packing_optimizer.h
86
87diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
88index 3427a8a4..64188a68 100644
89--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
90+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
91@@ -619,6 +619,7 @@ arm64_fp16_assembly_sources = [
92
93 optimizing_assembly_sources = [
94   "assembly/opt/DynamicMatmulSdot4x4x16AIWI.S",
95+  "assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S",
96   "assembly/opt/MatmulDpInt8Opt.S",
97   "assembly/opt/MatmulDpInt8.S",
98   "assembly/opt/MatmulOptR4Int8.S",
99diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S
100index efacd61b..bf646f32 100644
101--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S
102+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S
103@@ -1,5 +1,5 @@
104 /**
105- * Copyright 2021 Huawei Technologies Co., Ltd
106+ * Copyright 2021-2023 Huawei Technologies Co., Ltd
107  *
108  * Licensed under the Apache License, Version 2.0 (the "License");
109  * you may not use this file except in compliance with the License.
110@@ -20,7 +20,7 @@
111
112 // void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales,
113 //                                  float *bias, size_t row, size_t col, size_t stride, const int *a_sums,
114-//                                  const int *b_sums, int64_t a_zp, int64_t b_zp_sum);
115+//                                  const int *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode);
116 // x0: a(left matrix ptr)
117 // x1: b(right matrix ptr)
118 // x2: out ptr
119@@ -34,18 +34,23 @@
120 // x10: b_sums
121 // x19/w19: a_zp
122 // x19/w20: b_zp_sum
123+// x21: act_type -> 0: none, 1:Relu, 3:Relu6
124+// x22: mode -> 0: TensorByTensor, 1:TensorByChannel, 2:ChannelByTensor, 3:ChannelByChannel
125
126 asm_function DynamicMatmulSdot4x4x16AIWI
127-    sub sp, sp, #144
128+    sub sp, sp, #160
129     st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
130     st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
131     stp x19, x20, [sp], #16
132+    stp x21, x22, [sp], #16
133
134     ldr x8, [sp]
135     ldr x9, [sp, #8]
136     ldr x10, [sp, #16]
137     ldr x19, [sp, #24]
138     ldr x20, [sp, #32]
139+    ldr x21, [sp, #40]
140+    ldr x22, [sp, #48]
141
142     dup v16.4s, wzr // dup:Duplicate general-purpose register to vector.
143     dup v17.4s, wzr
144@@ -64,7 +69,7 @@ asm_function DynamicMatmulSdot4x4x16AIWI
145     dup v30.4s, wzr
146     dup v31.4s, wzr
147
148-    mov x18, x1 // reload rhs ptr
149+    mov x11, x1 // reload rhs ptr
150     mov x17, x0 // reload lhs ptr
151     mov x16, x3 // reload depth
152
153@@ -75,7 +80,7 @@ asm_function DynamicMatmulSdot4x4x16AIWI
154
155 LoopDepth:
156     ld1 {v0.16b}, [x17], #16
157-    ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x18], #64
158+    ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64
159
160     sdot v16.4s, v1.16b, v0.4b[0]
161     sdot v17.4s, v2.16b, v0.4b[0]
162@@ -100,8 +105,8 @@ LoopDepth:
163
164 LoopDepthHalf:
165     ld1 {v0.16b}, [x17], #16
166-    ld1 {v1.16b, v2.16b}, [x18]
167-    add x18, x18, #64
168+    ld1 {v1.16b, v2.16b}, [x11]
169+    add x11, x11, #64
170     sdot v16.4s, v1.16b, v0.4b[0]
171     sdot v17.4s, v2.16b, v0.4b[0]
172     sdot v20.4s, v1.16b, v0.4b[1]
173@@ -117,8 +122,8 @@ LoopDepthHalf:
174
175 LoopDepthQuarter:
176     ld1 {v0.16b}, [x17], #16
177-    ld1 {v1.16b}, [x18]
178-    add x18, x18, #64
179+    ld1 {v1.16b}, [x11]
180+    add x11, x11, #64
181     sdot v16.4s, v1.16b, v0.4b[0]
182     sdot v20.4s, v1.16b, v0.4b[1]
183     sdot v24.4s, v1.16b, v0.4b[2]
184@@ -225,28 +230,108 @@ Convert2Float:
185
186 MultiplyScale:
187     // multi_scale * input_matrix
188-    ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x4]
189-
190-    fmul v16.4s,v16.4s,v1.4s
191-    fmul v17.4s,v17.4s,v2.4s
192-    fmul v18.4s,v18.4s,v3.4s
193-    fmul v19.4s,v19.4s,v4.4s
194-
195-    fmul v20.4s,v20.4s,v1.4s
196-    fmul v21.4s,v21.4s,v2.4s
197-    fmul v22.4s,v22.4s,v3.4s
198-    fmul v23.4s,v23.4s,v4.4s
199-
200-    fmul v24.4s,v24.4s,v1.4s
201-    fmul v25.4s,v25.4s,v2.4s
202-    fmul v26.4s,v26.4s,v3.4s
203-    fmul v27.4s,v27.4s,v4.4s
204-
205-    fmul v28.4s,v28.4s,v1.4s
206-    fmul v29.4s,v29.4s,v2.4s
207-    fmul v30.4s,v30.4s,v3.4s
208-    fmul v31.4s,v31.4s,v4.4s
209-
210+    cbz x22, TensorXTensor
211+    cmp x22, #1
212+    beq TensorXChannel
213+    cmp x22, #2
214+    beq ChannelXTensor
215+    ChannelXChannel:
216+        ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64
217+        ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64
218+
219+        fmul v16.4s,v16.4s,v0.4s
220+        fmul v17.4s,v17.4s,v1.4s
221+        fmul v18.4s,v18.4s,v2.4s
222+        fmul v19.4s,v19.4s,v3.4s
223+
224+        ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64
225+        ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4]
226+
227+        fmul v20.4s,v20.4s,v4.4s
228+        fmul v21.4s,v21.4s,v5.4s
229+        fmul v22.4s,v22.4s,v6.4s
230+        fmul v23.4s,v23.4s,v7.4s
231+
232+        fmul v24.4s,v24.4s,v8.4s
233+        fmul v25.4s,v25.4s,v9.4s
234+        fmul v26.4s,v26.4s,v10.4s
235+        fmul v27.4s,v27.4s,v11.4s
236+
237+        fmul v28.4s,v28.4s,v12.4s
238+        fmul v29.4s,v29.4s,v13.4s
239+        fmul v30.4s,v30.4s,v14.4s
240+        fmul v31.4s,v31.4s,v15.4s
241+        b AddBias
242+
243+    TensorXTensor:
244+        ld1 {v0.s}[0], [x4]
245+
246+        fmul v16.4s,v16.4s,v0.s[0]
247+        fmul v17.4s,v17.4s,v0.s[0]
248+        fmul v18.4s,v18.4s,v0.s[0]
249+        fmul v19.4s,v19.4s,v0.s[0]
250+
251+        fmul v20.4s,v20.4s,v0.s[0]
252+        fmul v21.4s,v21.4s,v0.s[0]
253+        fmul v22.4s,v22.4s,v0.s[0]
254+        fmul v23.4s,v23.4s,v0.s[0]
255+
256+        fmul v24.4s,v24.4s,v0.s[0]
257+        fmul v25.4s,v25.4s,v0.s[0]
258+        fmul v26.4s,v26.4s,v0.s[0]
259+        fmul v27.4s,v27.4s,v0.s[0]
260+
261+        fmul v28.4s,v28.4s,v0.s[0]
262+        fmul v29.4s,v29.4s,v0.s[0]
263+        fmul v30.4s,v30.4s,v0.s[0]
264+        fmul v31.4s,v31.4s,v0.s[0]
265+        b AddBias
266+
267+    TensorXChannel:
268+        ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4]
269+
270+        fmul v16.4s,v16.4s,v0.4s
271+        fmul v17.4s,v17.4s,v1.4s
272+        fmul v18.4s,v18.4s,v2.4s
273+        fmul v19.4s,v19.4s,v3.4s
274+
275+        fmul v20.4s,v20.4s,v0.4s
276+        fmul v21.4s,v21.4s,v1.4s
277+        fmul v22.4s,v22.4s,v2.4s
278+        fmul v23.4s,v23.4s,v3.4s
279+
280+        fmul v24.4s,v24.4s,v0.4s
281+        fmul v25.4s,v25.4s,v1.4s
282+        fmul v26.4s,v26.4s,v2.4s
283+        fmul v27.4s,v27.4s,v3.4s
284+
285+        fmul v28.4s,v28.4s,v0.4s
286+        fmul v29.4s,v29.4s,v1.4s
287+        fmul v30.4s,v30.4s,v2.4s
288+        fmul v31.4s,v31.4s,v3.4s
289+        b AddBias
290+
291+    ChannelXTensor:
292+        ld1 {v0.4s}, [x4]
293+        fmul v16.4s,v16.4s,v0.s[0]
294+        fmul v17.4s,v17.4s,v0.s[0]
295+        fmul v18.4s,v18.4s,v0.s[0]
296+        fmul v19.4s,v19.4s,v0.s[0]
297+
298+        fmul v20.4s,v20.4s,v0.s[1]
299+        fmul v21.4s,v21.4s,v0.s[1]
300+        fmul v22.4s,v22.4s,v0.s[1]
301+        fmul v23.4s,v23.4s,v0.s[1]
302+
303+        fmul v24.4s,v24.4s,v0.s[2]
304+        fmul v25.4s,v25.4s,v0.s[2]
305+        fmul v26.4s,v26.4s,v0.s[2]
306+        fmul v27.4s,v27.4s,v0.s[2]
307+
308+        fmul v28.4s,v28.4s,v0.s[3]
309+        fmul v29.4s,v29.4s,v0.s[3]
310+        fmul v30.4s,v30.4s,v0.s[3]
311+        fmul v31.4s,v31.4s,v0.s[3]
312 AddBias:
313     // +bias
314     cbz x5, StoreData
315@@ -272,6 +357,88 @@ AddBias:
316     fadd v30.4s,v30.4s,v3.4s
317     fadd v31.4s,v31.4s,v4.4s
318
319+Activate:
320+    cmp x21, #1
321+    beq Relu
322+    cmp x21, #3
323+    beq Relu6
324+    b StoreData
325+
326+Relu:
327+    dup v1.4s, wzr
328+
329+    smax v16.4s,v16.4s,v1.4s
330+    smax v17.4s,v17.4s,v1.4s
331+    smax v18.4s,v18.4s,v1.4s
332+    smax v19.4s,v19.4s,v1.4s
333+
334+    smax v20.4s,v20.4s,v1.4s
335+    smax v21.4s,v21.4s,v1.4s
336+    smax v22.4s,v22.4s,v1.4s
337+    smax v23.4s,v23.4s,v1.4s
338+
339+    smax v24.4s,v24.4s,v1.4s
340+    smax v25.4s,v25.4s,v1.4s
341+    smax v26.4s,v26.4s,v1.4s
342+    smax v27.4s,v27.4s,v1.4s
343+
344+    smax v28.4s,v28.4s,v1.4s
345+    smax v29.4s,v29.4s,v1.4s
346+    smax v30.4s,v30.4s,v1.4s
347+    smax v31.4s,v31.4s,v1.4s
348+
349+    b StoreData
350+
351+Relu6:
352+    dup v1.4s, wzr
353+    movi v2.4s, #6
354+    scvtf v2.4s, v2.4s
355+
356+    // max (out, 0)
357+    smax v16.4s,v16.4s,v1.4s
358+    smax v17.4s,v17.4s,v1.4s
359+    smax v18.4s,v18.4s,v1.4s
360+    smax v19.4s,v19.4s,v1.4s
361+
362+    smax v20.4s,v20.4s,v1.4s
363+    smax v21.4s,v21.4s,v1.4s
364+    smax v22.4s,v22.4s,v1.4s
365+    smax v23.4s,v23.4s,v1.4s
366+
367+    smax v24.4s,v24.4s,v1.4s
368+    smax v25.4s,v25.4s,v1.4s
369+    smax v26.4s,v26.4s,v1.4s
370+    smax v27.4s,v27.4s,v1.4s
371+
372+    smax v28.4s,v28.4s,v1.4s
373+    smax v29.4s,v29.4s,v1.4s
374+    smax v30.4s,v30.4s,v1.4s
375+    smax v31.4s,v31.4s,v1.4s
376+
377+    // min (out, 6)
378+
379+    smin v16.4s,v16.4s,v2.4s
380+    smin v17.4s,v17.4s,v2.4s
381+    smin v18.4s,v18.4s,v2.4s
382+    smin v19.4s,v19.4s,v2.4s
383+
384+    smin v20.4s,v20.4s,v2.4s
385+    smin v21.4s,v21.4s,v2.4s
386+    smin v22.4s,v22.4s,v2.4s
387+    smin v23.4s,v23.4s,v2.4s
388+
389+    smin v24.4s,v24.4s,v2.4s
390+    smin v25.4s,v25.4s,v2.4s
391+    smin v26.4s,v26.4s,v2.4s
392+    smin v27.4s,v27.4s,v2.4s
393+
394+    smin v28.4s,v28.4s,v2.4s
395+    smin v29.4s,v29.4s,v2.4s
396+    smin v30.4s,v30.4s,v2.4s
397+    smin v31.4s,v31.4s,v2.4s
398+
399+    b StoreData
400+
401 StoreData:
402     cmp x7, #16
403     beq Write16
404@@ -547,19 +714,19 @@ Write4:
405     b StoreDataEnd
406
407 Write3:
408-    st1 {v16.1d}, [x15]
409+    st1 {v16.1d}, [x15], #8
410     st1 {v16.s}[2], [x15]
411     cmp x6, #1
412     beq StoreDataEnd
413-    st1 {v20.1d}, [x14]
414+    st1 {v20.1d}, [x14], #8
415     st1 {v20.s}[2], [x14]
416     cmp x6, #2
417     beq StoreDataEnd
418-    st1 {v24.1d}, [x13]
419+    st1 {v24.1d}, [x13], #8
420     st1 {v24.s}[2], [x13]
421     cmp x6, #3
422     beq StoreDataEnd
423-    st1 {v28.1d}, [x12]
424+    st1 {v28.1d}, [x12], #8
425     st1 {v28.s}[2], [x12]
426     b StoreDataEnd
427
428@@ -589,9 +756,10 @@ Write1:
429     st1 {v28.s}[0], [x12]
430     b StoreDataEnd
431 StoreDataEnd:
432-    sub sp, sp, #144
433+    sub sp, sp, #160
434     ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
435     ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
436     ldp x19, x20, [sp], #16
437+    ldp x21, x22, [sp], #16
438     ret
439 #endif
440diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S
441new file mode 100644
442index 00000000..e22a572a
443--- /dev/null
444+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S
445@@ -0,0 +1,789 @@
446+/**
447+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
448+ *
449+ * Licensed under the Apache License, Version 2.0 (the "License");
450+ * you may not use this file except in compliance with the License.
451+ * You may obtain a copy of the License at
452+ *
453+ * http://www.apache.org/licenses/LICENSE-2.0
454+ *
455+ * Unless required by applicable law or agreed to in writing, software
456+ * distributed under the License is distributed on an "AS IS" BASIS,
457+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
458+ * See the License for the specific language governing permissions and
459+ * limitations under the License.
460+ */
461+#ifdef ENABLE_ARM64
462+#include "nnacl/assembly_global.h"
463+.text
464+.align 5
465+
466+// void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4,
467+//                                         float16_t *multi_scales, float16_t *bias, size_t row, size_t col, size_t stride,
468+//                                         const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum,
469+//                                         int64_t act_type, int64_t mode);
470+// x0: a(left matrix ptr)
471+// x1: b(right matrix ptr)
472+// x2: out ptr
473+// x3: deep
474+// x4: multi_scales
475+// x5: bias
476+// x6: row
477+// x7: col
478+// x8: stride
479+// x9: a_sums
480+// x10: b_sums
481+// x19/w19: a_zp
482+// x19/w20: b_zp_sum
483+// x21: act_type -> 0: none, 1:Relu, 3:Relu6
484+// x22: mode -> 0: TensorByTensor, 1:TensorByChannel, 2:ChannelByTensor, 3:ChannelByChannel
485+
486+asm_function DynamicMatmulSdot4x4x16AIWIForFp16
487+    sub sp, sp, #160
488+    st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
489+    st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
490+    stp x19, x20, [sp], #16
491+    stp x21, x22, [sp], #16
492+
493+    ldr x8, [sp]
494+    ldr x9, [sp, #8]
495+    ldr x10, [sp, #16]
496+    ldr x19, [sp, #24]
497+    ldr x20, [sp, #32]
498+    ldr x21, [sp, #40]
499+    ldr x22, [sp, #48]
500+
501+    dup v16.4s, wzr // dup:Duplicate general-purpose register to vector.
502+    dup v17.4s, wzr
503+    dup v18.4s, wzr
504+    dup v19.4s, wzr
505+    dup v20.4s, wzr
506+    dup v21.4s, wzr
507+    dup v22.4s, wzr
508+    dup v23.4s, wzr
509+    dup v24.4s, wzr
510+    dup v25.4s, wzr
511+    dup v26.4s, wzr
512+    dup v27.4s, wzr
513+    dup v28.4s, wzr
514+    dup v29.4s, wzr
515+    dup v30.4s, wzr
516+    dup v31.4s, wzr
517+
518+    mov x11, x1 // reload rhs ptr
519+    mov x17, x0 // reload lhs ptr
520+    mov x16, x3 // reload depth
521+
522+    cmp x7, #4
523+    ble LoopDepthQuarter
524+    cmp x7, #8
525+    ble LoopDepthHalf
526+
527+LoopDepth:
528+    ld1 {v0.16b}, [x17], #16
529+    ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64
530+
531+    sdot v16.4s, v1.16b, v0.4b[0]
532+    sdot v17.4s, v2.16b, v0.4b[0]
533+    sdot v18.4s, v3.16b, v0.4b[0]
534+    sdot v19.4s, v4.16b, v0.4b[0]
535+    sdot v20.4s, v1.16b, v0.4b[1]
536+    sdot v21.4s, v2.16b, v0.4b[1]
537+    sdot v22.4s, v3.16b, v0.4b[1]
538+    sdot v23.4s, v4.16b, v0.4b[1]
539+    sdot v24.4s, v1.16b, v0.4b[2]
540+    sdot v25.4s, v2.16b, v0.4b[2]
541+    sdot v26.4s, v3.16b, v0.4b[2]
542+    sdot v27.4s, v4.16b, v0.4b[2]
543+    sdot v28.4s, v1.16b, v0.4b[3]
544+    sdot v29.4s, v2.16b, v0.4b[3]
545+    sdot v30.4s, v3.16b, v0.4b[3]
546+    sdot v31.4s, v4.16b, v0.4b[3]
547+
548+    subs x16, x16, #4
549+    bgt LoopDepth
550+    b AddInputSum
551+
552+LoopDepthHalf:
553+    ld1 {v0.16b}, [x17], #16
554+    ld1 {v1.16b, v2.16b}, [x11]
555+    add x11, x11, #64
556+    sdot v16.4s, v1.16b, v0.4b[0]
557+    sdot v17.4s, v2.16b, v0.4b[0]
558+    sdot v20.4s, v1.16b, v0.4b[1]
559+    sdot v21.4s, v2.16b, v0.4b[1]
560+    sdot v24.4s, v1.16b, v0.4b[2]
561+    sdot v25.4s, v2.16b, v0.4b[2]
562+    sdot v28.4s, v1.16b, v0.4b[3]
563+    sdot v29.4s, v2.16b, v0.4b[3]
564+
565+    subs x16, x16, #4
566+    bgt LoopDepthHalf
567+    b AddInputSum
568+
569+LoopDepthQuarter:
570+    ld1 {v0.16b}, [x17], #16
571+    ld1 {v1.16b}, [x11]
572+    add x11, x11, #64
573+    sdot v16.4s, v1.16b, v0.4b[0]
574+    sdot v20.4s, v1.16b, v0.4b[1]
575+    sdot v24.4s, v1.16b, v0.4b[2]
576+    sdot v28.4s, v1.16b, v0.4b[3]
577+
578+    subs x16, x16, #4
579+    bgt LoopDepthQuarter
580+    b AddInputSum
581+
582+AddInputSum:
583+    cmp w20, #0
584+    beq AddInputSumEnd
585+    ld1 {v5.4s}, [x9], #16
586+    dup v6.4s, v5.s[0]
587+    dup v7.4s, v5.s[1]
588+    dup v8.4s, v5.s[2]
589+    dup v9.4s, v5.s[3]
590+
591+    sub v16.4s, v16.4s, v6.4s
592+    sub v17.4s, v17.4s, v6.4s
593+    sub v18.4s, v18.4s, v6.4s
594+    sub v19.4s, v19.4s, v6.4s
595+    sub v20.4s, v20.4s, v7.4s
596+    sub v21.4s, v21.4s, v7.4s
597+    sub v22.4s, v22.4s, v7.4s
598+    sub v23.4s, v23.4s, v7.4s
599+    sub v24.4s, v24.4s, v8.4s
600+    sub v25.4s, v25.4s, v8.4s
601+    sub v26.4s, v26.4s, v8.4s
602+    sub v27.4s, v27.4s, v8.4s
603+    sub v28.4s, v28.4s, v9.4s
604+    sub v29.4s, v29.4s, v9.4s
605+    sub v30.4s, v30.4s, v9.4s
606+    sub v31.4s, v31.4s, v9.4s
607+AddInputSumEnd:
608+
609+AddWeightSum:
610+    ld1 {v9.4s},  [x10], #16
611+    ld1 {v10.4s}, [x10], #16
612+    ld1 {v11.4s}, [x10], #16
613+    ld1 {v12.4s}, [x10], #16
614+    dup v13.4s, w19
615+    mul v9.4s, v9.4s, v13.4s
616+    mul v10.4s, v10.4s, v13.4s
617+    mul v11.4s, v11.4s, v13.4s
618+    mul v12.4s, v12.4s, v13.4s
619+    sub v16.4s, v16.4s, v9.4s
620+    sub v17.4s, v17.4s, v10.4s
621+    sub v18.4s, v18.4s, v11.4s
622+    sub v19.4s, v19.4s, v12.4s
623+    sub v20.4s, v20.4s, v9.4s
624+    sub v21.4s, v21.4s, v10.4s
625+    sub v22.4s, v22.4s, v11.4s
626+    sub v23.4s, v23.4s, v12.4s
627+    sub v24.4s, v24.4s, v9.4s
628+    sub v25.4s, v25.4s, v10.4s
629+    sub v26.4s, v26.4s, v11.4s
630+    sub v27.4s, v27.4s, v12.4s
631+    sub v28.4s, v28.4s, v9.4s
632+    sub v29.4s, v29.4s, v10.4s
633+    sub v30.4s, v30.4s, v11.4s
634+    sub v31.4s, v31.4s, v12.4s
635+
636+AddZpSum:
637+    mul w15, w19, w20
638+    cmp w15, #0
639+    beq AddZpSumEnd
640+    dup v14.4s, w15
641+    add v16.4s, v16.4s, v14.4s
642+    add v17.4s, v17.4s, v14.4s
643+    add v18.4s, v18.4s, v14.4s
644+    add v19.4s, v19.4s, v14.4s
645+    add v20.4s, v20.4s, v14.4s
646+    add v21.4s, v21.4s, v14.4s
647+    add v22.4s, v22.4s, v14.4s
648+    add v23.4s, v23.4s, v14.4s
649+    add v24.4s, v24.4s, v14.4s
650+    add v25.4s, v25.4s, v14.4s
651+    add v26.4s, v26.4s, v14.4s
652+    add v27.4s, v27.4s, v14.4s
653+    add v28.4s, v28.4s, v14.4s
654+    add v29.4s, v29.4s, v14.4s
655+    add v30.4s, v30.4s, v14.4s
656+    add v31.4s, v31.4s, v14.4s
657+AddZpSumEnd:
658+
659+Convert2Float:
660+    scvtf v16.4s, v16.4s
661+    scvtf v17.4s, v17.4s
662+    scvtf v18.4s, v18.4s
663+    scvtf v19.4s, v19.4s
664+    scvtf v20.4s, v20.4s
665+    scvtf v21.4s, v21.4s
666+    scvtf v22.4s, v22.4s
667+    scvtf v23.4s, v23.4s
668+    scvtf v24.4s, v24.4s
669+    scvtf v25.4s, v25.4s
670+    scvtf v26.4s, v26.4s
671+    scvtf v27.4s, v27.4s
672+    scvtf v28.4s, v28.4s
673+    scvtf v29.4s, v29.4s
674+    scvtf v30.4s, v30.4s
675+    scvtf v31.4s, v31.4s
676+
677+MultiplyScale:
678+    // multi_scale * input_matrix
679+    cbz x22, TensorXTensor
680+    cmp x22, #1
681+    beq TensorXChannel
682+    cmp x22, #2
683+    beq ChannelXTensor
684+    ChannelXChannel:
685+        ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64
686+        ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64
687+
688+        fmul v16.4s,v16.4s,v0.4s
689+        fmul v17.4s,v17.4s,v1.4s
690+        fmul v18.4s,v18.4s,v2.4s
691+        fmul v19.4s,v19.4s,v3.4s
692+
693+        ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64
694+        ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4]
695+
696+        fmul v20.4s,v20.4s,v4.4s
697+        fmul v21.4s,v21.4s,v5.4s
698+        fmul v22.4s,v22.4s,v6.4s
699+        fmul v23.4s,v23.4s,v7.4s
700+
701+        fmul v24.4s,v24.4s,v8.4s
702+        fmul v25.4s,v25.4s,v9.4s
703+        fmul v26.4s,v26.4s,v10.4s
704+        fmul v27.4s,v27.4s,v11.4s
705+
706+        fmul v28.4s,v28.4s,v12.4s
707+        fmul v29.4s,v29.4s,v13.4s
708+        fmul v30.4s,v30.4s,v14.4s
709+        fmul v31.4s,v31.4s,v15.4s
710+        b ConvertHalfPrecision
711+
712+    TensorXTensor:
713+        ld1 {v0.s}[0], [x4]
714+
715+        fmul v16.4s,v16.4s,v0.s[0]
716+        fmul v17.4s,v17.4s,v0.s[0]
717+        fmul v18.4s,v18.4s,v0.s[0]
718+        fmul v19.4s,v19.4s,v0.s[0]
719+
720+        fmul v20.4s,v20.4s,v0.s[0]
721+        fmul v21.4s,v21.4s,v0.s[0]
722+        fmul v22.4s,v22.4s,v0.s[0]
723+        fmul v23.4s,v23.4s,v0.s[0]
724+
725+        fmul v24.4s,v24.4s,v0.s[0]
726+        fmul v25.4s,v25.4s,v0.s[0]
727+        fmul v26.4s,v26.4s,v0.s[0]
728+        fmul v27.4s,v27.4s,v0.s[0]
729+
730+        fmul v28.4s,v28.4s,v0.s[0]
731+        fmul v29.4s,v29.4s,v0.s[0]
732+        fmul v30.4s,v30.4s,v0.s[0]
733+        fmul v31.4s,v31.4s,v0.s[0]
734+        b ConvertHalfPrecision
735+
736+    TensorXChannel:
737+        ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4]
738+
739+        fmul v16.4s,v16.4s,v0.4s
740+        fmul v17.4s,v17.4s,v1.4s
741+        fmul v18.4s,v18.4s,v2.4s
742+        fmul v19.4s,v19.4s,v3.4s
743+
744+        fmul v20.4s,v20.4s,v0.4s
745+        fmul v21.4s,v21.4s,v1.4s
746+        fmul v22.4s,v22.4s,v2.4s
747+        fmul v23.4s,v23.4s,v3.4s
748+
749+        fmul v24.4s,v24.4s,v0.4s
750+        fmul v25.4s,v25.4s,v1.4s
751+        fmul v26.4s,v26.4s,v2.4s
752+        fmul v27.4s,v27.4s,v3.4s
753+
754+        fmul v28.4s,v28.4s,v0.4s
755+        fmul v29.4s,v29.4s,v1.4s
756+        fmul v30.4s,v30.4s,v2.4s
757+        fmul v31.4s,v31.4s,v3.4s
758+        b ConvertHalfPrecision
759+
760+    ChannelXTensor:
761+        ld1 {v0.4s}, [x4]
762+        fmul v16.4s,v16.4s,v0.s[0]
763+        fmul v17.4s,v17.4s,v0.s[0]
764+        fmul v18.4s,v18.4s,v0.s[0]
765+        fmul v19.4s,v19.4s,v0.s[0]
766+
767+        fmul v20.4s,v20.4s,v0.s[1]
768+        fmul v21.4s,v21.4s,v0.s[1]
769+        fmul v22.4s,v22.4s,v0.s[1]
770+        fmul v23.4s,v23.4s,v0.s[1]
771+
772+        fmul v24.4s,v24.4s,v0.s[2]
773+        fmul v25.4s,v25.4s,v0.s[2]
774+        fmul v26.4s,v26.4s,v0.s[2]
775+        fmul v27.4s,v27.4s,v0.s[2]
776+
777+        fmul v28.4s,v28.4s,v0.s[3]
778+        fmul v29.4s,v29.4s,v0.s[3]
779+        fmul v30.4s,v30.4s,v0.s[3]
780+        fmul v31.4s,v31.4s,v0.s[3]
781+
782+ConvertHalfPrecision:
783+// from single-precision convert to half-precision
784+    fcvtn v16.4h,v16.4s
785+    fcvtn v17.4h,v17.4s
786+    fcvtn v18.4h,v18.4s
787+    fcvtn v19.4h,v19.4s
788+
789+    fcvtn v20.4h,v20.4s
790+    fcvtn v21.4h,v21.4s
791+    fcvtn v22.4h,v22.4s
792+    fcvtn v23.4h,v23.4s
793+
794+    fcvtn v24.4h,v24.4s
795+    fcvtn v25.4h,v25.4s
796+    fcvtn v26.4h,v26.4s
797+    fcvtn v27.4h,v27.4s
798+
799+    fcvtn v28.4h,v28.4s
800+    fcvtn v29.4h,v29.4s
801+    fcvtn v30.4h,v30.4s
802+    fcvtn v31.4h,v31.4s
803+
804+AddBias:
805+    // +bias
806+    cbz x5, StoreData
807+    ld1 {v1.4h, v2.4h, v3.4h, v4.4h}, [x5]
808+
809+    fadd v16.4h,v16.4h,v1.4h
810+    fadd v17.4h,v17.4h,v2.4h
811+    fadd v18.4h,v18.4h,v3.4h
812+    fadd v19.4h,v19.4h,v4.4h
813+
814+    fadd v20.4h,v20.4h,v1.4h
815+    fadd v21.4h,v21.4h,v2.4h
816+    fadd v22.4h,v22.4h,v3.4h
817+    fadd v23.4h,v23.4h,v4.4h
818+
819+    fadd v24.4h,v24.4h,v1.4h
820+    fadd v25.4h,v25.4h,v2.4h
821+    fadd v26.4h,v26.4h,v3.4h
822+    fadd v27.4h,v27.4h,v4.4h
823+
824+    fadd v28.4h,v28.4h,v1.4h
825+    fadd v29.4h,v29.4h,v2.4h
826+    fadd v30.4h,v30.4h,v3.4h
827+    fadd v31.4h,v31.4h,v4.4h
828+
829+Activate:
830+    cmp x21, #1
831+    beq Relu
832+    cmp x21, #3
833+    beq Relu6
834+    b StoreData
835+
836+Relu:
837+    dup v1.4h, wzr
838+
839+    smax v16.4h,v16.4h,v1.4h
840+    smax v17.4h,v17.4h,v1.4h
841+    smax v18.4h,v18.4h,v1.4h
842+    smax v19.4h,v19.4h,v1.4h
843+
844+    smax v20.4h,v20.4h,v1.4h
845+    smax v21.4h,v21.4h,v1.4h
846+    smax v22.4h,v22.4h,v1.4h
847+    smax v23.4h,v23.4h,v1.4h
848+
849+    smax v24.4h,v24.4h,v1.4h
850+    smax v25.4h,v25.4h,v1.4h
851+    smax v26.4h,v26.4h,v1.4h
852+    smax v27.4h,v27.4h,v1.4h
853+
854+    smax v28.4h,v28.4h,v1.4h
855+    smax v29.4h,v29.4h,v1.4h
856+    smax v30.4h,v30.4h,v1.4h
857+    smax v31.4h,v31.4h,v1.4h
858+
859+    b StoreData
860+
861+Relu6:
862+    dup v1.4h, wzr
863+    movi v2.4h, #6
864+    scvtf v2.4h, v2.4h
865+
866+    // max (out, 0)
867+    smax v16.4h,v16.4h,v1.4h
868+    smax v17.4h,v17.4h,v1.4h
869+    smax v18.4h,v18.4h,v1.4h
870+    smax v19.4h,v19.4h,v1.4h
871+
872+    smax v20.4h,v20.4h,v1.4h
873+    smax v21.4h,v21.4h,v1.4h
874+    smax v22.4h,v22.4h,v1.4h
875+    smax v23.4h,v23.4h,v1.4h
876+
877+    smax v24.4h,v24.4h,v1.4h
878+    smax v25.4h,v25.4h,v1.4h
879+    smax v26.4h,v26.4h,v1.4h
880+    smax v27.4h,v27.4h,v1.4h
881+
882+    smax v28.4h,v28.4h,v1.4h
883+    smax v29.4h,v29.4h,v1.4h
884+    smax v30.4h,v30.4h,v1.4h
885+    smax v31.4h,v31.4h,v1.4h
886+
887+    // min (out, 6)
888+
889+    smin v16.4h,v16.4h,v2.4h
890+    smin v17.4h,v17.4h,v2.4h
891+    smin v18.4h,v18.4h,v2.4h
892+    smin v19.4h,v19.4h,v2.4h
893+
894+    smin v20.4h,v20.4h,v2.4h
895+    smin v21.4h,v21.4h,v2.4h
896+    smin v22.4h,v22.4h,v2.4h
897+    smin v23.4h,v23.4h,v2.4h
898+
899+    smin v24.4h,v24.4h,v2.4h
900+    smin v25.4h,v25.4h,v2.4h
901+    smin v26.4h,v26.4h,v2.4h
902+    smin v27.4h,v27.4h,v2.4h
903+
904+    smin v28.4h,v28.4h,v2.4h
905+    smin v29.4h,v29.4h,v2.4h
906+    smin v30.4h,v30.4h,v2.4h
907+    smin v31.4h,v31.4h,v2.4h
908+
909+    b StoreData
910+
911+StoreData:
912+    cmp x7, #16
913+    beq Write16
914+
915+    mov x15, x2 // reload out ptr
916+    add x14, x15, x8
917+    add x13, x14, x8
918+    add x12, x13, x8
919+
920+    cmp x7, #15
921+    beq Write15
922+    cmp x7, #14
923+    beq Write14
924+    cmp x7, #13
925+    beq Write13
926+    cmp x7, #12
927+    beq Write12
928+    cmp x7, #11
929+    beq Write11
930+    cmp x7, #10
931+    beq Write10
932+    cmp x7, #9
933+    beq Write9
934+    cmp x7, #8
935+    beq Write8
936+    cmp x7, #7
937+    beq Write7
938+    cmp x7, #6
939+    beq Write6
940+    cmp x7, #5
941+    beq Write5
942+    cmp x7, #4
943+    beq Write4
944+    cmp x7, #3
945+    beq Write3
946+    cmp x7, #2
947+    beq Write2
948+    cmp x7, #1
949+    beq Write1
950+    b StoreDataEnd
951+
952+Write16:
953+    cmp x6, #4
954+    beq Write16Row4
955+    cmp x6, #3
956+    beq Write16Row3
957+    cmp x6, #2
958+    beq Write16Row2
959+    cmp x6, #1
960+    beq Write16Row1
961+
962+    Write16Row4:
963+        st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8
964+        st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8
965+        st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2], x8
966+        st1 {v28.4h,v29.4h,v30.4h,v31.4h}, [x2]
967+        b StoreDataEnd
968+    Write16Row3:
969+        st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8
970+        st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8
971+        st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2]
972+        b StoreDataEnd
973+    Write16Row2:
974+        st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8
975+        st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2]
976+        b StoreDataEnd
977+    Write16Row1:
978+        st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2]
979+        b StoreDataEnd
980+
981+Write15:
982+    st1 {v16.4h,v17.4h,v18.4h}, [x15], #24
983+    st1 {v19.s}[0], [x15], #4
984+    st1 {v19.h}[2], [x15]
985+    cmp x6, #1
986+    beq StoreDataEnd
987+    st1 {v20.4h,v21.4h,v22.4h}, [x14], #24
988+    st1 {v23.s}[0], [x14], #4
989+    st1 {v23.h}[2], [x14]
990+    cmp x6, #2
991+    beq StoreDataEnd
992+    st1 {v24.4h,v25.4h,v26.4h}, [x13], #24
993+    st1 {v27.s}[0], [x13], #4
994+    st1 {v27.h}[2], [x13]
995+    cmp x6, #3
996+    beq StoreDataEnd
997+    st1 {v28.4h,v29.4h,v30.4h}, [x12], #24
998+    st1 {v31.s}[0], [x12], #4
999+    st1 {v31.h}[2], [x12]
1000+    b StoreDataEnd
1001+
1002+Write14:
1003+    st1 {v16.4h,v17.4h,v18.4h}, [x15], #24
1004+    st1 {v19.s}[0], [x15]
1005+    cmp x6, #1
1006+    beq StoreDataEnd
1007+    st1 {v20.4h,v21.4h,v22.4h}, [x14], #24
1008+    st1 {v23.s}[0], [x14]
1009+    cmp x6, #2
1010+    beq StoreDataEnd
1011+    st1 {v24.4h,v25.4h,v26.4h}, [x13], #24
1012+    st1 {v27.s}[0], [x13]
1013+    cmp x6, #3
1014+    beq StoreDataEnd
1015+    st1 {v28.4h,v29.4h,v30.4h}, [x12], #24
1016+    st1 {v31.s}[0], [x12]
1017+    b StoreDataEnd
1018+
1019+Write13:
1020+    st1 {v16.4h,v17.4h,v18.4h}, [x15], #24
1021+    st1 {v19.h}[0], [x15]
1022+    cmp x6, #1
1023+    beq StoreDataEnd
1024+    st1 {v20.4h,v21.4h,v22.4h}, [x14], #24
1025+    st1 {v23.h}[0], [x14]
1026+    cmp x6, #2
1027+    beq StoreDataEnd
1028+    st1 {v24.4h,v25.4h,v26.4h}, [x13], #24
1029+    st1 {v27.h}[0], [x13]
1030+    cmp x6, #3
1031+    beq StoreDataEnd
1032+    st1 {v28.4h,v29.4h,v30.4h}, [x12], #24
1033+    st1 {v31.h}[0], [x12]
1034+    b StoreDataEnd
1035+
1036+Write12:
1037+    st1 {v16.4h,v17.4h,v18.4h}, [x15], #24
1038+    cmp x6, #1
1039+    beq StoreDataEnd
1040+    st1 {v20.4h,v21.4h,v22.4h}, [x14], #24
1041+    cmp x6, #2
1042+    beq StoreDataEnd
1043+    st1 {v24.4h,v25.4h,v26.4h}, [x13], #24
1044+    cmp x6, #3
1045+    beq StoreDataEnd
1046+    st1 {v28.4h,v29.4h,v30.4h}, [x12], #24
1047+    b StoreDataEnd
1048+
1049+Write11:
1050+    st1 {v16.4h,v17.4h}, [x15], #16
1051+    st1 {v18.s}[0], [x15], #4
1052+    st1 {v18.h}[2], [x15]
1053+    cmp x6, #1
1054+    beq StoreDataEnd
1055+    st1 {v20.4h,v21.4h}, [x14], #16
1056+    st1 {v22.s}[0], [x14], #4
1057+    st1 {v22.h}[2], [x14]
1058+    cmp x6, #2
1059+    beq StoreDataEnd
1060+    st1 {v24.4h,v25.4h}, [x13], #16
1061+    st1 {v26.s}[0], [x13], #4
1062+    st1 {v26.h}[2], [x13]
1063+    cmp x6, #3
1064+    beq StoreDataEnd
1065+    st1 {v28.4h,v29.4h}, [x12], #16
1066+    st1 {v30.s}[0], [x12], #4
1067+    st1 {v30.h}[2], [x12]
1068+    b StoreDataEnd
1069+
1070+Write10:
1071+    st1 {v16.4h,v17.4h}, [x15], #16
1072+    st1 {v18.s}[0], [x15]
1073+    cmp x6, #1
1074+    beq StoreDataEnd
1075+    st1 {v20.4h,v21.4h}, [x14], #16
1076+    st1 {v22.s}[0], [x14]
1077+    cmp x6, #2
1078+    beq StoreDataEnd
1079+    st1 {v24.4h,v25.4h}, [x13], #16
1080+    st1 {v26.s}[0], [x13]
1081+    cmp x6, #3
1082+    beq StoreDataEnd
1083+    st1 {v28.4h,v29.4h}, [x12], #16
1084+    st1 {v30.s}[0], [x12]
1085+    b StoreDataEnd
1086+
1087+Write9:
1088+    st1 {v16.4h,v17.4h}, [x15], #16
1089+    st1 {v18.h}[0], [x15]
1090+    cmp x6, #1
1091+    beq StoreDataEnd
1092+    st1 {v20.4h,v21.4h}, [x14], #16
1093+    st1 {v22.h}[0], [x14]
1094+    cmp x6, #2
1095+    beq StoreDataEnd
1096+    st1 {v24.4h,v25.4h}, [x13], #16
1097+    st1 {v26.h}[0], [x13]
1098+    cmp x6, #3
1099+    beq StoreDataEnd
1100+    st1 {v28.4h,v29.4h}, [x12], #16
1101+    st1 {v30.h}[0], [x12]
1102+    b StoreDataEnd
1103+
1104+Write8:
1105+    st1 {v16.4h,v17.4h}, [x15], #16
1106+    cmp x6, #1
1107+    beq StoreDataEnd
1108+    st1 {v20.4h,v21.4h}, [x14], #16
1109+    cmp x6, #2
1110+    beq StoreDataEnd
1111+    st1 {v24.4h,v25.4h}, [x13], #16
1112+    cmp x6, #3
1113+    beq StoreDataEnd
1114+    st1 {v28.4h,v29.4h}, [x12], #16
1115+    b StoreDataEnd
1116+
1117+Write7:
1118+    st1 {v16.4h}, [x15], #8
1119+    st1 {v17.s}[0], [x15], #4
1120+    st1 {v17.h}[2], [x15]
1121+    cmp x6, #1
1122+    beq StoreDataEnd
1123+    st1 {v20.4h}, [x14], #8
1124+    st1 {v21.s}[0], [x14], #4
1125+    st1 {v21.h}[2], [x14]
1126+    cmp x6, #2
1127+    beq StoreDataEnd
1128+    st1 {v24.4h}, [x13], #8
1129+    st1 {v25.s}[0], [x13], #4
1130+    st1 {v25.h}[2], [x13]
1131+    cmp x6, #3
1132+    beq StoreDataEnd
1133+    st1 {v28.4h}, [x12], #8
1134+    st1 {v29.s}[0], [x12], #4
1135+    st1 {v29.h}[2], [x12]
1136+    b StoreDataEnd
1137+
1138+Write6:
1139+    st1 {v16.4h}, [x15], #8
1140+    st1 {v17.s}[0], [x15]
1141+    cmp x6, #1
1142+    beq StoreDataEnd
1143+    st1 {v20.4h}, [x14], #8
1144+    st1 {v21.s}[0], [x14]
1145+    cmp x6, #2
1146+    beq StoreDataEnd
1147+    st1 {v24.4h}, [x13], #8
1148+    st1 {v25.s}[0], [x13]
1149+    cmp x6, #3
1150+    beq StoreDataEnd
1151+    st1 {v28.4h}, [x12], #8
1152+    st1 {v29.s}[0], [x12]
1153+    b StoreDataEnd
1154+
1155+Write5:
1156+    st1 {v16.4h}, [x15], #8
1157+    st1 {v17.h}[0], [x15]
1158+    cmp x6, #1
1159+    beq StoreDataEnd
1160+    st1 {v20.4h}, [x14], #8
1161+    st1 {v21.h}[0], [x14]
1162+    cmp x6, #2
1163+    beq StoreDataEnd
1164+    st1 {v24.4h}, [x13], #8
1165+    st1 {v25.h}[0], [x13]
1166+    cmp x6, #3
1167+    beq StoreDataEnd
1168+    st1 {v28.4h}, [x12], #8
1169+    st1 {v29.h}[0], [x12]
1170+    b StoreDataEnd
1171+
1172+Write4:
1173+    st1 {v16.4h}, [x15]
1174+    cmp x6, #1
1175+    beq StoreDataEnd
1176+    st1 {v20.4h}, [x14]
1177+    cmp x6, #2
1178+    beq StoreDataEnd
1179+    st1 {v24.4h}, [x13]
1180+    cmp x6, #3
1181+    beq StoreDataEnd
1182+    st1 {v28.4h}, [x12]
1183+    b StoreDataEnd
1184+
1185+Write3:
1186+    st1 {v16.s}[0], [x15], #4
1187+    st1 {v16.h}[2], [x15]
1188+    cmp x6, #1
1189+    beq StoreDataEnd
1190+    st1 {v20.s}[0], [x14], #4
1191+    st1 {v20.h}[2], [x14]
1192+    cmp x6, #2
1193+    beq StoreDataEnd
1194+    st1 {v24.s}[0], [x13], #4
1195+    st1 {v24.h}[2], [x13]
1196+    cmp x6, #3
1197+    beq StoreDataEnd
1198+    st1 {v28.s}[0], [x12], #4
1199+    st1 {v28.h}[2], [x12]
1200+    b StoreDataEnd
1201+
1202+Write2:
1203+    st1 {v16.s}[0], [x15]
1204+    cmp x6, #1
1205+    beq StoreDataEnd
1206+    st1 {v20.s}[0], [x14]
1207+    cmp x6, #2
1208+    beq StoreDataEnd
1209+    st1 {v24.s}[0], [x13]
1210+    cmp x6, #3
1211+    beq StoreDataEnd
1212+    st1 {v28.s}[0], [x12]
1213+    b StoreDataEnd
1214+
1215+Write1:
1216+    st1 {v16.h}[0], [x15]
1217+    cmp x6, #1
1218+    beq StoreDataEnd
1219+    st1 {v20.h}[0], [x14]
1220+    cmp x6, #2
1221+    beq StoreDataEnd
1222+    st1 {v24.h}[0], [x13]
1223+    cmp x6, #3
1224+    beq StoreDataEnd
1225+    st1 {v28.h}[0], [x12]
1226+    b StoreDataEnd
1227+StoreDataEnd:
1228+    sub sp, sp, #160
1229+    ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
1230+    ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
1231+    ldp x19, x20, [sp], #16
1232+    ldp x21, x22, [sp], #16
1233+    ret
1234+#endif
1235diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h
1236index 627b9ee6..dfc05f28 100644
1237--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h
1238+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h
1239@@ -22,6 +22,9 @@ typedef struct DynamicQuantParameter {
1240   OpParameter op_parameter_;
1241   bool symmetric_;
1242   int64_t dst_type_;
1243+  bool activation_perchannel_;
1244+  int64_t prefer_axis_;
1245+  bool transpose_;
1246 } DynamicQuantParameter;
1247
1248 #endif  // MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_
1249diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c
1250index 0bfa6475..a09a4359 100644
1251--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c
1252+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c
1253@@ -1,5 +1,5 @@
1254 /**
1255- * Copyright 2022 Huawei Technologies Co., Ltd
1256+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
1257  *
1258  * Licensed under the Apache License, Version 2.0 (the "License");
1259  * you may not use this file except in compliance with the License.
1260@@ -17,15 +17,15 @@
1261 #include "nnacl/int8/dynamic_matmul_int8.h"
1262 #include "nnacl/int8/fixed_point.h"
1263
1264-void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales,
1265-                             float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums,
1266-                             const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum) {
1267+void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales,
1268+                             const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums,
1269+                             const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode) {
1270   /* *
1271    * row4x4-major * row4x16-major => (int8)row-major
1272    * support activation per-layer symmetric && weight per-layer/per-channel symmetric
1273    * */
1274   for (int r = 0; r < row; r++) {
1275-    int64_t s2 = a_sums[r] * b_zp_sum;
1276+    int64_t s2 = a_sums[r];
1277     for (int c = 0; c < col; c++) {
1278       int r4div = r / C4NUM, r4mod = r % C4NUM;
1279       int c16div = c / C16NUM, c16mod = c % C16NUM;
1280@@ -39,18 +39,67 @@ void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_
1281       int64_t s3 = b_sums[c] * a_zp;
1282       int64_t s4 = a_zp * b_zp_sum;
1283       size_t ci = r * stride / sizeof(float) + c;
1284-      out[ci] = multi_scales[c] * (s1 - s2 - s3 + s4);
1285+      int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c));
1286+      out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4);
1287       if (bias != NULL) {
1288         out[ci] += bias[c];
1289       }
1290+      if (act_type == ActType_Relu) {
1291+        out[ci] = MSMAX(0, out[ci]);
1292+      } else if (act_type == ActType_Relu6) {
1293+        out[ci] = MSMAX(0, out[ci]);
1294+        out[ci] = MSMIN(C6NUM, out[ci]);
1295+      }
1296     }
1297   }
1298   return;
1299 }
1300
1301+#ifdef ENABLE_FP16
1302+void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4,
1303+                                    const float *multi_scales, const float16_t *bias, size_t row, size_t col,
1304+                                    size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp,
1305+                                    int64_t b_zp_sum, int64_t act_type, int64_t mode) {
1306+  /* *
1307+   * row4x4-major * row4x16-major => (int8)row-major
1308+   * support activation per-layer symmetric && weight per-layer/per-channel symmetric
1309+   * */
1310+  for (int r = 0; r < row; r++) {
1311+    int64_t s2 = a_sums[r];
1312+    for (int c = 0; c < col; c++) {
1313+      int r4div = r / C4NUM, r4mod = r % C4NUM;
1314+      int c16div = c / C16NUM, c16mod = c % C16NUM;
1315+      int32_t s1 = 0;
1316+      for (int d = 0; d < deep4; d++) {
1317+        int d4div = d / C4NUM, d4mod = d % C4NUM;
1318+        size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod;
1319+        size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod;
1320+        s1 += a[ai] * b[bi];
1321+      }
1322+      int64_t s3 = b_sums[c] * a_zp;
1323+      int64_t s4 = a_zp * b_zp_sum;
1324+      size_t ci = r * stride / sizeof(float16_t) + c;
1325+      int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c));
1326+      out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4);
1327+      if (bias != NULL) {
1328+        out[ci] += bias[c];
1329+      }
1330+      if (act_type == ActType_Relu) {
1331+        out[ci] = MSMAX(0, out[ci]);
1332+      } else if (act_type == ActType_Relu6) {
1333+        out[ci] = MSMAX(0, out[ci]);
1334+        out[ci] = MSMIN(C6NUM, out[ci]);
1335+      }
1336+    }
1337+  }
1338+  return;
1339+}
1340+#endif
1341+
1342 void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
1343-                             int deep, int deep16, size_t stride, int input_zp, float input_scale,
1344-                             const float *filter_scale, const int filter_zp, bool filter_per_channel) {
1345+                             int deep, int deep16, size_t stride, int input_zp, const float *input_scale,
1346+                             const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel,
1347+                             int64_t act_type) {
1348   /* *
1349    * row4x16-major * row16x4-major => (int8)row-major
1350    * support activation per-layer symmetric && weight per-layer/per-channel symmetric
1351@@ -74,13 +123,20 @@ void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias
1352         s3 += input_zp * filter_zp;
1353       }
1354       value = s0 - s1 - s2 + s3;
1355+      int input_quant_index = input_per_channel ? r : 0;
1356       int filter_quant_index = filter_per_channel ? c : 0;
1357-      float multi_scale = input_scale * filter_scale[filter_quant_index];
1358+      float multi_scale = input_scale[input_quant_index] * filter_scale[filter_quant_index];
1359       size_t ci = r * stride + c;
1360       dst[ci] = multi_scale * value;
1361       if (bias != NULL) {
1362         dst[ci] += bias[c];
1363       }
1364+      if (act_type == ActType_Relu) {
1365+        dst[ci] = MSMAX(0, dst[ci]);
1366+      } else if (act_type == ActType_Relu6) {
1367+        dst[ci] = MSMAX(0, dst[ci]);
1368+        dst[ci] = MSMIN(C6NUM, dst[ci]);
1369+      }
1370     }
1371   }
1372   return;
1373@@ -166,8 +222,8 @@ void PackInput4x4Asm(const int8_t *src_ic, int8_t *pack_ic, size_t ic_4div, size
1374     "6: \n"
1375
1376     :
1377-    : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div),
1378-      [ ic_4res ] "r"(ic_4res)
1379+    : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [src_stride] "r"(src_stride), [ic_4div] "r"(ic_4div),
1380+      [ic_4res] "r"(ic_4res)
1381     : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3");
1382 }
1383 #endif
1384@@ -276,7 +332,7 @@ void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, in
1385       "1:\n"
1386
1387       :
1388-      : [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ row ] "r"(row_div), [ row_stride ] "r"(row_stride_int64)
1389+      : [src_ic] "r"(src_ic), [packed_ic] "r"(packed_ic), [row] "r"(row_div), [row_stride] "r"(row_stride_int64)
1390       : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12");
1391     packed_ic += C4NUM * row_div;
1392     src_ic += row_div * row_stride;
1393diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h
1394index ef835898..77e861bb 100644
1395--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h
1396+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h
1397@@ -1,5 +1,5 @@
1398 /**
1399- * Copyright 2022 Huawei Technologies Co., Ltd
1400+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
1401  *
1402  * Licensed under the Apache License, Version 2.0 (the "License");
1403  * you may not use this file except in compliance with the License.
1404@@ -27,18 +27,46 @@ extern "C" {
1405 void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride);
1406 void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size);
1407 void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
1408-                             int deep, int deep16, size_t stride, int input_zp, float input_scale,
1409-                             const float *filter_scale, const int filter_zp, bool filter_per_channel);
1410+                             int deep, int deep16, size_t stride, int input_zp, const float *input_scale,
1411+                             const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel,
1412+                             int64_t act_type);
1413 void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order);
1414 void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order);
1415-#ifdef ENABLE_ARM64
1416-void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales,
1417-                                 float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums,
1418-                                 const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum);
1419+#if defined(ENABLE_ARM64) && !defined(USE_AOS_GCC_TOOLCHAIN)
1420+/*
1421+ * mode is used to distinguish different quantization scenarios, whose value is 0-3.
1422+ * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel.
1423+ */
1424+void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales,
1425+                                 const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums,
1426+                                 const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode);
1427+#endif
1428+/*
1429+ * mode is used to distinguish different quantization scenarios, whose value is 0-3.
1430+ * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel.
1431+ */
1432+void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales,
1433+                             const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums,
1434+                             const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode);
1435+#ifdef ENABLE_FP16
1436+/*
1437+ * mode is used to distinguish different quantization scenarios, whose value is 0-3.
1438+ * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel.
1439+ */
1440+void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4,
1441+                                    const float *multi_scales, const float16_t *bias, size_t row, size_t col,
1442+                                    size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp,
1443+                                    int64_t b_zp_sum, int64_t act_type, int64_t mode);
1444+/*
1445+ * mode is used to distinguish different quantization scenarios, whose value is 0-3.
1446+ * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel.
1447+ */
1448+void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4,
1449+                                        const float *multi_scales, const float16_t *bias, size_t row, size_t col,
1450+                                        size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp,
1451+                                        int64_t b_zp_sum, int64_t act_type, int64_t mode);
1452 #endif
1453-void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales,
1454-                             float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums,
1455-                             const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum);
1456+
1457 #ifdef __cplusplus
1458 }
1459 #endif
1460diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c
1461index bca1cbca..4ec4ebb8 100644
1462--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c
1463+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c
1464@@ -16,6 +16,9 @@
1465
1466 #include "nnacl/int8/dynamic_quant_int8.h"
1467 void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) {
1468+if (count == 0) {
1469+    return;
1470+  }
1471 #ifndef ENABLE_ARM64
1472   for (int i = 0; i < count; ++i) {
1473     if (data[i] < *real_min) {
1474@@ -26,7 +29,7 @@ void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *r
1475     }
1476   }
1477 #else
1478-  // avoid to compile optimize.
1479+	// avoid to compile optimize.
1480   volatile int count_4 = DOWN_ROUND(count, C4NUM);
1481   asm volatile(
1482     "mov x4, %[data]\n"          // reload data
1483@@ -63,3 +66,22 @@ void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *r
1484   }
1485 #endif
1486 }
1487+
1488+void CalculateAllChannelMinMax(const float *data, int count, float *real_min, float *real_max, int channel_length) {
1489+	int channel_total = count / channel_length;
1490+	for (int i = 0; i < channel_total; i++) {
1491+		CalculateMinMaxFp32(data + i * channel_length, channel_length, real_min + i, real_max + i);
1492+	}
1493+}
1494+
1495+int GetBucketIndex(int dims[], int dim_size, int prefer_axis, int data_index) {
1496+  int stride = 1;
1497+  int bucket_count = dims[prefer_axis];
1498+  for (int i = prefer_axis + 1; i < dim_size; i++) {
1499+    stride *= dims[i];
1500+  }
1501+  if (stride == 0 || bucket_count == 0) {
1502+    return 0;
1503+  }
1504+  return (data_index / stride) % bucket_count;
1505+}
1506diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h
1507index d4a63518..8fa0a9ed 100644
1508--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h
1509+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h
1510@@ -25,6 +25,8 @@
1511 extern "C" {
1512 #endif
1513 void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max);
1514+void CalculateAllChannelMinMax(const float *data, int count, float *real_min, float *real_max, int channel_length);
1515+int GetBucketIndex(int dims[], int dim_size, int prefer_axis, int data_index);
1516 #ifdef __cplusplus
1517 }
1518 #endif
1519diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c
1520index 25050fda..753aa5dd 100644
1521--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c
1522+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c
1523@@ -202,6 +202,140 @@ int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float s
1524   return NNACL_OK;
1525 }
1526
1527+#ifdef ENABLE_ARM64
1528+inline void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps,
1529+																			 int size, int channel_length, int32_t min_value, int32_t max_value) {
1530+	volatile float ivs[size];
1531+	for (int i = 0; i < size; i++) {
1532+		volatile int channel_index = i / channel_length;
1533+		ivs[i] = 1.0f / scales[channel_index];
1534+	}
1535+	volatile int32_t zp = zps[0];
1536+
1537+  asm volatile(
1538+    "mov w8, %w[size]\n"
1539+    "cmp w8, #0\n"
1540+    "beq 2f\n"
1541+
1542+		"mov x4, %[ivs]\n"           // reload ivs
1543+    "dup v13.4s, %w[min_value]\n"
1544+    "dup v14.4s, %w[max_value]\n"
1545+    "cmp w8, #16\n"
1546+    "blt 1f\n"
1547+    "0:\n"
1548+    "subs w8, w8, #16\n"
1549+    "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n"
1550+    "dup v8.4s, %w[zp]\n"
1551+    "dup v9.4s, %w[zp]\n"
1552+    "dup v10.4s, %w[zp]\n"
1553+    "dup v11.4s, %w[zp]\n"
1554+    "scvtf v4.4s, v8.4s\n"
1555+    "scvtf v5.4s, v9.4s\n"
1556+    "scvtf v6.4s, v10.4s\n"
1557+    "scvtf v7.4s, v11.4s\n"
1558+		"ld1 {v12.4s}, [x4], #16\n"
1559+    "fmla v4.4s, v0.4s, v12.4s\n"
1560+		"ld1 {v12.4s}, [x4], #16\n"
1561+    "fmla v5.4s, v1.4s, v12.4s\n"
1562+		"ld1 {v12.4s}, [x4], #16\n"
1563+    "fmla v6.4s, v2.4s, v12.4s\n"
1564+		"ld1 {v12.4s}, [x4], #16\n"
1565+    "fmla v7.4s, v3.4s, v12.4s\n"
1566+
1567+    "fcvtas v0.4s, v4.4s\n"
1568+    "fcvtas v1.4s, v5.4s\n"
1569+    "fcvtas v2.4s, v6.4s\n"
1570+    "fcvtas v3.4s, v7.4s\n"
1571+    "smax v0.4s, v0.4s, v13.4s\n"
1572+    "smax v1.4s, v1.4s, v13.4s\n"
1573+    "smax v2.4s, v2.4s, v13.4s\n"
1574+    "smax v3.4s, v3.4s, v13.4s\n"
1575+    "smin v0.4s, v0.4s, v14.4s\n"
1576+    "smin v1.4s, v1.4s, v14.4s\n"
1577+    "smin v2.4s, v2.4s, v14.4s\n"
1578+    "smin v3.4s, v3.4s, v14.4s\n"
1579+
1580+    "sqxtn v4.4h, v0.4s\n"
1581+    "sqxtn2 v4.8h, v1.4s\n"
1582+    "sqxtn v5.4h, v2.4s\n"
1583+    "sqxtn2 v5.8h, v3.4s\n"
1584+    "sqxtn v6.8b, v4.8h\n"
1585+    "sqxtn2 v6.16b, v5.8h\n"
1586+    "st1 {v6.16b}, [%[quant_values]], #16\n"
1587+
1588+    "beq 2f\n"
1589+    "cmp w8, #16\n"
1590+    "bge 0b\n"
1591+
1592+    "1:\n"
1593+    "scvtf s0, %w[zp]\n"
1594+    "subs w8, w8, #1\n"
1595+    "ldr s4, [%[real_values]], #4\n"
1596+    "fmul s4, s4, s12\n"
1597+    "fadd s0, s0, s4\n"
1598+    "fcvtas s0, s0\n"
1599+    "smax v0.4s, v0.4s, v13.4s\n"
1600+    "smin v0.4s, v0.4s, v14.4s\n"
1601+    "sqxtn v1.4h, v0.4s\n"
1602+    "sqxtn v0.8b, v1.8h\n"
1603+    "st1 {v0.b}[0], [%[quant_values]], #1\n"
1604+
1605+    "bne 1b\n"
1606+
1607+    "2:\n"
1608+    :
1609+    : [ quant_values ] "r"(quant_values), [ real_values ] "r"(real_values), [ scales ] "r"(scales), [ zp ] "r"(zp),
1610+      [ size ] "r"(size), [ channel_length ] "r"(channel_length), [ ivs ] "r"(ivs), [ min_value ] "r"(min_value),
1611+			[ max_value ] "r"(max_value)
1612+    : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "x4");
1613+}
1614+#endif
1615+
1616+int DoPerchannelQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size,
1617+																	 int channel_length, int32_t min_value, int32_t max_value) {
1618+	if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL) {
1619+		return NNACL_PARAM_INVALID;
1620+	}
1621+#ifdef ENABLE_ARM64
1622+	Fp32ToInt8Perchannel_arm64(real_values, quant_values, scale, zp, size, channel_length, min_value, max_value);
1623+#else
1624+	for (int i = 0; i < size; ++i) {
1625+		int channel_index = i / channel_length;
1626+		const float inverse_scale = 1.0f / scale[channel_index];
1627+		if (real_values[i] == INFINITY) {
1628+			quant_values[i] = max_value;
1629+		} else if (real_values[i] == -INFINITY) {
1630+			quant_values[i] = min_value;
1631+		} else {
1632+			int temp = round(real_values[i] * inverse_scale + zp[channel_index]);
1633+			temp = temp < max_value ? temp : max_value;
1634+			temp = temp > min_value ? temp : min_value;
1635+			quant_values[i] = (int8_t)temp;
1636+		}
1637+	}
1638+#endif
1639+	return NNACL_OK;
1640+}
1641+
1642+int QuantizeDataFp32ToInt8(const float real_value, int8_t *quant_value, float scale, int32_t zp, int32_t min_value,
1643+                           int32_t max_value) {
1644+  if (quant_value == NULL) {
1645+    return NNACL_PARAM_INVALID;
1646+  }
1647+  const float inverse_scale = 1.0f / scale;
1648+  if (real_value == INFINITY) {
1649+    *quant_value = max_value;
1650+  } else if (real_value == -INFINITY) {
1651+    *quant_value = min_value;
1652+  } else {
1653+    int temp = round(real_value * inverse_scale + zp);
1654+    temp = temp < max_value ? temp : max_value;
1655+    temp = temp > min_value ? temp : min_value;
1656+    *quant_value = (int8_t)temp;
1657+  }
1658+  return NNACL_OK;
1659+}
1660+
1661 int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp,
1662                                         int size, int32_t min_value, int32_t max_value) {
1663   if (quant_values == NULL || real_values == NULL) {
1664diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h
1665index 251e9716..950b4287 100644
1666--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h
1667+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h
1668@@ -31,12 +31,18 @@ extern "C" {
1669 int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
1670 int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size,
1671                          int32_t min_value, int32_t max_value);
1672+int DoPerchannelQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size,
1673+																	 int channel_length, int32_t min_value, int32_t max_value);
1674+int QuantizeDataFp32ToInt8(const float real_value, int8_t *quant_value, float scale, int32_t zp, int32_t min_value,
1675+                           int32_t max_value);
1676 int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp,
1677                                         int size, int32_t min_value, int32_t max_value);
1678 #ifdef ENABLE_ARM64
1679 void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size,
1680                       int32_t min_value, int32_t max_value);
1681 void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size);
1682+void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps,
1683+																			 int size, int channel_length, int32_t min_value, int32_t max_value);
1684 #endif
1685 int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size);
1686 int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size);
1687diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h
1688index 1f1913e1..8116ac58 100644
1689--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h
1690+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h
1691@@ -1,5 +1,5 @@
1692 /**
1693- * Copyright 2020 Huawei Technologies Co., Ltd
1694+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
1695  *
1696  * Licensed under the Apache License, Version 2.0 (the "License");
1697  * you may not use this file except in compliance with the License.
1698@@ -35,6 +35,16 @@ typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst
1699
1700 typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2, OutType_NC4HW4 = 3 } OutType;
1701
1702+typedef enum MatmulType {
1703+  // reserve 0 for base op
1704+  kNotImplemented = 0,
1705+  kMatmulInt8Cpu,
1706+  kMatmulDynamicInt8Cpu,
1707+  kMatmulDynamicSdotInt8Cpu,
1708+  kMatmulFp32BaseCpu,
1709+  kMatmulFp32Arm64Cpu,
1710+} MatmulType;
1711+
1712 typedef struct MatMulParameter {
1713   // Primitive parameter
1714   OpParameter op_parameter_;
1715@@ -63,6 +73,7 @@ typedef struct MatMulParameter {
1716   ActType act_type_;
1717   bool use_axis_;
1718   int axis_;
1719+  MatmulType matmul_type_;
1720 } MatMulParameter;
1721
1722 typedef struct MatmulQuantParameter {
1723@@ -79,8 +90,8 @@ typedef struct MatmulQuantParameter {
1724 } MatmulQuantParameter;
1725
1726 typedef struct MatmulDynamicQuantParameter {
1727-  float input_scale_;
1728-  int32_t input_zp_;
1729+  float *input_scale_;
1730+  int32_t *input_zp_;
1731   float *filter_scale_;
1732   int32_t *filter_zp_;
1733 } MatmulDynamicQuantParameter;
1734diff --git a/mindspore/core/ops/dynamic_quant.cc b/mindspore/core/ops/dynamic_quant.cc
1735index 77cadbc0..d11ee4ff 100644
1736--- a/mindspore/core/ops/dynamic_quant.cc
1737+++ b/mindspore/core/ops/dynamic_quant.cc
1738@@ -30,9 +30,28 @@ bool DynamicQuant::get_symmetric() const {
1739 }
1740 void DynamicQuant::set_dst_type(const int64_t dst_type) { (void)AddAttr(kDstType, api::MakeValue(dst_type)); }
1741 int64_t DynamicQuant::get_dst_type() const { return GetValue<int64_t>(GetAttr(kDstType)); }
1742+void DynamicQuant::set_prefer_axis(const int64_t prefer_axis) {
1743+  (void)AddAttr(kPreferAxis, api::MakeValue(prefer_axis));
1744+}
1745+int64_t DynamicQuant::get_prefer_axis() const { return GetValue<int64_t>(GetAttr(kPreferAxis)); }
1746+void DynamicQuant::set_activation_perchannel(const bool activation_perchannel) {
1747+  (void)AddAttr(kActivationPerchannel, api::MakeValue(activation_perchannel));
1748+}
1749+bool DynamicQuant::get_activation_perchannel() const {
1750+  auto value_ptr = this->GetAttr(kActivationPerchannel);
1751+  return GetValue<bool>(value_ptr);
1752+}
1753+void DynamicQuant::set_transpose(const bool transpose) { (void)AddAttr(kTrans, api::MakeValue(transpose)); }
1754+bool DynamicQuant::get_transpose() const {
1755+  auto value_ptr = this->GetAttr(kTrans);
1756+  return GetValue<bool>(value_ptr);
1757+}
1758 void DynamicQuant::Init(const bool symmetric, const int64_t dst_type) {
1759   this->set_symmetric(symmetric);
1760   this->set_dst_type(dst_type);
1761+  this->set_activation_perchannel(false);
1762+  this->set_prefer_axis(0);
1763+  this->set_transpose(false);
1764 }
1765
1766 REGISTER_PRIMITIVE_C(kNameDynamicQuant, DynamicQuant);
1767diff --git a/mindspore/core/ops/dynamic_quant.h b/mindspore/core/ops/dynamic_quant.h
1768index ade36b4f..e7f1b7e6 100644
1769--- a/mindspore/core/ops/dynamic_quant.h
1770+++ b/mindspore/core/ops/dynamic_quant.h
1771@@ -61,6 +61,36 @@ class MIND_API DynamicQuant : public BaseOperator {
1772   ///
1773   /// \return the data type of output.
1774   int64_t get_dst_type() const;
1775+
1776+  /// \brief Method to set prefer_axis attribute.
1777+  ///
1778+  /// \param[in] prefer_axis Define the preferred axis.
1779+  void set_prefer_axis(const int64_t prefer_axis);
1780+
1781+  /// \brief Method to get prefer_axis attribute.
1782+  ///
1783+  /// \return the preferred axis.
1784+  int64_t get_prefer_axis() const;
1785+
1786+  /// \brief Method to set activation perchannel attribute.
1787+  ///
1788+  /// \param[in] activation_perchannel Define whether activation perchannel quantization.
1789+  void set_activation_perchannel(const bool activation_perchannel);
1790+
1791+  /// \brief Method to get activation perchannel attribute.
1792+  ///
1793+  /// \return Whether activation perchannel quantization.
1794+  bool get_activation_perchannel() const;
1795+
1796+  /// \brief Method to set transpose attribute.
1797+  ///
1798+  /// \param[in] symmetric Define whether transpose matrix.
1799+  void set_transpose(const bool transpose);
1800+
1801+  /// \brief Method to get transpose attribute.
1802+  ///
1803+  /// \return Whether transpose matrix.
1804+  bool get_transpose() const;
1805 };
1806 abstract::AbstractBasePtr DynamicQuantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
1807                                             const std::vector<abstract::AbstractBasePtr> &input_args);
1808diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h
1809index 7a509840..8dc2e3f4 100644
1810--- a/mindspore/core/ops/op_name.h
1811+++ b/mindspore/core/ops/op_name.h
1812@@ -22,6 +22,7 @@ namespace mindspore::ops {
1813 constexpr auto kAlpha = "alpha";
1814 constexpr auto kActivation = "activation";
1815 constexpr auto kActivationType = "activation_type";
1816+constexpr auto kActivationPerchannel = "activation_perchannel";
1817 constexpr auto kAttentionQActType = "attention_q_act_type";
1818 constexpr auto kAttentionKActType = "attention_k_act_type";
1819 constexpr auto kAttentionVActType = "attention_v_act_type";
1820@@ -178,6 +179,7 @@ constexpr auto kDivisorOverride = "divisor_override";
1821 constexpr auto kPostNmsTopn = "post_nms_topn";
1822 constexpr auto kPower = "power";
1823 constexpr auto kPreNmsTopn = "pre_nms_topn";
1824+constexpr auto kPreferAxis = "prefer_axis";
1825 constexpr auto kRankSize = "rank_size";
1826 constexpr auto kRatio = "ratio";
1827 constexpr auto kReduction = "reduction";
1828@@ -209,6 +211,7 @@ constexpr auto kSummarize = "summarize";
1829 constexpr auto kTimeMajor = "time_major";
1830 constexpr auto kTolerance = "tolerance";
1831 constexpr auto kTopK = "top_k";
1832+constexpr auto kTrans = "trans";
1833 constexpr auto kTransposeA = "transpose_a";
1834 constexpr auto kTransposeB = "transpose_b";
1835 constexpr auto kNegativeSlope = "negative_slope";
1836diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn
1837index a4d77b1c..86b80a28 100644
1838--- a/mindspore/lite/BUILD.gn
1839+++ b/mindspore/lite/BUILD.gn
1840@@ -142,11 +142,13 @@ all_lite_sources = [
1841   "src/common/utils.cc",
1842   "src/common/graph_util.cc",
1843   "src/common/log.cc",
1844+  "src/common/mmap_utils.cc",
1845   "src/common/prim_util.cc",
1846   "src/common/tensor_util.cc",
1847   "src/runtime/allocator.cc",
1848   "src/runtime/inner_allocator.cc",
1849   "src/runtime/runtime_allocator.cc",
1850+  "src/runtime/runtime_packed_node_pass.cc",
1851   "src/runtime/infer_manager.cc",
1852   "src/runtime/runtime_shape_fusion_pass.cc",
1853   "src/runtime/runtime_pass.cc",
1854diff --git a/mindspore/lite/schema/inner/ops_generated.h b/mindspore/lite/schema/inner/ops_generated.h
1855index e0614168..86fdbad1 100644
1856--- a/mindspore/lite/schema/inner/ops_generated.h
1857+++ b/mindspore/lite/schema/inner/ops_generated.h
1858@@ -19484,6 +19484,9 @@ struct DynamicQuantT : public flatbuffers::NativeTable {
1859   typedef DynamicQuant TableType;
1860   bool symmetric = false;
1861   int64_t dst_type = 32LL;
1862+  bool activation_perchannel = false;
1863+  int64_t prefer_axis = 0;
1864+  bool transpose = false;
1865 };
1866
1867 struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
1868@@ -19494,7 +19497,10 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
1869   }
1870   enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
1871     VT_SYMMETRIC = 4,
1872-    VT_DST_TYPE = 6
1873+    VT_DST_TYPE = 6,
1874+    VT_ACTIVATION_PERCHANNEL = 8,
1875+    VT_PREFER_AXIS = 10,
1876+    VT_TRANSPOSE = 12
1877   };
1878   bool symmetric() const {
1879     return GetField<uint8_t>(VT_SYMMETRIC, 0) != 0;
1880@@ -19508,10 +19514,31 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
1881   bool mutate_dst_type(int64_t _dst_type) {
1882     return SetField<int64_t>(VT_DST_TYPE, _dst_type, 32LL);
1883   }
1884+  bool activation_perchannel() const {
1885+    return GetField<uint8_t>(VT_ACTIVATION_PERCHANNEL, 0) != 0;
1886+  }
1887+  bool mutate_activation_perchannel(bool _activation_perchannel) {
1888+    return SetField<uint8_t>(VT_ACTIVATION_PERCHANNEL, static_cast<uint8_t>(_activation_perchannel), 0);
1889+  }
1890+  int64_t prefer_axis() const {
1891+    return GetField<int64_t>(VT_PREFER_AXIS, 0);
1892+  }
1893+  bool mutate_prefer_axis(int64_t _prefer_axis) {
1894+    return SetField<int64_t>(VT_PREFER_AXIS, _prefer_axis, 0);
1895+  }
1896+  bool transpose() const {
1897+    return GetField<uint8_t>(VT_TRANSPOSE, 0) != 0;
1898+  }
1899+  bool mutate_transpose(bool _transpose) {
1900+    return SetField<uint8_t>(VT_TRANSPOSE, static_cast<uint8_t>(_transpose), 0);
1901+  }
1902   bool Verify(flatbuffers::Verifier &verifier) const {
1903     return VerifyTableStart(verifier) &&
1904            VerifyField<uint8_t>(verifier, VT_SYMMETRIC) &&
1905            VerifyField<int64_t>(verifier, VT_DST_TYPE) &&
1906+           VerifyField<uint8_t>(verifier, VT_ACTIVATION_PERCHANNEL) &&
1907+           VerifyField<int64_t>(verifier, VT_PREFER_AXIS) &&
1908+           VerifyField<uint8_t>(verifier, VT_TRANSPOSE) &&
1909            verifier.EndTable();
1910   }
1911   DynamicQuantT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
1912@@ -19529,6 +19556,15 @@ struct DynamicQuantBuilder {
1913   void add_dst_type(int64_t dst_type) {
1914     fbb_.AddElement<int64_t>(DynamicQuant::VT_DST_TYPE, dst_type, 32LL);
1915   }
1916+  void add_activation_perchannel(bool activation_perchannel) {
1917+    fbb_.AddElement<uint8_t>(DynamicQuant::VT_ACTIVATION_PERCHANNEL, static_cast<uint8_t>(activation_perchannel), 0);
1918+  }
1919+  void add_prefer_axis(int64_t prefer_axis) {
1920+    fbb_.AddElement<int64_t>(DynamicQuant::VT_PREFER_AXIS, prefer_axis, 0);
1921+  }
1922+  void add_transpose(bool transpose) {
1923+    fbb_.AddElement<uint8_t>(DynamicQuant::VT_TRANSPOSE, static_cast<uint8_t>(transpose), 0);
1924+  }
1925   explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb)
1926         : fbb_(_fbb) {
1927     start_ = fbb_.StartTable();
1928@@ -19543,9 +19579,15 @@ struct DynamicQuantBuilder {
1929 inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuant(
1930     flatbuffers::FlatBufferBuilder &_fbb,
1931     bool symmetric = false,
1932-    int64_t dst_type = 32LL) {
1933+    int64_t dst_type = 32LL,
1934+    bool activation_perchannel = false,
1935+    int64_t prefer_axis = 0,
1936+    bool transpose = false) {
1937   DynamicQuantBuilder builder_(_fbb);
1938+  builder_.add_prefer_axis(prefer_axis);
1939   builder_.add_dst_type(dst_type);
1940+  builder_.add_transpose(transpose);
1941+  builder_.add_activation_perchannel(activation_perchannel);
1942   builder_.add_symmetric(symmetric);
1943   return builder_.Finish();
1944 }
1945@@ -26124,6 +26166,9 @@ inline void DynamicQuant::UnPackTo(DynamicQuantT *_o, const flatbuffers::resolve
1946   (void)_resolver;
1947   { auto _e = symmetric(); _o->symmetric = _e; }
1948   { auto _e = dst_type(); _o->dst_type = _e; }
1949+  { auto _e = activation_perchannel(); _o->activation_perchannel = _e; }
1950+  { auto _e = prefer_axis(); _o->prefer_axis = _e; }
1951+  { auto _e = transpose(); _o->transpose = _e; }
1952 }
1953
1954 inline flatbuffers::Offset<DynamicQuant> DynamicQuant::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DynamicQuantT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
1955@@ -26136,10 +26181,16 @@ inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuant(flatbuffers::FlatBuf
1956   struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DynamicQuantT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
1957   auto _symmetric = _o->symmetric;
1958   auto _dst_type = _o->dst_type;
1959+  auto _activation_perchannel = _o->activation_perchannel;
1960+  auto _prefer_axis = _o->prefer_axis;
1961+  auto _transpose = _o->transpose;
1962   return mindspore::schema::CreateDynamicQuant(
1963       _fbb,
1964       _symmetric,
1965-      _dst_type);
1966+      _dst_type,
1967+      _activation_perchannel,
1968+      _prefer_axis,
1969+      _transpose);
1970 }
1971
1972 inline LSTMGradDataT *LSTMGradData::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
1973@@ -33528,10 +33579,13 @@ inline const flatbuffers::TypeTable *ReduceScatterTypeTable() {
1974 inline const flatbuffers::TypeTable *DynamicQuantTypeTable() {
1975   static const flatbuffers::TypeCode type_codes[] = {
1976     { flatbuffers::ET_BOOL, 0, -1 },
1977-    { flatbuffers::ET_LONG, 0, -1 }
1978+    { flatbuffers::ET_LONG, 0, -1 },
1979+    { flatbuffers::ET_BOOL, 0, -1 },
1980+    { flatbuffers::ET_LONG, 0, -1 },
1981+    { flatbuffers::ET_BOOL, 0, -1 }
1982   };
1983   static const flatbuffers::TypeTable tt = {
1984-    flatbuffers::ST_TABLE, 2, type_codes, nullptr, nullptr, nullptr, nullptr
1985+    flatbuffers::ST_TABLE, 5, type_codes, nullptr, nullptr, nullptr, nullptr
1986   };
1987   return &tt;
1988 }
1989diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs
1990index 023ee7e5..32775bac 100644
1991--- a/mindspore/lite/schema/ops.fbs
1992+++ b/mindspore/lite/schema/ops.fbs
1993@@ -1231,6 +1231,9 @@ table ReduceScatter {
1994 table DynamicQuant {
1995     symmetric: bool = false;
1996     dst_type: long = 32;
1997+    activation_perchannel: bool = false;
1998+    prefer_axis: long = 0;
1999+    transpose: bool = false;
2000 }
2001
2002 table LSTMGradData {
2003diff --git a/mindspore/lite/schema/ops_generated.h b/mindspore/lite/schema/ops_generated.h
2004index 5b15211a..393cefcd 100644
2005--- a/mindspore/lite/schema/ops_generated.h
2006+++ b/mindspore/lite/schema/ops_generated.h
2007@@ -12939,7 +12939,10 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
2008   typedef DynamicQuantBuilder Builder;
2009   enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
2010     VT_SYMMETRIC = 4,
2011-    VT_DST_TYPE = 6
2012+    VT_DST_TYPE = 6,
2013+    VT_ACTIVATION_PERCHANNEL = 8,
2014+    VT_PREFER_AXIS = 10,
2015+    VT_TRANSPOSE = 12
2016   };
2017   bool symmetric() const {
2018     return GetField<uint8_t>(VT_SYMMETRIC, 0) != 0;
2019@@ -12947,10 +12950,22 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
2020   int64_t dst_type() const {
2021     return GetField<int64_t>(VT_DST_TYPE, 32LL);
2022   }
2023+  bool activation_perchannel() const {
2024+    return GetField<uint8_t>(VT_ACTIVATION_PERCHANNEL, 0) != 0;
2025+  }
2026+  int64_t prefer_axis() const {
2027+    return GetField<int64_t>(VT_PREFER_AXIS, 0);
2028+  }
2029+  bool transpose() const {
2030+    return GetField<uint8_t>(VT_TRANSPOSE, 0) != 0;
2031+  }
2032   bool Verify(flatbuffers::Verifier &verifier) const {
2033     return VerifyTableStart(verifier) &&
2034            VerifyField<uint8_t>(verifier, VT_SYMMETRIC) &&
2035            VerifyField<int64_t>(verifier, VT_DST_TYPE) &&
2036+           VerifyField<uint8_t>(verifier, VT_ACTIVATION_PERCHANNEL) &&
2037+           VerifyField<int64_t>(verifier, VT_PREFER_AXIS) &&
2038+           VerifyField<uint8_t>(verifier, VT_TRANSPOSE) &&
2039            verifier.EndTable();
2040   }
2041 };
2042@@ -12965,6 +12980,15 @@ struct DynamicQuantBuilder {
2043   void add_dst_type(int64_t dst_type) {
2044     fbb_.AddElement<int64_t>(DynamicQuant::VT_DST_TYPE, dst_type, 32LL);
2045   }
2046+  void add_activation_perchannel(bool activation_perchannel) {
2047+    fbb_.AddElement<uint8_t>(DynamicQuant::VT_ACTIVATION_PERCHANNEL, static_cast<uint8_t>(activation_perchannel), 0);
2048+  }
2049+  void add_prefer_axis(int64_t prefer_axis) {
2050+    fbb_.AddElement<int64_t>(DynamicQuant::VT_PREFER_AXIS, prefer_axis, 0);
2051+  }
2052+  void add_transpose(bool transpose) {
2053+    fbb_.AddElement<uint8_t>(DynamicQuant::VT_TRANSPOSE, static_cast<uint8_t>(transpose), 0);
2054+  }
2055   explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb)
2056         : fbb_(_fbb) {
2057     start_ = fbb_.StartTable();
2058@@ -12979,9 +13003,15 @@ struct DynamicQuantBuilder {
2059 inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuant(
2060     flatbuffers::FlatBufferBuilder &_fbb,
2061     bool symmetric = false,
2062-    int64_t dst_type = 32LL) {
2063+    int64_t dst_type = 32LL,
2064+    bool activation_perchannel = false,
2065+    int64_t prefer_axis = 0,
2066+    bool transpose = false) {
2067   DynamicQuantBuilder builder_(_fbb);
2068+  builder_.add_prefer_axis(prefer_axis);
2069   builder_.add_dst_type(dst_type);
2070+  builder_.add_transpose(transpose);
2071+  builder_.add_activation_perchannel(activation_perchannel);
2072   builder_.add_symmetric(symmetric);
2073   return builder_.Finish();
2074 }
2075diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
2076index 31941fb1..d28c30d9 100644
2077--- a/mindspore/lite/src/CMakeLists.txt
2078+++ b/mindspore/lite/src/CMakeLists.txt
2079@@ -111,6 +111,7 @@ set(LITE_SRC
2080         ${API_SRC}
2081         ${CMAKE_CURRENT_SOURCE_DIR}/common/context_util.cc
2082         ${CMAKE_CURRENT_SOURCE_DIR}/common/file_utils.cc
2083+        ${CMAKE_CURRENT_SOURCE_DIR}/common/mmap_utils.cc
2084         ${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cc
2085         ${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
2086         ${CMAKE_CURRENT_SOURCE_DIR}/common/log.cc
2087@@ -137,6 +138,7 @@ set(LITE_SRC
2088         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/sub_graph_kernel.cc
2089         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/scheduler.cc
2090         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cc
2091+        ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_packed_node_pass.cc
2092         ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
2093         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_info.cc
2094         ${CMAKE_CURRENT_SOURCE_DIR}/runtime/pack_weight_manager.cc
2095diff --git a/mindspore/lite/src/common/mmap_utils.cc b/mindspore/lite/src/common/mmap_utils.cc
2096new file mode 100644
2097index 00000000..ca8f8d1e
2098--- /dev/null
2099+++ b/mindspore/lite/src/common/mmap_utils.cc
2100@@ -0,0 +1,63 @@
2101+/**
2102+ * Copyright 2023 Huawei Technologies Co., Ltd
2103+ *
2104+ * Licensed under the Apache License, Version 2.0 (the "License");
2105+ * you may not use this file except in compliance with the License.
2106+ * You may obtain a copy of the License at
2107+ *
2108+ * http://www.apache.org/licenses/LICENSE-2.0
2109+ *
2110+ * Unless required by applicable law or agreed to in writing, software
2111+ * distributed under the License is distributed on an "AS IS" BASIS,
2112+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2113+ * See the License for the specific language governing permissions and
2114+ * limitations under the License.
2115+ */
2116+
2117+#include "src/common/mmap_utils.h"
2118+#include "src/common/file_utils.h"
2119+#if !defined(_WIN32) && !defined(_WIN64)
2120+#include <sys/mman.h>
2121+#include <fcntl.h>
2122+#include <sys/stat.h>
2123+#endif
2124+
2125+namespace mindspore {
2126+namespace lite {
2127+void *ReadFileByMmap(const std::string &file, size_t *size) {
2128+#if !defined(_WIN32) && !defined(_WIN64) && !defined(MS_COMPILE_IOS)
2129+  auto real_path = RealPath(file.c_str());
2130+  auto fd = open(real_path.c_str(), O_RDONLY);
2131+  if (fd == -1) {
2132+    MS_LOG(ERROR) << "Could not open " << file;
2133+    return nullptr;
2134+  }
2135+  struct stat fd_stat;
2136+  if (fstat(fd, &fd_stat) != 0) {
2137+    MS_LOG(ERROR) << "Get fd stat error.";
2138+    close(fd);
2139+    return nullptr;
2140+  }
2141+  *size = fd_stat.st_size;
2142+  auto mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED | MAP_POPULATE, fd, 0);
2143+  close(fd);
2144+  if (mmap_buffers == MAP_FAILED) {
2145+    MS_LOG(ERROR) << "Model mmap failed.";
2146+    return nullptr;
2147+  }
2148+  return mmap_buffers;
2149+#else
2150+  MS_LOG(ERROR) << "Mmap is unsupported on windows.";
2151+  return nullptr;
2152+#endif
2153+}
2154+
2155+void UnmapMmapBuffer(void *buffer, size_t size) {
2156+#if !defined(_WIN32) && !defined(_WIN64)
2157+  (void)munmap(buffer, size);
2158+#else
2159+  MS_LOG(ERROR) << "Mmap is unsupported on windows.";
2160+#endif
2161+}
2162+}  // namespace lite
2163+}  // namespace mindspore
2164diff --git a/mindspore/lite/src/common/mmap_utils.h b/mindspore/lite/src/common/mmap_utils.h
2165new file mode 100644
2166index 00000000..bdd7c9a5
2167--- /dev/null
2168+++ b/mindspore/lite/src/common/mmap_utils.h
2169@@ -0,0 +1,27 @@
2170+/**
2171+ * Copyright 2023 Huawei Technologies Co., Ltd
2172+ *
2173+ * Licensed under the Apache License, Version 2.0 (the "License");
2174+ * you may not use this file except in compliance with the License.
2175+ * You may obtain a copy of the License at
2176+ *
2177+ * http://www.apache.org/licenses/LICENSE-2.0
2178+ *
2179+ * Unless required by applicable law or agreed to in writing, software
2180+ * distributed under the License is distributed on an "AS IS" BASIS,
2181+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2182+ * See the License for the specific language governing permissions and
2183+ * limitations under the License.
2184+ */
2185+#ifndef MINDSPORE_LITE_SRC_COMMON_MMAP_UTILS_H_
2186+#define MINDSPORE_LITE_SRC_COMMON_MMAP_UTILS_H_
2187+
2188+#include <string>
2189+
2190+namespace mindspore {
2191+namespace lite {
2192+void *ReadFileByMmap(const std::string &file, size_t *size);
2193+void UnmapMmapBuffer(void *buffer, size_t size);
2194+}  // namespace lite
2195+}  // namespace mindspore
2196+#endif
2197diff --git a/mindspore/lite/src/common/ops/ops_def.cc b/mindspore/lite/src/common/ops/ops_def.cc
2198index 011c37df..cfab3113 100644
2199--- a/mindspore/lite/src/common/ops/ops_def.cc
2200+++ b/mindspore/lite/src/common/ops/ops_def.cc
2201@@ -1231,6 +1231,9 @@ OP_SCHEMA_DEF_END(ReduceScatter)
2202 OP_SCHEMA_DEF(DynamicQuant)
2203 OP_ATTR_WITH_VALUE(symmetric, bool, false)
2204 OP_ATTR_WITH_VALUE(dst_type, long, 32)
2205+OP_ATTR_WITH_VALUE(activation_perchannel, bool, false)
2206+OP_ATTR_WITH_VALUE(prefer_axis, long, 0)
2207+OP_ATTR_WITH_VALUE(transpose, bool, false)
2208 OP_SCHEMA_DEF_END(DynamicQuant)
2209
2210 OP_SCHEMA_DEF(LSTMGradData)
2211diff --git a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc
2212index b7e62e6c..fe8a939e 100644
2213--- a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc
2214+++ b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc
2215@@ -38,6 +38,9 @@ OpParameter *PopulateDynamicQuantParameter(const void *prim) {
2216   param->op_parameter_.type_ = primitive->value_type();
2217   param->dst_type_ = value->dst_type();
2218   param->symmetric_ = value->symmetric();
2219+  param->activation_perchannel_ = value->activation_perchannel();
2220+  param->prefer_axis_ = value->prefer_axis();
2221+  param->transpose_ = value->transpose();
2222   return reinterpret_cast<OpParameter *>(param);
2223 }
2224 REG_POPULATE(PrimitiveType_DynamicQuant, PopulateDynamicQuantParameter, SCHEMA_CUR);
2225diff --git a/mindspore/lite/src/common/primitive_t_utils.cc b/mindspore/lite/src/common/primitive_t_utils.cc
2226index ad406562..db7c7ef0 100644
2227--- a/mindspore/lite/src/common/primitive_t_utils.cc
2228+++ b/mindspore/lite/src/common/primitive_t_utils.cc
2229@@ -22,7 +22,7 @@
2230 namespace mindspore {
2231 namespace lite {
2232 constexpr size_t INITIAL_SIZE = 1024;
2233-const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
2234+const schema::Primitive *ConvertToPrimitive(const schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) {
2235   if (primitive_t == nullptr || fbb == nullptr) {
2236     MS_LOG(ERROR) << "primitiveT or fbb is nullptr.";
2237     return nullptr;
2238@@ -71,6 +71,18 @@ std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const std::shared_ptr<mindspor
2239     return nullptr;
2240   }
2241 }
2242+
2243+schema::PrimitiveType GetSchemaPrimType(const schema::PrimitiveT *primitive_t) {
2244+  flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE);
2245+  auto primitive = ConvertToPrimitive(primitive_t, &fbb);
2246+  if (primitive == nullptr) {
2247+    MS_LOG(ERROR) << "Failed to convert Primitive.";
2248+    return schema::PrimitiveType::PrimitiveType_NONE;
2249+  }
2250+  fbb.Clear();
2251+  int prim_type = GetPrimitiveType(primitive, static_cast<int>(SCHEMA_VERSION::SCHEMA_CUR));
2252+  return static_cast<schema::PrimitiveType>(prim_type);
2253+}
2254 }  // namespace lite
2255 }  // namespace mindspore
2256 #endif
2257diff --git a/mindspore/lite/src/common/primitive_t_utils.h b/mindspore/lite/src/common/primitive_t_utils.h
2258index 7fe3e781..dba02777 100644
2259--- a/mindspore/lite/src/common/primitive_t_utils.h
2260+++ b/mindspore/lite/src/common/primitive_t_utils.h
2261@@ -24,9 +24,10 @@
2262
2263 namespace mindspore {
2264 namespace lite {
2265-const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
2266+const schema::Primitive *ConvertToPrimitive(const schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb);
2267 OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t);
2268 std::unique_ptr<schema::PrimitiveT> GetPrimitiveT(const std::shared_ptr<mindspore::ops::BaseOperator> &op);
2269+schema::PrimitiveType GetSchemaPrimType(const schema::PrimitiveT *primitive_t);
2270 }  // namespace lite
2271 }  // namespace mindspore
2272 #endif
2273diff --git a/mindspore/lite/src/runtime/inner_context.h b/mindspore/lite/src/runtime/inner_context.h
2274index ff58995f..52537e93 100644
2275--- a/mindspore/lite/src/runtime/inner_context.h
2276+++ b/mindspore/lite/src/runtime/inner_context.h
2277@@ -32,6 +32,13 @@
2278 #endif
2279
2280 namespace mindspore::lite {
2281+typedef struct InstructionsContext {
2282+  // Instructions should be checked in the beginning.
2283+  bool support_fp16 = false;
2284+  bool support_sdot = false;
2285+  bool support_sse = false;
2286+  bool support_avx512 = false;
2287+} InstructionsContext;
2288 #ifdef ENABLE_MINDRT
2289 constexpr int kDefaultParallelNum = 2;
2290 #endif
2291@@ -77,6 +84,8 @@ struct MS_API InnerContext : public Context {
2292
2293   void ReplaceLinkInfoSenderWithNewOne(void *new_sender, void *old_sender);
2294
2295+  InstructionsContext instructions_ctx_;
2296+
2297  private:
2298   bool IsAllDeviceTypeValid() const;
2299
2300diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc
2301index 41b8b58b..f25bf288 100644
2302--- a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc
2303+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc
2304@@ -47,6 +47,13 @@ int DynamicQuantCPUKernel::Prepare() {
2305   src_dtype_ = in_tensor->data_type();
2306   dst_dtype_ = param->dst_type_;
2307   symmetric_ = param->symmetric_;
2308+  activation_perchannel_ = param->activation_perchannel_;
2309+  prefer_axis_ = param->prefer_axis_;
2310+  transpose_ = param->transpose_;
2311+  // shape_size_ = in_tensor->shape().size();
2312+  // for (int i = 0; i < shape_size_; i++) {
2313+  //   input_shape_[i] = in_tensor->shape().at(i);
2314+  // }
2315   if (out_tensor->data_type() != dst_dtype_) {
2316     MS_LOG(ERROR) << "param data type and tensor data type do not match.";
2317     return RET_ERROR;
2318@@ -68,10 +75,33 @@ int DynamicQuantCPUKernel::ReSize() {
2319     // Limit for 8 thread
2320     thread_n_num_ = MSMIN(thread_n_num_, kBucketNums);
2321   }
2322-  for (int i = 0; i < kBucketNums; ++i) {
2323-    real_min_array_[i] = FLT_MAX;
2324-    real_max_array_[i] = FLT_MIN;
2325+
2326+  int min_max_array_size = 0;
2327+  if (activation_perchannel_) {
2328+    auto dims = in_tensor->shape();
2329+    if (prefer_axis_ < 0) {
2330+      prefer_axis_ += dims.size();
2331+    }
2332+    channel_num_ = dims[prefer_axis_];
2333+    MS_CHECK_GT(channel_num_, 0, RET_ERROR);
2334+    channel_length_ = num_unit_ / channel_num_;
2335+    thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
2336+		if (channel_length_ > thread_n_stride_) {
2337+			thread_n_num_ = 1;
2338+		}
2339+		min_max_array_size = channel_num_;
2340+  } else {
2341+    min_max_array_size = kBucketNums;
2342   }
2343+  real_min_ = (float *)malloc(min_max_array_size * sizeof(float));
2344+	real_max_ = (float *)malloc(min_max_array_size * sizeof(float));
2345+	if (real_min_ == nullptr || real_max_ == nullptr) {
2346+		return RET_NULL_PTR;
2347+	}
2348+	for (int i = 0; i < min_max_array_size; ++i) {
2349+		real_min_[i] = FLT_MAX;
2350+		real_max_[i] = -FLT_MAX;
2351+	}
2352   MS_CHECK_GT(thread_n_num_, 0, RET_ERROR);
2353   thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
2354   return RET_OK;
2355@@ -84,8 +114,20 @@ int DynamicQuantCPUKernel::CalculateMinMax(int task_id) {
2356   }
2357   int thread_offset = task_id * thread_n_stride_;
2358   float *data = float32_ptr_ + thread_offset;
2359-
2360-  CalculateMinMaxFp32(data, num_unit_thread, &real_min_array_[task_id], &real_max_array_[task_id]);
2361+  if (activation_perchannel_) {
2362+    int channel_offset = task_id * thread_n_stride_ / channel_length_;
2363+    float *real_min = real_min_ + channel_offset;
2364+		float *real_max = real_max_ + channel_offset;
2365+    if (!transpose_) {
2366+			CalculateAllChannelMinMax(data, num_unit_thread, real_min, real_max, channel_length_);
2367+		} else {
2368+			MS_LOG(ERROR) << "Matrix a transpose not supported.";
2369+		}
2370+  } else {
2371+    float *real_min = real_min_ + task_id;
2372+    float *real_max = real_max_ + task_id;
2373+    CalculateMinMaxFp32(data, num_unit_thread, real_min, real_max);
2374+  }
2375   return RET_OK;
2376 }
2377
2378@@ -100,34 +142,33 @@ int CalculateMinMaxRun(void *cdata, int task_id, float, float) {
2379   return RET_OK;
2380 }
2381
2382-void DynamicQuantCPUKernel::ReduceMinMaxFp32() {
2383+void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() {
2384+  float real_min = FLT_MAX;
2385+  float real_max = -FLT_MAX;
2386   for (int i = 0; i < kBucketNums; i++) {
2387-    if (real_min_array_[i] < real_min_) {
2388-      real_min_ = real_min_array_[i];
2389-    }
2390-    if (real_max_array_[i] > real_max_) {
2391-      real_max_ = real_max_array_[i];
2392-    }
2393+		if (real_min_[i] < real_min) {
2394+			real_min = real_min_[i];
2395+		}
2396+		if (real_max_[i] > real_max) {
2397+			real_max = real_max_[i];
2398+		}
2399   }
2400-  return;
2401-}
2402
2403-void DynamicQuantCPUKernel::CalculateScaleZp() {
2404   lite::LiteQuantParam quant_parm;
2405   double scale;
2406   int zp = 0;
2407   constexpr int kQSymmetricRange = 255;
2408   constexpr int kQAsymmetricRange = 254;
2409   if (!symmetric_) {
2410-    auto range = real_max_ - real_min_;
2411+    auto range = real_max - real_min;
2412     if (range <= 0) {
2413       range = kDefaultRange;
2414       MS_LOG(WARNING) << name_ << " range is 0 and set the range to 0.01.";
2415     }
2416     scale = range / kQSymmetricRange;  // -128 ~ 127
2417-    zp = static_cast<int>(std::round(INT8_MIN - real_min_ / scale));
2418+    zp = static_cast<int>(std::round(INT8_MIN - real_min / scale));
2419   } else {
2420-    auto max = std::max(abs(real_max_), abs(real_min_));
2421+    auto max = std::max(abs(real_max), abs(real_min));
2422     scale = 2 * max / kQAsymmetricRange;  // -127 ~ 127
2423   }
2424   quant_parm.scale = scale;
2425@@ -138,27 +179,87 @@ void DynamicQuantCPUKernel::CalculateScaleZp() {
2426   return;
2427 }
2428
2429+void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() {
2430+  std::vector<lite::LiteQuantParam> quant_params;
2431+  for (int i = 0; i < channel_num_; ++i) {
2432+    float real_min = real_min_[i];
2433+    float real_max = real_max_[i];
2434+
2435+    lite::LiteQuantParam quant_parm;
2436+    double scale;
2437+    int zp = 0;
2438+    constexpr int kQSymmetricRange = 255;
2439+    constexpr int kQAsymmetricRange = 254;
2440+    if (!symmetric_) {
2441+      auto range = real_max - real_min;
2442+      if (range <= 0) {
2443+        range = kDefaultRange;
2444+        MS_LOG(WARNING) << name_ << " range is 0 and set the range to 0.01.";
2445+      }
2446+      scale = range / kQSymmetricRange;  // -128 ~ 127
2447+      zp = static_cast<int>(std::round(INT8_MIN - real_min / scale));
2448+    } else {
2449+      auto max = std::max(abs(real_max), abs(real_min));
2450+      scale = 2 * max / kQAsymmetricRange;  // -127 ~ 127
2451+    }
2452+    quant_parm.scale = scale;
2453+    quant_parm.zeroPoint = zp;
2454+    quant_parm.bitNum = k8Bit;
2455+    quant_parm.inited = true;
2456+    quant_params.push_back(quant_parm);
2457+  }
2458+  this->out_tensors_.front()->set_quant_params(quant_params);
2459+  return;
2460+}
2461+
2462 int DynamicQuantCPUKernel::QuantData(int task_id) {
2463   int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
2464   if (num_unit_thread <= 0) {
2465     return RET_OK;
2466   }
2467-  int thread_offset = task_id * thread_n_stride_;
2468-  auto quant_arg = out_tensors_.front()->quant_params().front();
2469-  int ret;
2470   TypeId data_type = out_tensors_.front()->data_type();
2471-  if (data_type == TypeId::kNumberTypeInt8) {
2472-    ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
2473-                               quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
2474-  } else {
2475+  if (data_type != TypeId::kNumberTypeInt8) {
2476     MS_LOG(ERROR) << "Data type not supported:" << data_type;
2477     return RET_PARAM_INVALID;
2478   }
2479-  if (ret != RET_OK) {
2480-    MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
2481-    return RET_ERROR;
2482+	int thread_offset = task_id * thread_n_stride_;
2483+  int ret;
2484+  if (activation_perchannel_) {
2485+		if (out_tensors_.front()->quant_params().size() != static_cast<size_t>(channel_num_)) {
2486+			return RET_ERROR;
2487+		}
2488+		float *scale = (float *)malloc(channel_num_ * sizeof(float));
2489+		int32_t *zero_point = (int32_t *)malloc(channel_num_ * sizeof(int32_t));
2490+		for (int i = 0; i < channel_num_; i++) {
2491+			auto quant_arg = out_tensors_.front()->quant_params().at(i);
2492+			scale[i] = quant_arg.scale;
2493+			zero_point[i] = quant_arg.zeroPoint;
2494+		}
2495+		if (transpose_) {
2496+			MS_LOG(ERROR) << "Matrix a transpose not supported.";
2497+			free(scale);
2498+			free(zero_point);
2499+			return RET_ERROR;
2500+		} else {
2501+			ret = DoPerchannelQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, scale,
2502+																					 zero_point, num_unit_thread, channel_length_, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
2503+			free(scale);
2504+			free(zero_point);
2505+			if (ret != RET_OK) {
2506+				MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
2507+				return RET_ERROR;
2508+			}
2509+		}
2510+  } else {
2511+    auto quant_arg = out_tensors_.front()->quant_params().front();
2512+    ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
2513+                               quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
2514+    if (ret != RET_OK) {
2515+      MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
2516+      return RET_ERROR;
2517+    }
2518   }
2519-  return RET_OK;
2520+  return RET_OK;
2521 }
2522
2523 int QuantDataRun(void *cdata, int task_id, float, float) {
2524@@ -182,8 +283,11 @@ int DynamicQuantCPUKernel::Run() {
2525     MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
2526     return RET_ERROR;
2527   }
2528-  ReduceMinMaxFp32();
2529-  CalculateScaleZp();
2530+  if (activation_perchannel_) {
2531+    CalculatePerChannelScaleZp();
2532+  } else {
2533+    CalculatePerlayerScaleZp();
2534+  }
2535   ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_n_num_);
2536   if (ret != RET_OK) {
2537     MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
2538diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h
2539index 6acb0d8d..e44c7643 100644
2540--- a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h
2541+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h
2542@@ -19,6 +19,7 @@
2543
2544 #include <vector>
2545 #include <cfloat>
2546+#include <map>
2547 #include "src/runtime/lite_kernel.h"
2548
2549 namespace mindspore::kernel {
2550@@ -27,7 +28,10 @@ class DynamicQuantCPUKernel : public LiteKernel {
2551   DynamicQuantCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
2552                         const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
2553       : LiteKernel(parameter, inputs, outputs, ctx), thread_num_(ctx->thread_num_) {}
2554-  ~DynamicQuantCPUKernel() override = default;
2555+  ~DynamicQuantCPUKernel() override {
2556+    free(real_min_);
2557+	  free(real_max_);
2558+  };
2559
2560   int Prepare() override;
2561   int ReSize() override;
2562@@ -37,8 +41,8 @@ class DynamicQuantCPUKernel : public LiteKernel {
2563   int CalculateMinMax(int task_id);
2564
2565  private:
2566-  void ReduceMinMaxFp32();
2567-  void CalculateScaleZp();
2568+  void CalculatePerlayerScaleZp();
2569+  void CalculatePerChannelScaleZp();
2570
2571  private:
2572   int thread_num_;
2573@@ -47,14 +51,19 @@ class DynamicQuantCPUKernel : public LiteKernel {
2574   int num_unit_{0};
2575   int8_t *int8_ptr_ = nullptr;
2576   float *float32_ptr_ = nullptr;
2577+  float *real_min_ = nullptr;
2578+  float *real_max_ = nullptr;
2579
2580-  float real_min_array_[8];
2581-  float real_max_array_[8];
2582-  float real_min_ = FLT_MAX;
2583-  float real_max_ = FLT_MIN;
2584   int32_t src_dtype_{0};
2585   int32_t dst_dtype_{0};
2586   bool symmetric_ = false;
2587+  bool activation_perchannel_ = false;
2588+  bool transpose_ = false;
2589+  int32_t prefer_axis_{-1};
2590+  // int32_t input_shape_[8];
2591+  // int32_t shape_size_{0};
2592+  int32_t channel_num_{0};
2593+  int32_t channel_length_{0};
2594 };
2595 }  // namespace mindspore::kernel
2596
2597diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h
2598index a9383eac..5e360789 100644
2599--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h
2600+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h
2601@@ -36,6 +36,7 @@ class MatmulBaseInt8CPUKernel : public LiteKernel {
2602                           const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
2603       : LiteKernel(parameter, inputs, outputs, ctx) {
2604     param_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
2605+    param_->matmul_type_ = MatmulType::kNotImplemented;
2606   }
2607   ~MatmulBaseInt8CPUKernel() override;
2608   int Prepare() override;
2609diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc
2610index 166510ec..c51d7cc5 100644
2611--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc
2612+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc
2613@@ -1,5 +1,5 @@
2614 /**
2615- * Copyright 2022 Huawei Technologies Co., Ltd
2616+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
2617  *
2618  * Licensed under the Apache License, Version 2.0 (the "License");
2619  * you may not use this file except in compliance with the License.
2620@@ -17,6 +17,9 @@
2621 #include "src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h"
2622 #include "nnacl/int8/dynamic_matmul_int8.h"
2623
2624+using mindspore::lite::kCHWDimNumber;
2625+using mindspore::lite::kHWDimNumber;
2626+using mindspore::lite::kNCHWDimNumber;
2627 using mindspore::lite::RET_ERROR;
2628 using mindspore::lite::RET_MEMORY_FAILED;
2629 using mindspore::lite::RET_OK;
2630@@ -79,20 +82,20 @@ int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() {
2631   }
2632   int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
2633   filter_per_channel_ = (weight_quant_params.size() > 1);
2634-  channel_num_ = filter_per_channel_ ? col : 1;
2635-  if (static_cast<int>(weight_quant_params.size()) != channel_num_) {
2636+  auto channel_num = filter_per_channel_ ? col : 1;
2637+  if (static_cast<int>(weight_quant_params.size()) != channel_num) {
2638     MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
2639-                  << " != channel_num_:" << channel_num_;
2640+                  << " != channel_num:" << channel_num;
2641     return RET_ERROR;
2642   }
2643-  quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
2644+  quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num * sizeof(float)));
2645   CHECK_NULL_RETURN(quant_param_->filter_scale_);
2646-  memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
2647-  quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
2648+  memset(quant_param_->filter_scale_, 0, sizeof(channel_num));
2649+  quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num * sizeof(int32_t)));
2650   CHECK_NULL_RETURN(quant_param_->filter_zp_);
2651-  memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
2652+  memset(quant_param_->filter_zp_, 0, sizeof(channel_num));
2653
2654-  for (int i = 0; i < channel_num_; i++) {
2655+  for (int i = 0; i < channel_num; i++) {
2656     quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
2657     quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
2658   }
2659@@ -105,57 +108,68 @@ void MatmulDynamicBaseInt8CPUKernel::ResizeMatrixBParameter() {
2660   for (size_t i = 0; i < w_shape.size() - kSize2; ++i) {
2661     batch *= w_shape[i];
2662   }
2663-  param_->batch = batch;
2664+  b_batch_ = batch;
2665   param_->col_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
2666   param_->deep_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize1] : w_shape[w_shape.size() - kSize2];
2667
2668   param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
2669   param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
2670
2671-  thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
2672-  thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
2673+  thread_num_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
2674+  thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_num_);
2675   return;
2676 }
2677
2678 void MatmulDynamicBaseInt8CPUKernel::FreeTmpBuffer() {
2679-  if (pack_a_ptr_ != nullptr) {
2680-    free(pack_a_ptr_);
2681-    pack_a_ptr_ = nullptr;
2682-  }
2683-  if (pack_b_ptr_ != nullptr) {
2684+  FreeMatrixABuffer();
2685+  if (pack_b_ptr_ != nullptr && !weight_is_packed_) {
2686     free(pack_b_ptr_);
2687     pack_b_ptr_ = nullptr;
2688   }
2689-  if (input_sums_ != nullptr) {
2690-    free(input_sums_);
2691-    input_sums_ = nullptr;
2692-  }
2693-  if (weight_sums_ != nullptr) {
2694+  if (weight_sums_ != nullptr && !weight_is_packed_) {
2695     free(weight_sums_);
2696     weight_sums_ = nullptr;
2697   }
2698-  if (fp32_bias_ptr_ != nullptr) {
2699-    free(fp32_bias_ptr_);
2700-    fp32_bias_ptr_ = nullptr;
2701+  if (bias_ptr_ != nullptr) {
2702+    free(bias_ptr_);
2703+    bias_ptr_ = nullptr;
2704   }
2705-  return;
2706 }
2707
2708-int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam() {
2709+int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam(std::vector<float> *scales, std::vector<int32_t> *zp) {
2710   auto in_quant_params = in_tensors_.at(kInputIndex)->quant_params();
2711   if (in_quant_params.empty()) {
2712     MS_LOG(ERROR) << "invalid in quant param";
2713     return RET_ERROR;
2714   }
2715-  quant_param_->input_zp_ = in_quant_params.front().zeroPoint;
2716-  quant_param_->input_scale_ = static_cast<float>(in_quant_params.front().scale);
2717+  input_per_channel_ = (in_quant_params.size() > 1);
2718+  auto channel_num = input_per_channel_ ? param_->row_ : 1;
2719+  if (static_cast<int>(in_quant_params.size()) != channel_num) {
2720+    MS_LOG(ERROR) << in_tensors_.at(kInputIndex)->tensor_name() << " quant params size:" << in_quant_params.size()
2721+                  << " != channel_num:" << channel_num;
2722+    return RET_ERROR;
2723+  }
2724+  scales->resize(channel_num);
2725+  zp->resize(channel_num);
2726+  for (int i = 0; i < channel_num; ++i) {
2727+    (*scales)[i] = in_quant_params[i].scale;
2728+    (*zp)[i] = in_quant_params[i].zeroPoint;
2729+  }
2730+  quant_param_->input_zp_ = zp->data();
2731+  quant_param_->input_scale_ = scales->data();
2732   return RET_OK;
2733 }
2734
2735 int MatmulDynamicBaseInt8CPUKernel::TransferB() {
2736+  if (weight_is_packed_) {
2737+    CHECK_NULL_RETURN(weight_sums_tensor_);
2738+    pack_b_ptr_ = static_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
2739+    weight_sums_ = static_cast<int *>(weight_sums_tensor_->data());
2740+    return RET_OK;
2741+  }
2742   auto weight_data = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->data());
2743   CHECK_NULL_RETURN(weight_data);
2744-  for (int i = 0; i < param_->batch; i++) {
2745+  for (int i = 0; i < b_batch_; i++) {
2746     auto current_weight = weight_data + i * param_->deep_ * param_->col_;
2747     auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
2748     auto current_sums = weight_sums_ + i * param_->col_align_;
2749@@ -168,40 +182,51 @@ int MatmulDynamicBaseInt8CPUKernel::TransferB() {
2750       CalcWeightSums(current_weight, param_->deep_, param_->col_, current_sums, RowMajor);
2751     }
2752   }
2753+
2754   return RET_OK;
2755 }
2756
2757 int MatmulDynamicBaseInt8CPUKernel::InitMatrixABuffer() {
2758-  if (pack_a_ptr_ != nullptr) {
2759-    free(pack_a_ptr_);
2760-    pack_a_ptr_ = nullptr;
2761+  size_t pack_a_size = param_->row_align_ * param_->deep_align_ * sizeof(int8_t);
2762+  size_t sum_a_size = param_->row_align_ * sizeof(int);
2763+  if (ms_context_ != nullptr && ms_context_->allocator != nullptr) {
2764+    pack_a_ptr_ = reinterpret_cast<int8_t *>(ms_context_->allocator->Malloc(pack_a_size + sum_a_size));
2765+  } else {
2766+    pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(pack_a_size + sum_a_size));
2767   }
2768-  pack_a_ptr_ = reinterpret_cast<int8_t *>(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t)));
2769   if (pack_a_ptr_ == nullptr) {
2770-    FreeTmpBuffer();
2771-    return RET_ERROR;
2772+    MS_LOG(ERROR) << "alloc run-buffer for matrix-a failed.";
2773+    return lite::RET_NULL_PTR;
2774   }
2775-  if (input_sums_ != nullptr) {
2776-    free(input_sums_);
2777-    input_sums_ = nullptr;
2778+  input_sums_ = reinterpret_cast<int *>(pack_a_ptr_ + pack_a_size);
2779+  memset(pack_a_ptr_, 0, pack_a_size + sum_a_size);
2780+  return RET_OK;
2781+}
2782+
2783+void MatmulDynamicBaseInt8CPUKernel::FreeMatrixABuffer() {
2784+  if (pack_a_ptr_ == nullptr) {
2785+    return;
2786   }
2787-  input_sums_ = reinterpret_cast<int *>(malloc(param_->row_align_ * sizeof(int)));
2788-  if (input_sums_ == nullptr) {
2789-    FreeTmpBuffer();
2790-    return RET_ERROR;
2791+  if (ms_context_ != nullptr && ms_context_->allocator != nullptr) {
2792+    ms_context_->allocator->Free(pack_a_ptr_);
2793+  } else {
2794+    free(pack_a_ptr_);
2795   }
2796-  memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
2797-  memset(input_sums_, 0, param_->row_align_ * sizeof(int));
2798-  return RET_OK;
2799+  pack_a_ptr_ = nullptr;
2800+  input_sums_ = nullptr;
2801 }
2802
2803 int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() {
2804+  if (weight_is_packed_) {
2805+    return RET_OK;
2806+  }
2807+
2808   if (pack_b_ptr_ != nullptr) {
2809     free(pack_b_ptr_);
2810     pack_b_ptr_ = nullptr;
2811   }
2812   pack_b_ptr_ =
2813-    reinterpret_cast<int8_t *>(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
2814+    reinterpret_cast<int8_t *>(malloc(b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)));
2815   if (pack_b_ptr_ == nullptr) {
2816     FreeTmpBuffer();
2817     return RET_ERROR;
2818@@ -210,28 +235,32 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() {
2819     free(weight_sums_);
2820     weight_sums_ = nullptr;
2821   }
2822-  weight_sums_ = reinterpret_cast<int *>(malloc(param_->batch * param_->col_align_ * sizeof(int)));
2823+  weight_sums_ = reinterpret_cast<int *>(malloc(b_batch_ * param_->col_align_ * sizeof(int)));
2824   if (weight_sums_ == nullptr) {
2825     FreeTmpBuffer();
2826     return RET_ERROR;
2827   }
2828-  memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
2829-  memset(weight_sums_, 0, param_->batch * param_->col_align_ * sizeof(int));
2830+  memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
2831+  memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int));
2832   return RET_OK;
2833 }
2834
2835 int MatmulDynamicBaseInt8CPUKernel::CopyBias() {
2836   if (in_tensors_.size() == kHasBiasSize) {
2837+    CHECK_NULL_RETURN(in_tensors_[kBiasIndex]);
2838     auto bias_tensor = in_tensors_[kBiasIndex];
2839-    fp32_bias_ptr_ = static_cast<float *>(malloc(bias_tensor->Size()));
2840-    if (fp32_bias_ptr_ == nullptr) {
2841+    auto bias_shape = bias_tensor->shape();
2842+    MS_CHECK_TRUE_MSG(bias_shape.size() == 1, lite::RET_INPUT_TENSOR_ERROR, "bias is not 1D.");
2843+    size_t bias_pack_size = UP_ROUND(bias_shape.back(), col_tile_) * lite::DataTypeSize(bias_tensor->data_type());
2844+    bias_ptr_ = malloc(bias_pack_size);
2845+    if (bias_ptr_ == nullptr) {
2846       MS_LOG(ERROR) << "Memory allocation failed";
2847       FreeTmpBuffer();
2848       return RET_MEMORY_FAILED;
2849     }
2850-    memcpy(fp32_bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float));
2851+    memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->Size());
2852   } else {
2853-    fp32_bias_ptr_ = nullptr;
2854+    bias_ptr_ = nullptr;
2855   }
2856   return RET_OK;
2857 }
2858@@ -239,6 +268,18 @@ int MatmulDynamicBaseInt8CPUKernel::CopyBias() {
2859 int MatmulDynamicBaseInt8CPUKernel::Prepare() {
2860   CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
2861   CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
2862+  CHECK_NULL_RETURN(in_tensors_[0]);
2863+  CHECK_NULL_RETURN(in_tensors_[1]);
2864+  CHECK_NULL_RETURN(out_tensors_[0]);
2865+  if (in_tensors_[0]->data_type() != mindspore::kNumberTypeInt8 ||
2866+      in_tensors_[1]->data_type() != mindspore::kNumberTypeInt8) {
2867+    MS_LOG(ERROR) << "Datatype error, input0 data_type is " << in_tensors_[0]->data_type() << ", input1 data_type is "
2868+                  << in_tensors_[1]->data_type();
2869+    return RET_ERROR;
2870+  }
2871+#ifdef ENABLE_FP16
2872+  enable_fp16_ = ms_context_->device_list_[0].device_info_.cpu_device_info_.enable_float16_;
2873+#endif
2874   InitParameter();
2875   auto ret = MallocQuantParam();
2876   if (ret != RET_OK) {
2877@@ -277,18 +318,24 @@ int MatmulDynamicBaseInt8CPUKernel::Prepare() {
2878 }
2879
2880 int MatmulDynamicBaseInt8CPUKernel::ReSize() {
2881+  // In the framework, the out_tensors data_type is forced to kNumberTypeFloat32
2882+  if (enable_fp16_) {
2883+    out_tensors_[0]->set_data_type(kNumberTypeFloat16);
2884+  }
2885   auto x_shape = in_tensors_.at(0)->shape();
2886   auto o_shape = out_tensors_.at(0)->shape();
2887   MS_ASSERT(o_shape.size() >= kSize2);
2888+
2889   param_->row_ = o_shape[o_shape.size() - kSize2];
2890   param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
2891   param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1];
2892   param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_);
2893
2894-  auto ret = InitMatrixABuffer();
2895+  auto ret = InitBroadcastParams(in_tensors_[kInputIndex]->shape(), in_tensors_[kWeightIndex]->shape(), param_,
2896+                                 &a_offset_, &b_offset_);
2897   if (ret != RET_OK) {
2898-    FreeQuantParam();
2899-    return ret;
2900+    MS_LOG(ERROR) << "InitBroadcastParams failed.";
2901+    return RET_ERROR;
2902   }
2903
2904   if (!param_->b_const_) {
2905@@ -301,4 +348,80 @@ int MatmulDynamicBaseInt8CPUKernel::ReSize() {
2906   }
2907   return RET_OK;
2908 }
2909+
2910+int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector<int> &a_shape_const,
2911+                                                        const std::vector<int> &b_shape_const, MatMulParameter *params,
2912+                                                        std::vector<int> *a_offsets, std::vector<int> *b_offsets) {
2913+  std::vector<int> a_shape = a_shape_const;
2914+  if (a_shape.size() < kNCHWDimNumber) {
2915+    size_t add_nums = kNCHWDimNumber - a_shape.size();
2916+    for (size_t i = 0; i < add_nums; ++i) {
2917+      (void)a_shape.insert(a_shape.begin(), 1);
2918+    }
2919+  }
2920+  std::vector<int> b_shape = b_shape_const;
2921+  if (b_shape.size() < kNCHWDimNumber) {
2922+    size_t add_nums = kNCHWDimNumber - b_shape.size();
2923+    for (size_t i = 0; i < add_nums; ++i) {
2924+      (void)b_shape.insert(b_shape.begin(), 1);
2925+    }
2926+  }
2927+
2928+  int batch_sizes[MAX_SHAPE_SIZE] = {0};
2929+  int a_batch_sizes[MAX_SHAPE_SIZE] = {0};
2930+  int b_batch_sizes[MAX_SHAPE_SIZE] = {0};
2931+  for (int i = a_shape.size() - kCHWDimNumber; i >= 0; --i) {
2932+    if (static_cast<int>(a_shape.size() - kCHWDimNumber) == i) {
2933+      batch_sizes[i] = std::max(a_shape[i], b_shape[i]);
2934+      a_batch_sizes[i] = a_shape[i];
2935+      b_batch_sizes[i] = b_shape[i];
2936+    } else {
2937+      batch_sizes[i] = batch_sizes[i + 1] * std::max(a_shape[i], b_shape[i]);
2938+      a_batch_sizes[i] = a_batch_sizes[i + 1] * a_shape[i];
2939+      b_batch_sizes[i] = b_batch_sizes[i + 1] * b_shape[i];
2940+    }
2941+  }
2942+
2943+  int out_batch = 1;
2944+  for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) {
2945+    int max_v = MSMAX(a_shape[i], b_shape[i]);
2946+    int min_v = MSMIN(a_shape[i], b_shape[i]) > 0 ? MSMIN(a_shape[i], b_shape[i]) : 1;
2947+    out_batch *= max_v;
2948+    if (max_v != min_v && max_v % min_v != 0) {
2949+      MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape;
2950+      return RET_ERROR;
2951+    }
2952+  }
2953+  params->batch = out_batch;
2954+
2955+  a_offsets->resize(params->batch, 0);
2956+  b_offsets->resize(params->batch, 0);
2957+  for (int i = 0; i < params->batch; ++i) {
2958+    int64_t delta = i;
2959+    int a_offset = 0;
2960+    int b_offset = 0;
2961+    for (size_t j = 0; j < a_shape.size() - kHWDimNumber; ++j) {
2962+      if (j > 0) {
2963+        delta = delta % batch_sizes[j];
2964+      }
2965+      if (j < (a_shape.size() - kCHWDimNumber)) {
2966+        a_offset += (delta / batch_sizes[j + 1] * a_shape[j] / std::max(a_shape[j], b_shape[j])) * a_batch_sizes[j + 1];
2967+        b_offset += (delta / batch_sizes[j + 1] * b_shape[j] / std::max(a_shape[j], b_shape[j])) * b_batch_sizes[j + 1];
2968+      } else {
2969+        a_offset += (delta * a_shape[j] / std::max(a_shape[j], b_shape[j]));
2970+        b_offset += (delta * b_shape[j] / std::max(a_shape[j], b_shape[j]));
2971+      }
2972+    }
2973+    (*a_offsets)[i] = a_offset;
2974+    (*b_offsets)[i] = b_offset;
2975+  }
2976+
2977+  return RET_OK;
2978+}
2979+
2980+int MatmulDynamicBaseInt8CPUKernel::PreparePackedWeight(const lite::Tensor *tensor) {
2981+  weight_is_packed_ = true;
2982+  weight_sums_tensor_ = tensor;
2983+  return RET_OK;
2984+}
2985 }  // namespace mindspore::kernel
2986diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h
2987index 68b664af..6f86c07a 100644
2988--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h
2989+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h
2990@@ -1,5 +1,5 @@
2991 /**
2992- * Copyright 2022 Huawei Technologies Co., Ltd
2993+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
2994  *
2995  * Licensed under the Apache License, Version 2.0 (the "License");
2996  * you may not use this file except in compliance with the License.
2997@@ -18,13 +18,14 @@
2998 #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_MATMUL_DYNAMIC_BASE_INT8_H_
2999
3000 #include <vector>
3001+#include <algorithm>
3002 #include "include/errorcode.h"
3003-#include "include/context.h"
3004 #include "src/runtime/lite_kernel.h"
3005 #include "nnacl/matmul_parameter.h"
3006 #include "nnacl/common_func.h"
3007 #include "nnacl/int8/quantize.h"
3008 #include "nnacl/int8/common_func_int8.h"
3009+#include "src/common/common.h"
3010
3011 namespace mindspore::kernel {
3012 class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
3013@@ -37,44 +38,60 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
3014   ~MatmulDynamicBaseInt8CPUKernel() override;
3015   int Prepare() override;
3016   int ReSize() override;
3017+  static int InitBroadcastParams(const std::vector<int> &a_shape_const, const std::vector<int> &b_shape_const,
3018+                                 MatMulParameter *params, std::vector<int> *a_offsets, std::vector<int> *b_offsets);
3019+
3020+  const int8_t *GetPackBPtr() const { return pack_b_ptr_; }
3021+  const int *GetWeightSums() const { return weight_sums_; }
3022+  const int GetBBatch() const { return b_batch_; }
3023+  int PreparePackedWeight(const lite::Tensor *tensor) override;
3024
3025  private:
3026   void ResizeMatrixBParameter();
3027   int CopyBias();
3028-  int InitMatrixABuffer();
3029   int InitMatrixBBuffer();
3030   int MallocQuantParam();
3031
3032  protected:
3033+  int a_batch_ = 1;
3034+  int b_batch_ = 1;
3035+  std::vector<int> a_offset_;
3036+  std::vector<int> b_offset_;
3037   typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
3038   virtual void InitParameter() = 0;
3039   int TransferA();
3040-  int InitInputQuantParam();
3041+  int InitInputQuantParam(std::vector<float> *scales, std::vector<int32_t> *zp);
3042   int InitFilterQuantParam();
3043   int TransferB();
3044   void FreeTmpBuffer();
3045   void FreeQuantParam();
3046+  int InitMatrixABuffer();
3047+  void FreeMatrixABuffer();
3048
3049  protected:
3050   MatMulParameter *param_ = nullptr;
3051   MatmulDynamicQuantParameter *quant_param_ = nullptr;
3052   int8_t *pack_a_ptr_ = nullptr;
3053   int8_t *pack_b_ptr_ = nullptr;
3054-  float *fp32_bias_ptr_ = nullptr;
3055+
3056+  bool input_per_channel_ = false;
3057   bool filter_per_channel_ = true;
3058   int8_t *batch_input_ptr_ = nullptr;
3059   int8_t *batch_weight_ptr_ = nullptr;
3060+  int8_t *batch_a_ptr_ = nullptr;
3061   int8_t *batch_b_ptr_ = nullptr;
3062-  float *batch_c_ptr_ = nullptr;
3063+  void *bias_ptr_ = nullptr;
3064+  void *batch_c_ptr_ = nullptr;
3065   int *input_sums_ = nullptr;
3066   int *weight_sums_ = nullptr;
3067   int row_tile_ = C4NUM;
3068   int col_tile_ = C4NUM;
3069   int deep_tile_ = C16NUM;
3070-  int channel_num_ = 0;
3071-  int thread_count_ = 1;
3072   int thread_stride_ = 0;
3073+  bool enable_fp16_ = false;
3074   PackFunc b_pack_func_ = nullptr;
3075+  bool weight_is_packed_ = false;
3076+  const lite::Tensor *weight_sums_tensor_ = nullptr;
3077 };
3078 }  // namespace mindspore::kernel
3079
3080diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc
3081index 766b7bb2..69c1baae 100644
3082--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc
3083+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc
3084@@ -1,5 +1,5 @@
3085 /**
3086- * Copyright 2022 Huawei Technologies Co., Ltd
3087+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
3088  *
3089  * Licensed under the Apache License, Version 2.0 (the "License");
3090  * you may not use this file except in compliance with the License.
3091@@ -45,8 +45,8 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
3092   if (cur_oc <= 0) {
3093     return RET_OK;
3094   }
3095-  float *bias_ptr = fp32_bias_ptr_;
3096-  if (fp32_bias_ptr_ != nullptr) {
3097+  float *bias_ptr = static_cast<float *>(bias_ptr_);
3098+  if (bias_ptr_ != nullptr) {
3099     bias_ptr += cur_stride;
3100   }
3101   float *filter_scale = quant_param_->filter_scale_;
3102@@ -54,10 +54,12 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
3103   if (filter_per_channel_) {
3104     filter_scale += cur_stride;
3105   }
3106-  DynamicMatmul4x16x4AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
3107-                          batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_, param_->deep_align_,
3108-                          param_->col_, quant_param_->input_zp_, quant_param_->input_scale_, filter_scale, filter_zp,
3109-                          filter_per_channel_);
3110+  int64_t act_type = static_cast<int64_t>(param_->act_type_);
3111+
3112+  DynamicMatmul4x16x4AIWI(batch_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr,
3113+                          static_cast<float *>(batch_c_ptr_) + cur_stride, param_->row_, cur_oc, param_->deep_,
3114+                          param_->deep_align_, param_->col_, *quant_param_->input_zp_, quant_param_->input_scale_,
3115+                          filter_scale, filter_zp, input_per_channel_, filter_per_channel_, act_type);
3116   return RET_OK;
3117 }
3118
3119@@ -81,11 +83,18 @@ void MatmulDynamicInt8CPUKernel::InitParameter() {
3120 }
3121
3122 int MatmulDynamicInt8CPUKernel::Run() {
3123-  auto ret = InitInputQuantParam();
3124+  std::vector<float> input_scales;
3125+  std::vector<int32_t> input_zp;
3126+  auto ret = InitInputQuantParam(&input_scales, &input_zp);
3127   if (ret != RET_OK) {
3128     MS_LOG(ERROR) << "Init input quant param failed.";
3129     return ret;
3130   }
3131+  ret = InitMatrixABuffer();
3132+  if (ret != RET_OK) {
3133+    MS_LOG(ERROR) << " failed.";
3134+    return ret;
3135+  }
3136   if (!param_->b_const_) {
3137     ret = InitFilterQuantParam();
3138     if (ret != RET_OK) {
3139@@ -104,8 +113,8 @@ int MatmulDynamicInt8CPUKernel::Run() {
3140   CHECK_NULL_RETURN(a_ptr);
3141   CHECK_NULL_RETURN(c_ptr);
3142   for (int i = 0; i < param_->batch; i++) {
3143-    memset(pack_a_ptr_, quant_param_->input_zp_, param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
3144-    auto current_src_a = a_ptr + i * param_->row_ * param_->deep_;
3145+    memset(pack_a_ptr_, *(quant_param_->input_zp_), param_->row_align_ * param_->deep_align_ * sizeof(int8_t));
3146+    auto current_src_a = a_ptr + a_offset_[i] * param_->row_ * param_->deep_;
3147     if (param_->a_transpose_) {
3148       MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR);
3149       a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_);
3150@@ -114,15 +123,17 @@ int MatmulDynamicInt8CPUKernel::Run() {
3151       a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_);
3152     }
3153
3154-    batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
3155+    batch_a_ptr_ = pack_a_ptr_;
3156+    batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_;
3157     batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
3158
3159-    ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_count_);
3160+    ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_num_);
3161     if (ret != RET_OK) {
3162       MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
3163       return ret;
3164     }
3165   }
3166+  FreeMatrixABuffer();
3167   return RET_OK;
3168 }
3169 }  // namespace mindspore::kernel
3170diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h
3171index 71869275..86b2c009 100644
3172--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h
3173+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h
3174@@ -25,7 +25,9 @@ class MatmulDynamicInt8CPUKernel : public MatmulDynamicBaseInt8CPUKernel {
3175  public:
3176   MatmulDynamicInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
3177                              const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
3178-      : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {}
3179+      : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {
3180+    param_->matmul_type_ = MatmulType::kMatmulDynamicInt8Cpu;
3181+  }
3182   ~MatmulDynamicInt8CPUKernel() override = default;
3183   int Run() override;
3184
3185diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc
3186index 132fa5d7..755b81e9 100644
3187--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc
3188+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc
3189@@ -1,5 +1,5 @@
3190 /**
3191- * Copyright 2022 Huawei Technologies Co., Ltd
3192+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
3193  *
3194  * Licensed under the Apache License, Version 2.0 (the "License");
3195  * you may not use this file except in compliance with the License.
3196@@ -81,6 +81,41 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotPre(int task_id) {
3197   return RET_OK;
3198 }
3199
3200+void MatMulDynamicSdotInt8Kernel::ComputeMultiScaleAhead(std::vector<float> *multi_scale, int col_start,
3201+                                                         size_t col_num) {
3202+  auto &scales = *multi_scale;
3203+  if (!input_per_channel_) {
3204+    if (!filter_per_channel_) {
3205+      scales.resize(1);
3206+      scales[0] = quant_param_->input_scale_[0] * quant_param_->filter_scale_[0];
3207+    } else {
3208+      scales.resize(UP_ROUND(col_num, col_tile_));
3209+      float *filter_scales = quant_param_->filter_scale_ + col_start;
3210+      for (size_t i = 0; i < col_num; ++i) {
3211+        scales[i] = quant_param_->input_scale_[0] * filter_scales[i];
3212+      }
3213+    }
3214+  } else if (!filter_per_channel_) {
3215+    scales.resize(param_->row_align_);
3216+    for (int i = 0; i < param_->row_; ++i) {
3217+      scales[i] = quant_param_->input_scale_[i] * quant_param_->filter_scale_[0];
3218+    }
3219+  }
3220+}
3221+
3222+void MatMulDynamicSdotInt8Kernel::ComputeMultiScaleChannelByChannel(std::vector<float> *multi_scale, int row_start,
3223+                                                                    size_t row_num, int col_start, size_t col_num) {
3224+  auto &scales = *multi_scale;
3225+  scales.resize(row_tile_ * col_tile_, 0);
3226+  float *in_scales = quant_param_->input_scale_ + row_start;
3227+  float *filter_scales = quant_param_->filter_scale_ + col_start;
3228+  for (size_t i = 0; i < row_num; ++i) {
3229+    for (size_t j = 0; j < col_num; ++j) {
3230+      scales[i * col_tile_ + j] = in_scales[i] * filter_scales[j];
3231+    }
3232+  }
3233+}
3234+
3235 int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) {
3236   // Multi-thread split by col.
3237   int stride = thread_stride_ * col_tile_;
3238@@ -104,15 +139,13 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) {
3239     }
3240   }
3241
3242-  std::vector<float> multi_scale(cur_oc);
3243-  for (int i = 0; i < cur_oc; ++i) {
3244-    if (!param_->b_const_) {
3245-      multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[0];
3246-    } else {
3247-      multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[cur_stride + i];
3248-    }
3249-  }
3250-  auto out_stride = param_->col_ * sizeof(float);
3251+  std::vector<float> multi_scale;
3252+  ComputeMultiScaleAhead(&multi_scale, cur_stride, cur_oc);
3253+  int64_t mode = input_per_channel_ * C2NUM + filter_per_channel_;
3254+
3255+  size_t data_type_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
3256+  auto out_stride = param_->col_ * data_type_size;
3257+  int64_t act_type = static_cast<int64_t>(param_->act_type_);
3258   for (int r = 0; r < param_->row_; r += C4NUM) {
3259     size_t row = MSMIN(C4NUM, param_->row_ - r);
3260     auto a_ptr = pack_a_ptr_ + r * param_->deep_align_;
3261@@ -122,21 +155,30 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) {
3262       auto col_offset = cur_stride + c;
3263       auto b_ptr = batch_b_ptr_ + col_offset * param_->deep_align_;
3264       int *weight_sums_ptr = current_sums + c;
3265-      auto out_ptr = batch_c_ptr_ + r * param_->col_ + col_offset;
3266-      auto bias = fp32_bias_ptr_;
3267-      if (bias != nullptr) {
3268-        bias += col_offset;
3269-      }
3270
3271-#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(SUPPORT_34XX) && (!defined(MACHINE_LINUX_ARM64))
3272-      DynamicMatmulSdot4x4x16AIWI(a_ptr, b_ptr, out_ptr, param_->deep_align_, multi_scale.data() + c, bias, row, col,
3273-                                  out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_,
3274-                                  quant_param_->filter_zp_[0] * param_->deep_);
3275-#else
3276-      DynamicMatmul4x4x16AIWI(a_ptr, b_ptr, out_ptr, param_->deep_align_, multi_scale.data() + c, bias, row, col,
3277-                              out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_,
3278-                              quant_param_->filter_zp_[0] * param_->deep_);
3279+      void *out_ptr = static_cast<int8_t *>(batch_c_ptr_) + (r * param_->col_ + col_offset) * data_type_size;
3280+      auto bias = bias_ptr_;
3281+      if (bias_ptr_ != nullptr) {
3282+        bias = static_cast<int8_t *>(bias) + col_offset * data_type_size;
3283+      }
3284+      if (mode == C3NUM) {
3285+        ComputeMultiScaleChannelByChannel(&multi_scale, r, row, col_offset, col);
3286+      }
3287+      int multi_scale_offset =
3288+        (input_per_channel_ == filter_per_channel_ ? 0 : input_per_channel_ * r + filter_per_channel_ * c);
3289+      if (!enable_fp16_) {
3290+        dynamic_matmul_compute_fp32(a_ptr, b_ptr, reinterpret_cast<float *>(out_ptr), param_->deep_align_,
3291+                                    multi_scale.data() + multi_scale_offset, reinterpret_cast<float *>(bias), row, col,
3292+                                    out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_[0],
3293+                                    quant_param_->filter_zp_[0] * param_->deep_, act_type, mode);
3294+      } else {
3295+#ifdef ENABLE_FP16
3296+        dynamic_matmul_compute_fp16(a_ptr, b_ptr, reinterpret_cast<float16_t *>(out_ptr), param_->deep_align_,
3297+                                    multi_scale.data() + multi_scale_offset, reinterpret_cast<float16_t *>(bias), row,
3298+                                    col, out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_[0],
3299+                                    quant_param_->filter_zp_[0] * param_->deep_, act_type, mode);
3300 #endif
3301+      }
3302     }
3303   }
3304   return RET_OK;
3305@@ -155,31 +197,44 @@ void MatMulDynamicSdotInt8Kernel::InitParameter() {
3306   } else {
3307     b_pack_func_ = RowMajor2Col4x16MajorInt8;
3308   }
3309-  return;
3310+#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(SUPPORT_34XX) && (!defined(MACHINE_LINUX_ARM64)) && \
3311+  !defined(USE_AOS_GCC_TOOLCHAIN)
3312+  dynamic_matmul_compute_fp32 = DynamicMatmulSdot4x4x16AIWI;
3313+#else
3314+  dynamic_matmul_compute_fp32 = DynamicMatmul4x4x16AIWI;
3315+#endif
3316+#ifdef ENABLE_FP16
3317+#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(SUPPORT_34XX) && (!defined(MACHINE_LINUX_ARM64)) && \
3318+  !defined(USE_AOS_GCC_TOOLCHAIN)
3319+  dynamic_matmul_compute_fp16 = DynamicMatmulSdot4x4x16AIWIForFp16;
3320+#else
3321+  dynamic_matmul_compute_fp16 = DynamicMatmul4x4x16AIWIForFp16;
3322+#endif
3323+#endif
3324 }
3325
3326 int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() {
3327   int8_t *a_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->data());
3328   int8_t *b_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->data());
3329-  float *c_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data());
3330+  void *c_ptr = out_tensors_.at(0)->data();
3331   CHECK_NULL_RETURN(a_ptr);
3332   CHECK_NULL_RETURN(b_ptr);
3333   CHECK_NULL_RETURN(c_ptr);
3334
3335+  size_t data_type_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float);
3336   for (int i = 0; i < param_->batch; i++) {
3337-    batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_;
3338+    batch_input_ptr_ = a_ptr + a_offset_[i] * param_->row_ * param_->deep_;
3339     auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_);
3340     if (ret != RET_OK) {
3341       MS_LOG(ERROR) << "Arm64SdotPreRun error: [" << ret << "]";
3342       return ret;
3343     }
3344
3345-    batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_;
3346-    batch_sums_ = weight_sums_ + i * param_->col_align_;
3347-    batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_;
3348-    batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_;
3349-
3350-    ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_);
3351+    batch_weight_ptr_ = b_ptr + b_offset_[i] * param_->col_ * param_->deep_;
3352+    batch_sums_ = weight_sums_ + b_offset_[i] * param_->col_align_;
3353+    batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_;
3354+    batch_c_ptr_ = static_cast<uint8_t *>(c_ptr) + i * param_->row_ * param_->col_ * data_type_size;
3355+    ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_num_);
3356     if (ret != RET_OK) {
3357       MS_LOG(ERROR) << "Arm64SdotRun error: [" << ret << "]";
3358       return ret;
3359@@ -189,11 +244,18 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() {
3360 }
3361
3362 int MatMulDynamicSdotInt8Kernel::Run() {
3363-  auto ret = InitInputQuantParam();
3364+  std::vector<float> input_scales;
3365+  std::vector<int32_t> input_zp;
3366+  auto ret = InitInputQuantParam(&input_scales, &input_zp);
3367   if (ret != RET_OK) {
3368     MS_LOG(ERROR) << "Init input quant param failed.";
3369     return ret;
3370   }
3371+  ret = InitMatrixABuffer();
3372+  if (ret != RET_OK) {
3373+    MS_LOG(ERROR) << "Alloc run-buffer for matrix-a failed.";
3374+    return ret;
3375+  }
3376   if (!param_->b_const_) {
3377     ret = InitFilterQuantParam();
3378     if (ret != RET_OK) {
3379@@ -202,6 +264,8 @@ int MatMulDynamicSdotInt8Kernel::Run() {
3380       return ret;
3381     }
3382   }
3383-  return MatMulDynamicRunArm64Sdot();
3384+  ret = MatMulDynamicRunArm64Sdot();
3385+  FreeMatrixABuffer();
3386+  return ret;
3387 }
3388 }  // namespace mindspore::kernel
3389diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h
3390index fc6832bc..131af45b 100644
3391--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h
3392+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h
3393@@ -1,5 +1,5 @@
3394 /**
3395- * Copyright 2022 Huawei Technologies Co., Ltd
3396+ * Copyright 2022-2023 Huawei Technologies Co., Ltd
3397  *
3398  * Licensed under the Apache License, Version 2.0 (the "License");
3399  * you may not use this file except in compliance with the License.
3400@@ -25,20 +25,35 @@ class MatMulDynamicSdotInt8Kernel : public MatmulDynamicBaseInt8CPUKernel {
3401  public:
3402   MatMulDynamicSdotInt8Kernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
3403                               const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
3404-      : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {}
3405+      : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {
3406+    param_->matmul_type_ = MatmulType::kMatmulDynamicSdotInt8Cpu;
3407+  }
3408   ~MatMulDynamicSdotInt8Kernel() override = default;
3409   int Run() override;
3410
3411  public:
3412-  int MatMulDynamicRunArm64Sdot();
3413   int MatMulDynamicArm64SdotPre(int task_id);
3414   int MatMulDynamicArm64SdotImpl(int task_id);
3415
3416- private:
3417+ protected:
3418   void InitParameter() override;
3419
3420  private:
3421+  template <typename T>
3422+  using DynamicMatmulComputer = void (*)(const int8_t *a, const int8_t *b, T *out, size_t deep4,
3423+                                         const float *multi_scles, const T *bias, size_t row, size_t col, size_t stride,
3424+                                         const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum,
3425+                                         int64_t act_type, int64_t mode);
3426+
3427+  int MatMulDynamicRunArm64Sdot();
3428+  void ComputeMultiScaleAhead(std::vector<float> *multi_scale, int col_start, size_t col_num);
3429+  void ComputeMultiScaleChannelByChannel(std::vector<float> *multi_scale, int row_start, size_t row_num, int col_start,
3430+                                         size_t col_num);
3431   int *batch_sums_ = nullptr;
3432+  DynamicMatmulComputer<float> dynamic_matmul_compute_fp32{nullptr};
3433+#ifdef ENABLE_FP16
3434+  DynamicMatmulComputer<float16_t> dynamic_matmul_compute_fp16{nullptr};
3435+#endif
3436 };
3437 }  // namespace mindspore::kernel
3438
3439diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc
3440index b539224f..5ad3fd8a 100644
3441--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc
3442+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc
3443@@ -92,7 +92,7 @@ kernel::LiteKernel *MatmulInt8CPUKernelCreator(const std::vector<lite::Tensor *>
3444       MS_LOG(ERROR) << "kernel: " << parameter->name_ << " is unsupported A is const.";
3445       return nullptr;
3446     }
3447-    if (lite::IsSupportSDot()) {
3448+    if (lite::IsSupportSDot() || static_cast<const lite::InnerContext *>(ctx)->instructions_ctx_.support_sdot) {
3449       kernel = new (std::nothrow)
3450         MatMulDynamicSdotInt8Kernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
3451     } else {
3452diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h
3453index 4e9e4e42..d711f727 100644
3454--- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h
3455+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h
3456@@ -29,7 +29,9 @@ class MatmulInt8CPUKernel : public MatmulBaseInt8CPUKernel {
3457  public:
3458   MatmulInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
3459                       const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
3460-      : MatmulBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {}
3461+      : MatmulBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {
3462+    param_->matmul_type_ = MatmulType::kMatmulInt8Cpu;
3463+  }
3464   ~MatmulInt8CPUKernel() override = default;
3465   int Prepare() override;
3466   int ReSize() override;
3467diff --git a/mindspore/lite/src/runtime/kernel_registry.h b/mindspore/lite/src/runtime/kernel_registry.h
3468index f563d82d..82f55c81 100644
3469--- a/mindspore/lite/src/runtime/kernel_registry.h
3470+++ b/mindspore/lite/src/runtime/kernel_registry.h
3471@@ -46,13 +46,13 @@ class MS_API KernelRegistry {
3472                     const InnerContext *ctx, const mindspore::Context *ms_ctx, const kernel::KernelKey &key,
3473                     OpParameter *op_parameter, kernel::KernelExec **kernel, const void *primitive = nullptr);
3474   int ReplaceKernelExec(kernel::KernelExec *kernel, const kernel::KernelKey &key);
3475+  kernel::LiteKernel *GetLiteKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
3476+                                    const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *parameter);
3477
3478  protected:
3479   int GetCustomKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
3480                       const mindspore::Context *ctx, const kernel::KernelKey &key, kernel::KernelExec **kernel,
3481                       const void *primitive = nullptr);
3482-  kernel::LiteKernel *GetLiteKernel(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
3483-                                    const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *parameter);
3484   static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1};
3485   static const int data_type_length_{kNumberTypeEnd - kNumberTypeBegin + 1};
3486   static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1};
3487diff --git a/mindspore/lite/src/runtime/lite_kernel.h b/mindspore/lite/src/runtime/lite_kernel.h
3488index a27f77d8..c85278f4 100644
3489--- a/mindspore/lite/src/runtime/lite_kernel.h
3490+++ b/mindspore/lite/src/runtime/lite_kernel.h
3491@@ -193,6 +193,8 @@ class MS_API LiteKernel : public Abstractkernel {
3492   const lite::Context *context() const { return this->ms_context_; }
3493   bool ws_allocated_ = false;
3494
3495+  virtual int PreparePackedWeight(const lite::Tensor *tensor) { return mindspore::lite::RET_OK; }
3496+
3497  protected:
3498   OpParameter *op_parameter_ = nullptr;
3499   // tensor will free in ~lite_session()
3500diff --git a/mindspore/lite/src/runtime/lite_model.cc b/mindspore/lite/src/runtime/lite_model.cc
3501index cd8e68d1..662a0856 100644
3502--- a/mindspore/lite/src/runtime/lite_model.cc
3503+++ b/mindspore/lite/src/runtime/lite_model.cc
3504@@ -29,6 +29,7 @@
3505 #include "src/common/file_utils.h"
3506 #include "src/tensor.h"
3507 #include "extendrt/mindir_loader/model_loader.h"
3508+#include "src/common/mmap_utils.h"
3509
3510 namespace mindspore::lite {
3511 namespace {
3512@@ -36,7 +37,11 @@ constexpr size_t kMaxModelBufferSize = static_cast<size_t>(1024) * 1024 * 1024 *
3513 }
3514
3515 void LiteModel::Free() {
3516-  if (this->buf != nullptr) {
3517+  if (this->model_buf_by_mmap_) {
3518+    UnmapMmapBuffer(static_cast<void *>(this->buf), this->buf_size_);
3519+    this->buf = nullptr;
3520+  }
3521+  if (this->buf != nullptr && !this->model_buf_by_mmap_) {
3522     delete[](this->buf);
3523     this->buf = nullptr;
3524   }
3525diff --git a/mindspore/lite/src/runtime/lite_model.h b/mindspore/lite/src/runtime/lite_model.h
3526index af62cb91..d18ae051 100644
3527--- a/mindspore/lite/src/runtime/lite_model.h
3528+++ b/mindspore/lite/src/runtime/lite_model.h
3529@@ -310,6 +310,7 @@ class MS_API LiteModel : public Model {
3530
3531  public:
3532   std::vector<void *> node_bufs_;
3533+  bool model_buf_by_mmap_ = false;
3534
3535  protected:
3536   std::vector<char *> attr_tensor_bufs_;
3537diff --git a/mindspore/lite/src/runtime/lite_session.cc b/mindspore/lite/src/runtime/lite_session.cc
3538index 6661b410..eb1b5ef7 100644
3539--- a/mindspore/lite/src/runtime/lite_session.cc
3540+++ b/mindspore/lite/src/runtime/lite_session.cc
3541@@ -33,10 +33,12 @@
3542 #include "src/common/graph_util.h"
3543 #include "src/common/tensor_util.h"
3544 #include "src/common/file_utils.h"
3545+#include "src/common/mmap_utils.h"
3546 #include "src/runtime/lite_model.h"
3547 #include "src/runtime/weight_decoder.h"
3548 #include "src/runtime/runtime_allocator.h"
3549 #include "src/runtime/kernel_exec_util.h"
3550+#include "src/runtime/runtime_packed_node_pass.h"
3551 #ifndef CUSTOM_KERNEL_REGISTRY_CLIP
3552 #include "src/registry/register_kernel_impl.h"
3553 #endif
3554@@ -561,7 +563,7 @@ int LiteSession::CompileGraph(Model *model) {
3555   }
3556   InitGraphInputTensors(model);
3557   InitGraphOutputTensors(model);
3558-
3559+  PackedNodePass::GetInstance().Run(model, tensors_);
3560   // scheduler kernels
3561   Scheduler scheduler(context_, ms_context_, model, &tensors_, &inputs_, &outputs_, is_train_session_, &is_infershape_,
3562                       &is_control_flow_, execution_plan_, delegate_, delegate_device_type_);
3563@@ -672,6 +674,11 @@ int LiteSession::PrepareKernels(const Model *model) {
3564         return RET_ERROR;
3565       }
3566       for (auto &node : subgraph_kernel->nodes()) {
3567+        ret = PackKernelExec(node, tensors_);
3568+        if (ret != RET_OK) {
3569+          MS_LOG(ERROR) << "Pack KernelExec failed.";
3570+          return ret;
3571+        }
3572         ret = node->Prepare();
3573         if (ret != RET_OK) {
3574           MS_LOG(ERROR) << "node: " << node->name() << " prepare failed.";
3575@@ -1707,9 +1714,14 @@ const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspor
3576 }
3577
3578 const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size,
3579-                                               const std::shared_ptr<mindspore::Context> &ms_context) {
3580+                                               const std::shared_ptr<mindspore::Context> &ms_context, bool use_mmap) {
3581   size_t buf_size;
3582-  auto model_buf = lite::ReadFile(file.c_str(), &buf_size);
3583+  char *model_buf;
3584+  if (use_mmap) {
3585+    model_buf = reinterpret_cast<char *>(lite::ReadFileByMmap(file.c_str(), &buf_size));
3586+  } else {
3587+    model_buf = lite::ReadFile(file.c_str(), &buf_size);
3588+  }
3589   if (model_buf == nullptr) {
3590     MS_LOG(ERROR) << "The model path is invalid";
3591     return model_buf;
3592@@ -1829,7 +1841,8 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path,
3593 int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type,
3594                                                  const std::shared_ptr<mindspore::Context> &ms_context) {
3595   size_t model_size;
3596-  auto model_buf = LoadModelByPath(model_path, model_type, &model_size, ms_context);
3597+  bool use_mmap = IsMmapEnable();
3598+  auto model_buf = LoadModelByPath(model_path, model_type, &model_size, ms_context, use_mmap);
3599   if (model_buf == nullptr) {
3600     MS_LOG(ERROR) << "Read model file failed";
3601     return RET_ERROR;
3602@@ -1837,17 +1850,26 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path,
3603   auto *model = lite::ImportFromBuffer(model_buf, model_size, true, model_type, model_path);
3604   if (model == nullptr) {
3605     MS_LOG(ERROR) << "Import model failed";
3606-    delete[] model_buf;
3607+    if (use_mmap) {
3608+      lite::UnmapMmapBuffer(const_cast<void *>(static_cast<const void *>(model_buf)), model_size);
3609+    } else {
3610+      delete[] model_buf;
3611+    }
3612     return RET_ERROR;
3613   }
3614   auto status = lite::PackWeightManager::GetInstance()->InitPackWeightByBuf(model_buf, model_size);
3615   MS_CHECK_FALSE_MSG(status != RET_OK, RET_ERROR, "InitPackWeightByBuf failed.");
3616
3617   (reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true);
3618+  reinterpret_cast<lite::LiteModel *>(model)->model_buf_by_mmap_ = use_mmap;
3619   auto ret = CompileGraph(model);
3620   if (ret != lite::RET_OK) {
3621     MS_LOG(ERROR) << "Compile model failed";
3622-    delete[] model_buf;
3623+    if (use_mmap) {
3624+      lite::UnmapMmapBuffer(const_cast<void *>(static_cast<const void *>(model_buf)), model_size);
3625+    } else {
3626+      delete[] model_buf;
3627+    }
3628     model->buf = nullptr;
3629     delete model;
3630     return RET_ERROR;
3631@@ -1855,4 +1877,15 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path,
3632   set_model(model);
3633   return RET_OK;
3634 }
3635+
3636+bool lite::LiteSession::IsMmapEnable() {
3637+#if !defined(_WIN32) && !defined(_WIN64)
3638+  if (delegate_device_type_ == DT_NPU) {
3639+    return false;
3640+  }
3641+  return true;
3642+#else
3643+  return false;
3644+#endif
3645+}
3646 }  // namespace mindspore
3647diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h
3648index d5a672bb..c9edf63e 100644
3649--- a/mindspore/lite/src/runtime/lite_session.h
3650+++ b/mindspore/lite/src/runtime/lite_session.h
3651@@ -60,7 +60,7 @@ class MS_API LiteSession {
3652                                               const std::shared_ptr<mindspore::Context> &ms_context);
3653   static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size);
3654   static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size,
3655-                                     const std::shared_ptr<mindspore::Context> &ms_context);
3656+                                     const std::shared_ptr<mindspore::Context> &ms_context, bool use_mmap = false);
3657   virtual int Init(InnerContext *context);
3658   virtual void BindThread(bool if_bind);
3659   virtual int CompileGraph(Model *model);
3660@@ -154,6 +154,7 @@ class MS_API LiteSession {
3661   static void FreePackOpWeight(const std::vector<kernel::KernelExec *> &kernels);
3662   std::string ParseWeightPath();
3663   static void MarkSharedWeight(const std::vector<kernel::KernelExec *> &kernels);
3664+  bool IsMmapEnable();
3665
3666  private:
3667   int PreCheck(Model *model);
3668diff --git a/mindspore/lite/src/runtime/runtime_packed_node_pass.cc b/mindspore/lite/src/runtime/runtime_packed_node_pass.cc
3669new file mode 100644
3670index 00000000..81d50522
3671--- /dev/null
3672+++ b/mindspore/lite/src/runtime/runtime_packed_node_pass.cc
3673@@ -0,0 +1,358 @@
3674+/**
3675+ * Copyright 2023 Huawei Technologies Co., Ltd
3676+ *
3677+ * Licensed under the Apache License, Version 2.0 (the "License");
3678+ * you may not use this file except in compliance with the License.
3679+ * You may obtain a copy of the License at
3680+ *
3681+ * http://www.apache.org/licenses/LICENSE-2.0
3682+ *
3683+ * Unless required by applicable law or agreed to in writing, software
3684+ * distributed under the License is distributed on an "AS IS" BASIS,
3685+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
3686+ * See the License for the specific language governing permissions and
3687+ * limitations under the License.
3688+ */
3689+#include "src/runtime/runtime_packed_node_pass.h"
3690+#include "nnacl/op_base.h"
3691+#include "nnacl/matmul_parameter.h"
3692+
3693+using RecoveryWeightFunc = void (*)(void *, void *, int, int, bool);
3694+namespace mindspore {
3695+namespace {
3696+constexpr size_t kFlatbuffersBuilderInitSize = 1024;
3697+constexpr auto kActivationType = "activation_type";
3698+constexpr auto kTransposeA = "transpose_a";
3699+constexpr auto kTransposeB = "transpose_b";
3700+constexpr auto kArm64SimdDot = "ARM64SIMD_DOT";
3701+}  // namespace
3702+
3703+namespace lite {
3704+PackedNodePass::~PackedNodePass() {
3705+  for (auto &pack_info : node_pack_info_map_) {
3706+    delete pack_info.second;
3707+  }
3708+  node_pack_info_map_.clear();
3709+}
3710+
3711+void PackedNodePass::Run(Model *model, const std::vector<Tensor *> &tensors) {
3712+  for (auto &node : model->graph_.all_nodes_) {
3713+    MS_ASSERT(node != nullptr);
3714+    if (node->node_type_ != schema::PrimitiveType_Custom) {
3715+      continue;
3716+    }
3717+    auto *primitive = reinterpret_cast<const schema::Primitive *>(node->primitive_);
3718+    if (primitive == nullptr) {
3719+      MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!";
3720+      return;
3721+    }
3722+    auto custom = primitive->value_as_Custom();
3723+    if (custom == nullptr || custom->type() == nullptr) {
3724+      MS_LOG(ERROR) << "Custom node is nullptr";
3725+      return;
3726+    }
3727+    auto custom_type = custom->type()->str();
3728+    if (custom_type != "MatmulFusionPacked") {
3729+      continue;
3730+    }
3731+    flatbuffers::FlatBufferBuilder fbb(kFlatbuffersBuilderInitSize);
3732+
3733+    auto custom_attr = custom->attr();
3734+    std::map<std::string, std::string> attr_map;
3735+    for (size_t i = 0; i < custom_attr->size(); ++i) {
3736+      auto attr = custom_attr->Get(i);
3737+      auto attr_key = attr->name()->str();
3738+      auto data_bytes = attr->data();
3739+      int data_size = static_cast<int>(data_bytes->size());
3740+      std::string attr_value;
3741+      for (int j = 0; j < data_size; j++) {
3742+        attr_value.push_back(static_cast<char>(data_bytes->Get(j)));
3743+      }
3744+      attr_map[attr_key] = attr_value;
3745+    }
3746+    if (attr_map.find(kActivationType) == attr_map.end() || attr_map.find(kTransposeA) == attr_map.end() ||
3747+        attr_map.find(kTransposeB) == attr_map.end()) {
3748+      MS_LOG(ERROR) << "Custom attr error.";
3749+      return;
3750+    }
3751+    auto val_offset =
3752+      schema::CreateMatMulFusion(fbb, std::stoi(attr_map[kTransposeA]), std::stoi(attr_map[kTransposeB]),
3753+                                 static_cast<schema::ActivationType>(std::stoi(attr_map[kActivationType])));
3754+    auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MatMulFusion, val_offset.o);
3755+    fbb.Finish(prim_offset);
3756+    void *prim = malloc(fbb.GetSize());
3757+    if (prim == nullptr) {
3758+      MS_LOG(ERROR) << "malloc primitive failed.";
3759+      return;
3760+    }
3761+    memcpy(prim, fbb.GetBufferPointer(), fbb.GetSize());
3762+    auto custom_primitive = flatbuffers::GetRoot<schema::Primitive>(prim);
3763+    fbb.Clear();
3764+    PackInfo *pack_info = new (std::nothrow) PackInfo();
3765+    if (pack_info == nullptr) {
3766+      free(prim);
3767+      MS_LOG(ERROR) << "new PackInfo failed.";
3768+      return;
3769+    }
3770+    node->primitive_ = custom_primitive;
3771+    pack_info->is_packed_ = true;
3772+    pack_info->b_batch_ = std::stoi(attr_map["b_batch"]);
3773+    pack_info->col_ = std::stoi(attr_map["col"]);
3774+    pack_info->deep_ = std::stoi(attr_map["deep"]);
3775+    pack_info->col_align_ = std::stoi(attr_map["col_align"]);
3776+    pack_info->deep_align_ = std::stoi(attr_map["deep_align"]);
3777+    pack_info->b_transpose_ = std::stoi(attr_map[kTransposeB]);
3778+    pack_info->cpu_option_ = attr_map["cpu_option"];
3779+    AddNodePackInfo(node->name_, pack_info);
3780+    if (node->quant_type_ == schema::QuantType_QUANT_DYNAMIC) {
3781+      pack_info->weight_sums_index_ = node->input_indices_.back();
3782+      node->input_indices_.pop_back();
3783+      if (!(reinterpret_cast<lite::LiteModel *>(model)->keep_model_buf())) {
3784+        auto index = static_cast<size_t>(pack_info->weight_sums_index_);
3785+        if (index > tensors.size()) {
3786+          MS_LOG(ERROR) << "weight sums tensor index is error.";
3787+          return;
3788+        }
3789+        auto tensor = tensors[index];
3790+        CopyWeightBiasSumsTensor(tensor);
3791+      }
3792+    }
3793+
3794+    node->node_type_ = schema::PrimitiveType_MatMulFusion;
3795+  }
3796+  need_store_weight_ = !(reinterpret_cast<lite::LiteModel *>(model)->keep_model_buf());
3797+}
3798+
3799+void PackedNodePass::CopyWeightBiasSumsTensor(Tensor *tensor) {
3800+  if (!tensor->IsConst() && tensor->data() != nullptr) {
3801+    return;
3802+  }
3803+  if (!tensor->IsConst() || tensor->own_data()) {
3804+    return;
3805+  }
3806+  if (tensor->data_type() == kObjectTypeTensorType) {
3807+    MS_ASSERT(tensor->data() == nullptr);
3808+  } else {
3809+    auto copy_tensor = Tensor::CopyTensor(*tensor, true);
3810+    if (copy_tensor == nullptr) {
3811+      MS_LOG(ERROR) << "Copy tensor failed";
3812+      return;
3813+    }
3814+    tensor->FreeData();
3815+    tensor->set_data(copy_tensor->data());
3816+    tensor->set_own_data(true);
3817+    copy_tensor->set_data(nullptr);
3818+    delete copy_tensor;
3819+  }
3820+}
3821+
3822+int PackedNodePass::StoreWeightTensor(Tensor *tensor, size_t data_size) {
3823+  void *weight_data = malloc(data_size);
3824+  if (weight_data == nullptr) {
3825+    MS_LOG(ERROR) << "malloc weight tensor failed.";
3826+    return RET_NULL_PTR;
3827+  }
3828+  memcpy(weight_data, tensor->data(), data_size);
3829+  tensor->FreeData();
3830+  tensor->set_data(weight_data);
3831+  tensor->IncRefCount();
3832+  return RET_OK;
3833+}
3834+
3835+void MatmulDynamicSdotInt8Unpack(void *src, void *dst, int row, int col, bool transpose) {
3836+  auto src_int8 = static_cast<int8_t *>(src);
3837+  auto dst_int8 = static_cast<int8_t *>(dst);
3838+  if (!transpose) {
3839+    // RowMajor2Col4x16MajorInt8
3840+    int row_4 = UP_ROUND(row, C4NUM);
3841+    int stride = C16NUM * C4NUM;
3842+    for (int r = 0; r < row_4; ++r) {
3843+      for (int c = 0; c < col; ++c) {
3844+        int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM;
3845+        if (r < row) {
3846+          int src_idx = r * col + c;
3847+          src_int8[src_idx] = dst_int8[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM];
3848+        }
3849+      }
3850+    }
3851+  } else {
3852+    int temp = row;
3853+    row = col;
3854+    col = temp;
3855+    // RowMajor2Row4x16MajorInt8
3856+    int col4 = UP_ROUND(col, C4NUM);
3857+    for (int r = 0; r < row; r++) {
3858+      int rd16 = r / C16NUM;
3859+      int rm16 = r % C16NUM;
3860+      for (int c = 0; c < col; c++) {
3861+        int cd4 = c / C4NUM;
3862+        int cm4 = c % C4NUM;
3863+        int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4;
3864+        int src_index = r * col + c;
3865+        src_int8[src_index] = dst_int8[dst_index];
3866+      }
3867+    }
3868+  }
3869+}
3870+
3871+void MatmulFp32BaseUnpack(void *src, void *dst, int row, int col, bool transpose) {
3872+  if (!transpose) {
3873+    // RowMajor2Row8MajorParallel
3874+    auto src_r = static_cast<float *>(src);
3875+    auto dst_r = static_cast<float *>(dst);
3876+    for (int r = 0; r < row; r++) {
3877+      float *src_c = src_r + r * col;
3878+      int c = 0;
3879+      for (; c < col; c++) {
3880+        int cd8 = c / C8NUM;
3881+        int cm8 = c % C8NUM;
3882+        src_c[c] = dst_r[cd8 * C8NUM * row + r * C8NUM + cm8];
3883+      }
3884+    }
3885+    return;
3886+  }
3887+  // RowMajor2Col8MajorParallel
3888+  auto src_r = static_cast<float *>(src);
3889+  auto dst_r = static_cast<float *>(dst);
3890+  int row8 = row / C8NUM * C8NUM;
3891+  int col_skip = col / C4NUM * C4NUM;
3892+  int skip_size = C4NUM;
3893+
3894+  int ri = 0;
3895+  for (; ri < row8; ri += C8NUM) {
3896+    int ci = 0;
3897+    for (; ci < col_skip; ci += skip_size) {
3898+      float *src_c = src_r + ci;
3899+      float *dst_c = dst_r + ci * C8NUM;
3900+      for (int tr = 0; tr < C8NUM; tr++) {
3901+        for (int tc = 0; tc < C4NUM; tc++) {
3902+          src_c[tr * col + tc] = dst_c[tc * C8NUM + tr];
3903+        }
3904+      }
3905+    }
3906+    for (; ci < col; ci++) {
3907+      float *src_c = src_r + ci;
3908+      float *dst_c = dst_r + ci * C8NUM;
3909+      for (int i = 0; i < C8NUM; i++) {
3910+        src_c[i * col] = dst_c[i];
3911+      }
3912+    }
3913+    src_r += C8NUM * col;
3914+    dst_r += C8NUM * col;
3915+  }
3916+  for (; ri < row; ri++, src_r += col, dst_r++) {
3917+    for (int i = 0; i < col; i++) {
3918+      src_r[i] = dst_r[i * C8NUM];
3919+    }
3920+  }
3921+}
3922+
3923+RecoveryWeightFunc GetRecoveryWeightFunc(const int quant_type, const TypeId data_type, const int node_type,
3924+                                         const std::string &cpu_option) {
3925+  if (cpu_option == kArm64SimdDot && node_type == schema::PrimitiveType_MatMulFusion &&
3926+      quant_type == schema::QuantType_QUANT_DYNAMIC && data_type == kNumberTypeInt8) {
3927+    return MatmulDynamicSdotInt8Unpack;
3928+  }
3929+
3930+  if (cpu_option == kArm64SimdDot && node_type == schema::PrimitiveType_MatMulFusion &&
3931+      data_type == kNumberTypeFloat32) {
3932+    return MatmulFp32BaseUnpack;
3933+  }
3934+  return nullptr;
3935+}
3936+
3937+size_t GetWeightTensorSize(MatmulType matmul_type, const PackInfo &pack_info) {
3938+  size_t data_size = 0;
3939+  if (matmul_type == kMatmulDynamicSdotInt8Cpu) {
3940+    data_size = pack_info.b_batch_ * pack_info.col_align_ * pack_info.deep_align_ * sizeof(int8_t);
3941+  } else {
3942+    data_size = pack_info.b_batch_ * pack_info.col_align_ * pack_info.deep_ * sizeof(float);
3943+  }
3944+  return data_size;
3945+}
3946+
3947+int PackedMatmulKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors) {
3948+  auto pack_info = PackedNodePass::GetInstance().GetNodePackInfo(kernel_exec->name());
3949+  if (pack_info == nullptr) {
3950+    return RET_OK;
3951+  }
3952+  MS_CHECK_TRUE_MSG(kernel_exec->in_tensors().size() >= kInputSize1, lite::RET_ERROR,
3953+                    "kernel doesn't have weight tensor.");
3954+  auto dst_tensor = kernel_exec->in_tensors()[SECOND_INPUT];
3955+  auto kernel = kernel_exec->kernel();
3956+  MS_CHECK_TRUE_MSG(kernel != nullptr, lite::RET_NULL_PTR, "kernel is nullptr.");
3957+  auto param = reinterpret_cast<MatMulParameter *>(kernel_exec->op_parameter());
3958+  if (dst_tensor->data_type() == kNumberTypeFloat32) {
3959+    if (param->matmul_type_ == kNotImplemented) {
3960+      return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
3961+                                  schema::PrimitiveType_MatMulFusion, pack_info);
3962+    }
3963+  }
3964+
3965+  if (dst_tensor->data_type() == kNumberTypeInt8 && param->matmul_type_ != kMatmulDynamicSdotInt8Cpu &&
3966+      pack_info->cpu_option_ == kArm64SimdDot) {
3967+    return RecoveryPackedWeight(dst_tensor, static_cast<int>(kernel->quant_type()), dst_tensor->data_type(),
3968+                                schema::PrimitiveType_MatMulFusion, pack_info);
3969+  }
3970+
3971+  if (PackedNodePass::GetInstance().GetNeedStoreWeight()) {
3972+    size_t data_size = GetWeightTensorSize(param->matmul_type_, *pack_info);
3973+    int ret = PackedNodePass::GetInstance().StoreWeightTensor(dst_tensor, data_size);
3974+    if (ret != RET_OK) {
3975+      MS_LOG(ERROR) << "store weight tensor error.";
3976+      return ret;
3977+    }
3978+  }
3979+  auto lite_kernel = static_cast<kernel::LiteKernel *>(kernel);
3980+  lite::Tensor *weight_sums = nullptr;
3981+  auto index = static_cast<size_t>(pack_info->weight_sums_index_);
3982+  if (index < tensors.size()) {
3983+    weight_sums = tensors.at(index);
3984+  }
3985+  return lite_kernel->PreparePackedWeight(weight_sums);
3986+}
3987+
3988+int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type,
3989+                         PackInfo *pack_info) {
3990+  auto recovery_func = GetRecoveryWeightFunc(quant_type, data_type, node_type, pack_info->cpu_option_);
3991+  if (recovery_func == nullptr) {
3992+    MS_LOG(ERROR) << "unsupported recovery func.";
3993+    return RET_NULL_PTR;
3994+  }
3995+  void *unpack_data = malloc(weight->Size());
3996+  if (unpack_data == nullptr) {
3997+    MS_LOG(ERROR) << "malloc unpack_data failed.";
3998+    return RET_NULL_PTR;
3999+  }
4000+  void *pack_b_ptr = weight->data();
4001+  for (int i = 0; i < pack_info->b_batch_; i++) {
4002+    void *current_weight;
4003+    void *current_b_pack;
4004+    if (weight->data_type() == kNumberTypeInt8) {
4005+      current_weight = static_cast<void *>(static_cast<int8_t *>(unpack_data) + i * pack_info->deep_ * pack_info->col_);
4006+      current_b_pack =
4007+        static_cast<void *>(static_cast<int8_t *>(pack_b_ptr) + i * pack_info->col_align_ * pack_info->deep_align_);
4008+    } else if (weight->data_type() == kNumberTypeFloat32) {
4009+      current_weight = static_cast<void *>(static_cast<float *>(unpack_data) + i * pack_info->deep_ * pack_info->col_);
4010+      current_b_pack =
4011+        static_cast<void *>(static_cast<float *>(pack_b_ptr) + i * pack_info->col_align_ * pack_info->deep_);
4012+    } else {
4013+      free(unpack_data);
4014+      MS_LOG(ERROR) << "unsupported data type.";
4015+      return RET_ERROR;
4016+    }
4017+    recovery_func(current_weight, current_b_pack, pack_info->deep_, pack_info->col_, pack_info->b_transpose_);
4018+  }
4019+  weight->FreeData();
4020+  weight->set_data(unpack_data);
4021+  return RET_OK;
4022+}
4023+
4024+int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors) {
4025+  if (kernel_exec->type() == schema::PrimitiveType_MatMulFusion) {
4026+    return PackedMatmulKernelExec(kernel_exec, tensors);
4027+  }
4028+  return RET_OK;
4029+}
4030+}  // namespace lite
4031+}  // namespace mindspore
4032diff --git a/mindspore/lite/src/runtime/runtime_packed_node_pass.h b/mindspore/lite/src/runtime/runtime_packed_node_pass.h
4033new file mode 100644
4034index 00000000..0ba18eb7
4035--- /dev/null
4036+++ b/mindspore/lite/src/runtime/runtime_packed_node_pass.h
4037@@ -0,0 +1,83 @@
4038+/**
4039+ * Copyright 2023 Huawei Technologies Co., Ltd
4040+ *
4041+ * Licensed under the Apache License, Version 2.0 (the "License");
4042+ * you may not use this file except in compliance with the License.
4043+ * You may obtain a copy of the License at
4044+ *
4045+ * http://www.apache.org/licenses/LICENSE-2.0
4046+ *
4047+ * Unless required by applicable law or agreed to in writing, software
4048+ * distributed under the License is distributed on an "AS IS" BASIS,
4049+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4050+ * See the License for the specific language governing permissions and
4051+ * limitations under the License.
4052+ */
4053+
4054+#ifndef MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_PACKED_NODE_PASS_
4055+#define MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_PACKED_NODE_PASS_
4056+
4057+#include <string>
4058+#include <map>
4059+#include <vector>
4060+#include "src/runtime/lite_model.h"
4061+#include "src/tensor.h"
4062+#include "src/runtime/kernel_exec.h"
4063+
4064+namespace mindspore {
4065+namespace lite {
4066+struct PackInfo {
4067+  bool is_packed_{false};
4068+  int weight_sums_index_{-1};
4069+  int b_batch_;
4070+  int deep_;
4071+  int col_;
4072+  int deep_align_;
4073+  int col_align_;
4074+  bool b_transpose_{false};
4075+  std::string cpu_option_;
4076+};
4077+
4078+class PackedNodePass {
4079+ public:
4080+  static PackedNodePass &GetInstance() {
4081+    static PackedNodePass instance;
4082+    return instance;
4083+  }
4084+
4085+  PackInfo *GetNodePackInfo(const std::string &node_name) {
4086+    if (this->node_pack_info_map_.find(node_name) == this->node_pack_info_map_.end()) {
4087+      return nullptr;
4088+    }
4089+    return this->node_pack_info_map_[node_name];
4090+  }
4091+  void Run(Model *model, const std::vector<Tensor *> &tensors);
4092+  void CopyWeightBiasSumsTensor(Tensor *tensor);
4093+  int StoreWeightTensor(Tensor *tensor, size_t data_size);
4094+  bool GetNeedStoreWeight() const { return need_store_weight_; }
4095+
4096+ protected:
4097+  void AddNodePackInfo(const std::string &node_name, PackInfo *pack_info) {
4098+    if (this->node_pack_info_map_.find(node_name) != this->node_pack_info_map_.end()) {
4099+      MS_LOG(WARNING) << "Key conflict when add weight sums index.";
4100+    }
4101+    this->node_pack_info_map_[node_name] = pack_info;
4102+  }
4103+
4104+ private:
4105+  PackedNodePass() = default;
4106+  ~PackedNodePass();
4107+
4108+ private:
4109+  std::map<std::string, PackInfo *> node_pack_info_map_;
4110+  bool need_store_weight_{false};
4111+};
4112+
4113+int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector<Tensor *> &tensors);
4114+
4115+// packed weight data -> unpack
4116+int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type,
4117+                         PackInfo *packInfo);
4118+}  // namespace lite
4119+}  // namespace mindspore
4120+#endif  // MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_PACKED_NODE_PASS_
4121diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc
4122index bf5de821..887a78e3 100644
4123--- a/mindspore/lite/tools/common/graph_util.cc
4124+++ b/mindspore/lite/tools/common/graph_util.cc
4125@@ -36,7 +36,32 @@ namespace mindspore {
4126 namespace lite {
4127 namespace {
4128 const int kZeroPointGap = 128;
4129+constexpr size_t kTupleGetItemInputSize = 3;
4130+constexpr size_t kSecondIndex = 1;
4131 }  // namespace
4132+static STATUS GetAbstractfromTupleGetItem(const CNodePtr &cnode, AbstractBasePtr *abstract, size_t *idx) {
4133+  MS_CHECK_TRUE_MSG(abstract != nullptr, lite::RET_ERROR, "Abstract is nullptr.");
4134+  MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr.");
4135+  auto tuple_inputs = cnode->inputs();
4136+  MS_CHECK_TRUE_MSG(tuple_inputs.size() == kTupleGetItemInputSize, lite::RET_ERROR, "The node must have 3 inputs!");
4137+  auto get_item_input_cnode = tuple_inputs.at(kSecondIndex);
4138+  MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input node is nullptr.");
4139+  *idx = opt::GetTupleGetItemOutIndex(cnode);
4140+  if (!mindspore::utils::isa<mindspore::abstract::AbstractTuplePtr>(get_item_input_cnode->abstract())) {
4141+    MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple, cnode name: "
4142+                  << get_item_input_cnode->fullname_with_scope();
4143+    return lite::RET_ERROR;
4144+  }
4145+  auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(get_item_input_cnode->abstract());
4146+  auto abstract_list = abstract_tuple->elements();
4147+  if (abstract_list.size() <= *idx) {
4148+    MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect";
4149+    return lite::RET_ERROR;
4150+  }
4151+  *abstract = abstract_list[*idx];
4152+  return lite::RET_OK;
4153+}
4154+
4155 int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &outputs) {
4156   if (graph == nullptr || outputs.empty()) {
4157     MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty";
4158@@ -483,5 +508,83 @@ int TransferMetaGraph(const schema::MetaGraphT &graph, void **model_buf, size_t
4159   (void)memcpy(*model_buf, content, *size);
4160   return RET_OK;
4161 }
4162+
4163+STATUS GetShapeVectorFromParameter(const mindspore::ParameterPtr &param_node, std::vector<int64_t> *shape_vector) {
4164+  MS_CHECK_TRUE_MSG(shape_vector != nullptr, RET_ERROR, "shape vector is nullptr.");
4165+  auto abstract_base = param_node->abstract();
4166+  if (abstract_base == nullptr) {
4167+    MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
4168+    return RET_ERROR;
4169+  }
4170+
4171+  if (!abstract_base->isa<abstract::AbstractTensor>()) {
4172+    MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
4173+    return lite::RET_ERROR;
4174+  }
4175+  auto abstract_tensor = abstract_base->cast<abstract::AbstractTensorPtr>();
4176+  MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
4177+  *shape_vector = abstract_tensor->shape()->shape();
4178+  return lite::RET_OK;
4179+}
4180+
4181+
4182+STATUS GetShapeVectorAndIdxFromCNode(const CNodePtr &cnode, std::vector<int64_t> *shape_vector, size_t *idx) {
4183+  MS_CHECK_TRUE_MSG(shape_vector != nullptr, lite::RET_ERROR, "shape is nullptr");
4184+
4185+  AbstractBasePtr cnode_abstract = nullptr;
4186+  if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
4187+    // idx is only used when cnode is type of kPrimTupleGetItem.
4188+    MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr");
4189+    if (GetAbstractfromTupleGetItem(cnode, &cnode_abstract, idx) != lite::RET_OK) {
4190+      MS_LOG(ERROR) << "Get abstract from tuple get item failed.";
4191+      return lite::RET_ERROR;
4192+    }
4193+  } else {
4194+    cnode_abstract = cnode->abstract();
4195+  }
4196+  // the control flow model may be nullptr
4197+  if (cnode_abstract == nullptr) {
4198+    *shape_vector = std::vector<int64_t>();
4199+    return lite::RET_OK;
4200+  }
4201+  if (cnode_abstract->BuildShape() == mindspore::abstract::kNoShape) {
4202+    *shape_vector = std::vector<int64_t>();
4203+    return lite::RET_OK;
4204+  }
4205+  if (!utils::isa<mindspore::abstract::AbstractTensorPtr>(cnode_abstract)) {
4206+    MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope();
4207+    return lite::RET_ERROR;
4208+  }
4209+  auto cnode_abstract_tensor = cnode_abstract->cast<mindspore::abstract::AbstractTensorPtr>();
4210+  CHECK_NULL_RETURN(cnode_abstract_tensor);
4211+  if (!utils::isa<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape())) {
4212+    MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope();
4213+    return lite::RET_ERROR;
4214+  }
4215+  auto shape_ptr = utils::cast<mindspore::abstract::ShapePtr>(cnode_abstract_tensor->BuildShape());
4216+  CHECK_NULL_RETURN(shape_ptr);
4217+  if (shape_ptr->shape().empty()) {
4218+    MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope();
4219+  }
4220+  *shape_vector = shape_ptr->shape();
4221+  return lite::RET_OK;
4222+}
4223+
4224+STATUS GetCNodeOrParameterShapeVec(const AnfNodePtr &anf_node, std::vector<int> *shape) {
4225+  auto int64_t_to_int_func = [](int64_t x) -> int { return static_cast<int>(x); };
4226+  std::vector<int64_t> in_shape;
4227+  if (anf_node->isa<CNode>()) {
4228+    GetShapeVectorAndIdxFromCNode(anf_node->cast<CNodePtr>(), &in_shape);
4229+  } else if (anf_node->isa<Parameter>()) {
4230+    auto param_node = anf_node->cast<ParameterPtr>();
4231+    GetShapeVectorFromParameter(param_node, &in_shape);
4232+  } else {
4233+    MS_LOG(ERROR) << "Node type is not recognized.";
4234+    return RET_ERROR;
4235+  }
4236+  shape->resize(in_shape.size());
4237+  std::transform(in_shape.begin(), in_shape.end(), shape->begin(), int64_t_to_int_func);
4238+  return RET_OK;
4239+}
4240 }  // namespace lite
4241 }  // namespace mindspore
4242diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h
4243index 359af757..be239094 100644
4244--- a/mindspore/lite/tools/common/graph_util.h
4245+++ b/mindspore/lite/tools/common/graph_util.h
4246@@ -89,6 +89,12 @@ STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph);
4247 STATUS UpdateGraphOutputName(schema::MetaGraphT *meta_graph);
4248
4249 int TransferMetaGraph(const schema::MetaGraphT &graph, void **model_buf, size_t *size);
4250+
4251+STATUS GetShapeVectorAndIdxFromCNode(const CNodePtr &cnode, std::vector<int64_t> *shape_vector, size_t *idx = nullptr);
4252+
4253+STATUS GetShapeVectorFromParameter(const mindspore::ParameterPtr &param_node, std::vector<int64_t> *shape_vector);
4254+
4255+STATUS GetCNodeOrParameterShapeVec(const AnfNodePtr &anf_node, std::vector<int> *shape);
4256 }  // namespace lite
4257 }  // namespace mindspore
4258
4259diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt
4260index 8ce0304e..9031cb96 100644
4261--- a/mindspore/lite/tools/converter/CMakeLists.txt
4262+++ b/mindspore/lite/tools/converter/CMakeLists.txt
4263@@ -41,6 +41,8 @@ include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel)
4264 file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
4265         ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc
4266         ${CMAKE_CURRENT_SOURCE_DIR}/converter.cc
4267+        ${CMAKE_CURRENT_SOURCE_DIR}/offline_packing_optimizer.cc
4268+        ${CMAKE_CURRENT_SOURCE_DIR}/converter_packed_node.cc
4269         ${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc
4270         ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc
4271         ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc
4272@@ -125,6 +127,7 @@ set(LITE_SRC ${API_SRC}
4273         ${SRC_DIR}/common/ops/anf_utils.cc
4274         ${SRC_DIR}/common/utils.cc
4275         ${SRC_DIR}/common/file_utils.cc
4276+        ${SRC_DIR}/common/mmap_utils.cc
4277         ${SRC_DIR}/common/context_util.cc
4278         ${SRC_DIR}/common/graph_util.cc
4279         ${SRC_DIR}/common/string_util.cc
4280@@ -152,6 +155,7 @@ set(LITE_SRC ${API_SRC}
4281         ${SRC_DIR}/runtime/sub_graph_kernel.cc
4282         ${SRC_DIR}/runtime/sub_graph_split.cc
4283         ${SRC_DIR}/runtime/lite_session.cc
4284+        ${SRC_DIR}/runtime/runtime_packed_node_pass.cc
4285         ${SRC_DIR}/runtime/executor.cc
4286         ${SRC_DIR}/runtime/lite_model.cc
4287         ${SRC_DIR}/errorcode.cc
4288diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc
4289index 03cac4c0..a4274202 100644
4290--- a/mindspore/lite/tools/converter/anf_transform.cc
4291+++ b/mindspore/lite/tools/converter/anf_transform.cc
4292@@ -501,6 +501,14 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
4293     return nullptr;
4294   }
4295
4296+  if (!param->cpuOptionCfgParam.architecture.empty()) {
4297+    // Do offline pack.
4298+    if (OfflinePackingOptimizer().Optimize(old_graph, "ANDROID_ARM_CPU") != RET_OK) {
4299+      MS_LOG(ERROR) << "Do offline packing failed.";
4300+      return nullptr;
4301+    }
4302+  }
4303+
4304   return old_graph;
4305 }
4306
4307diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h
4308index 42e26310..8d0f2f5d 100644
4309--- a/mindspore/lite/tools/converter/anf_transform.h
4310+++ b/mindspore/lite/tools/converter/anf_transform.h
4311@@ -27,6 +27,7 @@
4312 #include "ir/anf.h"
4313 #include "tools/converter/quantizer/quantizer.h"
4314 #include "tools/converter/converter_context.h"
4315+#include "tools/converter/offline_packing_optimizer.h"
4316
4317 namespace mindspore {
4318 namespace lite {
4319diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
4320index 03ca2ec4..595ce604 100644
4321--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
4322+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
4323@@ -31,6 +31,7 @@ constexpr auto kRegistry = "registry";
4324 constexpr auto kAclOptionParam = "acl_option_cfg_param";
4325 constexpr auto kMicroParam = "micro_param";
4326 constexpr auto kThirdPartyModelParam = "third_party_model";
4327+constexpr auto kCpuOptionParam = "cpu_option_cfg_param";
4328 }  // namespace
4329 int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) {
4330   std::map<std::string, std::map<std::string, std::string>> maps;
4331@@ -101,6 +102,13 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin
4332     return ret;
4333   }
4334
4335+  ret = ParseCpuOptionCfgString(*maps);
4336+  (void)maps->erase(kCpuOptionParam);
4337+  if (ret != RET_OK) {
4338+    MS_LOG(ERROR) << "ParseCpuOptionCfgString failed.";
4339+    return ret;
4340+  }
4341+
4342   for (const auto &config_info : *maps) {
4343     ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
4344   }
4345@@ -152,6 +160,7 @@ int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::ma
4346       {"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
4347       {"skip_quant_node", common_quant_string_.skip_quant_node},
4348       {"debug_info_save_path", common_quant_string_.debug_info_save_path},
4349+      {"dynamic_quant_strategy", common_quant_string_.dynamic_quant_strategy},
4350     };
4351     return SetMapData(map, parse_map, kCommonQuantParam);
4352   }
4353@@ -253,5 +262,15 @@ int ConfigFileParser::ParseThirdPartyParamString(
4354   };
4355   return SetMapData(input_args, kValidArgs, kThirdPartyModelParam);
4356 }
4357+
4358+int ConfigFileParser::ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
4359+  if (maps.find(kCpuOptionParam) != maps.end()) {
4360+    const auto &map = maps.at(kCpuOptionParam);
4361+    std::map<std::string, std::string &> parse_map{{"architecture", cpu_option_cfg_string_.architecture},
4362+                                                   {"instruction", cpu_option_cfg_string_.instruction}};
4363+    return SetMapData(map, parse_map, kCpuOptionParam);
4364+  }
4365+  return RET_OK;
4366+}
4367 }  // namespace lite
4368 }  // namespace mindspore
4369diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
4370index c407dcdd..36257b3a 100644
4371--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h
4372+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
4373@@ -45,6 +45,7 @@ struct CommonQuantString {
4374   std::string min_quant_weight_channel;
4375   std::string skip_quant_node;
4376   std::string debug_info_save_path;
4377+  std::string dynamic_quant_strategy;
4378 };
4379
4380 struct MixedBitWeightQuantString {
4381@@ -102,6 +103,11 @@ struct ThirdPartyModelString {
4382   std::string extended_parameters;  // format: {key1:value1;ker2:value2}
4383 };
4384
4385+struct CpuOptionCfgString {
4386+  std::string architecture;
4387+  std::string instruction;
4388+};
4389+
4390 class ConfigFileParser {
4391  public:
4392   int ParseConfigFile(const std::string &config_file_path);
4393@@ -115,6 +121,7 @@ class ConfigFileParser {
4394   AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; }
4395   MicroParamString GetMicroParamString() { return this->micro_param_string_; }
4396   lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; }
4397+  CpuOptionCfgString GetCpuOptionCfgString() { return this->cpu_option_cfg_string_; }
4398
4399  private:
4400   int ParseDataPreProcessString(const std::map<std::string, std::map<std::string, std::string>> &maps);
4401@@ -127,6 +134,7 @@ class ConfigFileParser {
4402                  const std::map<std::string, std::string &> &parse_map, const std::string &section);
4403   int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps);
4404   int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> &sections);
4405+  int ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &sections);
4406
4407  private:
4408   DataPreProcessString data_pre_process_string_;
4409@@ -137,6 +145,7 @@ class ConfigFileParser {
4410   AclOptionCfgString acl_option_cfg_string_;
4411   MicroParamString micro_param_string_;
4412   lite::ThirdPartyModelString third_party_model_string_;
4413+  CpuOptionCfgString cpu_option_cfg_string_;
4414 };
4415
4416 }  // namespace lite
4417diff --git a/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc
4418new file mode 100644
4419index 00000000..41528773
4420--- /dev/null
4421+++ b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc
4422@@ -0,0 +1,41 @@
4423+/**
4424+ * Copyright 2023 Huawei Technologies Co., Ltd
4425+ *
4426+ * Licensed under the Apache License, Version 2.0 (the "License");
4427+ * you may not use this file except in compliance with the License.
4428+ * You may obtain a copy of the License at
4429+ *
4430+ * http://www.apache.org/licenses/LICENSE-2.0
4431+ *
4432+ * Unless required by applicable law or agreed to in writing, software
4433+ * distributed under the License is distributed on an "AS IS" BASIS,
4434+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4435+ * See the License for the specific language governing permissions and
4436+ * limitations under the License.
4437+ */
4438+
4439+#include "tools/converter/config_parser/cpu_option_param_parser.h"
4440+#include "common/log.h"
4441+namespace mindspore {
4442+namespace lite {
4443+STATUS CpuOptionParamParser::ParseCpuOptionCfg(const CpuOptionCfgString &cpu_option_string,
4444+                                               CpuOptionCfg *cpu_option_cfg) {
4445+  if (cpu_option_string.architecture.empty() || cpu_option_string.instruction.empty()) {
4446+    return RET_OK;
4447+  }
4448+
4449+  if (cpu_option_string.architecture != "ARM64") {
4450+    MS_LOG(ERROR) << "cpu instruction only supported ARM64. But get " << cpu_option_string.architecture;
4451+    return RET_INPUT_PARAM_INVALID;
4452+  }
4453+
4454+  if (cpu_option_string.instruction != "SIMD_DOT") {
4455+    MS_LOG(ERROR) << "cpu instruction only supported SIMD_DOT. But get " << cpu_option_string.instruction;
4456+    return RET_INPUT_PARAM_INVALID;
4457+  }
4458+  cpu_option_cfg->instruction = cpu_option_string.instruction;
4459+  cpu_option_cfg->architecture = cpu_option_string.architecture;
4460+  return RET_OK;
4461+}
4462+}  // namespace lite
4463+}  // namespace mindspore
4464diff --git a/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h
4465new file mode 100644
4466index 00000000..c549477f
4467--- /dev/null
4468+++ b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h
4469@@ -0,0 +1,32 @@
4470+/**
4471+ * Copyright 2023 Huawei Technologies Co., Ltd
4472+ *
4473+ * Licensed under the Apache License, Version 2.0 (the "License");
4474+ * you may not use this file except in compliance with the License.
4475+ * You may obtain a copy of the License at
4476+ *
4477+ * http://www.apache.org/licenses/LICENSE-2.0
4478+ *
4479+ * Unless required by applicable law or agreed to in writing, software
4480+ * distributed under the License is distributed on an "AS IS" BASIS,
4481+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4482+ * See the License for the specific language governing permissions and
4483+ * limitations under the License.
4484+ */
4485+
4486+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_CPU_OPTION_PARAM_PARSER_H_
4487+#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_CPU_OPTION_PARAM_PARSER_H_
4488+#include <string>
4489+#include "tools/converter/cxx_api/converter_para.h"
4490+#include "tools/converter/config_parser/config_file_parser.h"
4491+#include "include/errorcode.h"
4492+
4493+namespace mindspore {
4494+namespace lite {
4495+class CpuOptionParamParser {
4496+ public:
4497+  STATUS ParseCpuOptionCfg(const CpuOptionCfgString &cpu_option_string, CpuOptionCfg *cpu_option_cfg);
4498+};
4499+}  // namespace lite
4500+}  // namespace mindspore
4501+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_CPU_OPTION_PARAM_PARSER_H_
4502diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc
4503index cc9807e4..c0bd6219 100644
4504--- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc
4505+++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc
4506@@ -111,6 +111,11 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
4507   if (!common_quant->debug_info_save_path.empty()) {
4508     common_quant->is_debug = true;
4509   }
4510+
4511+	if (!common_quant_string.dynamic_quant_strategy.empty() &&
4512+			ParseDynamicQuantStrategy(common_quant_string.dynamic_quant_strategy, &common_quant->dynamic_strategy)) {
4513+		return RET_INPUT_PARAM_INVALID;
4514+	}
4515   return RET_OK;
4516 }
4517
4518@@ -210,5 +215,20 @@ int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activati
4519     return RET_INPUT_PARAM_INVALID;
4520   }
4521 }
4522+
4523+int QuantParamParser::ParseDynamicQuantStrategy(const std::string &dynamic_quant_strategy_str,
4524+																								quant::DynamicQuantStrategy *dynamic_strategy) {
4525+	if (dynamic_quant_strategy_str == "ACTIVATION_LAYER") {
4526+		(*dynamic_strategy) = quant::ACTIVATION_LAYER;
4527+		return RET_OK;
4528+	} else if (dynamic_quant_strategy_str == "ACTIVATION_CHANNEL") {
4529+		(*dynamic_strategy) = quant::ACTIVATION_CHANNEL;
4530+		return RET_OK;
4531+	} else {
4532+		MS_LOG(ERROR) << "INPUT ILLEGAL: dynamic_quant_strategy must be ACTIVATION_LAYER or ACTIVATION_CHANNEL.";
4533+		return RET_INPUT_PARAM_INVALID;
4534+	}
4535+	return RET_OK;
4536+}
4537 }  // namespace lite
4538 }  // namespace mindspore
4539diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h
4540index 4f9e3816..bbf3950c 100644
4541--- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h
4542+++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h
4543@@ -36,6 +36,7 @@ class QuantParamParser {
4544                                             quant::ActivationQuantizedMethod *activation_quant_method);
4545   static int ParseFilter(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
4546   static int ParseBitNum(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
4547+	static int ParseDynamicQuantStrategy(const std::string &dynamic_quant_strategy_str, quant::DynamicQuantStrategy *dynamic_strategy);
4548 };
4549 }  // namespace lite
4550 }  // namespace mindspore
4551diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc
4552index 6177d379..4ca303b5 100644
4553--- a/mindspore/lite/tools/converter/converter.cc
4554+++ b/mindspore/lite/tools/converter/converter.cc
4555@@ -47,6 +47,8 @@
4556 #include "tools/converter/config_parser/third_party_param_parser.h"
4557 #include "tools/common/string_util.h"
4558 #include "src/common/file_utils.h"
4559+#include "tools/converter/converter_packed_node.h"
4560+#include "tools/converter/config_parser/cpu_option_param_parser.h"
4561
4562 namespace mindspore {
4563 extern "C" {
4564@@ -396,6 +398,13 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> &param)
4565     MS_LOG(ERROR) << "Parse micro param failed.";
4566     return ret;
4567   }
4568+
4569+  lite::CpuOptionParamParser cpu_param_parser;
4570+  ret = cpu_param_parser.ParseCpuOptionCfg(config_parser.GetCpuOptionCfgString(), &param->cpuOptionCfgParam);
4571+  if (ret != RET_OK) {
4572+    MS_LOG(ERROR) << "Parse cpu option param failed.";
4573+    return ret;
4574+  }
4575   return RET_OK;
4576 }
4577
4578@@ -795,6 +804,16 @@ int RunConverter(const std::shared_ptr<ConverterPara> &param, void **model_data,
4579     status = RET_ERROR;
4580     return status;
4581   }
4582+
4583+  if (!param->cpuOptionCfgParam.architecture.empty()) {
4584+    std::string cpu_option = param->cpuOptionCfgParam.architecture + param->cpuOptionCfgParam.instruction;
4585+    status = ConverterPackedNode(meta_graph, cpu_option);
4586+    if (status != RET_OK) {
4587+      MS_LOG(ERROR) << "save pack info failed.";
4588+      return status;
4589+    }
4590+  }
4591+
4592   //   save graph to file
4593   meta_graph->version = Version();
4594
4595diff --git a/mindspore/lite/tools/converter/converter_packed_node.cc b/mindspore/lite/tools/converter/converter_packed_node.cc
4596new file mode 100644
4597index 00000000..f632fec3
4598--- /dev/null
4599+++ b/mindspore/lite/tools/converter/converter_packed_node.cc
4600@@ -0,0 +1,179 @@
4601+/**
4602+ * Copyright 2023 Huawei Technologies Co., Ltd
4603+ *
4604+ * Licensed under the Apache License, Version 2.0 (the "License");
4605+ * you may not use this file except in compliance with the License.
4606+ * You may obtain a copy of the License at
4607+ *
4608+ * http://www.apache.org/licenses/LICENSE-2.0
4609+ *
4610+ * Unless required by applicable law or agreed to in writing, software
4611+ * distributed under the License is distributed on an "AS IS" BASIS,
4612+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4613+ * See the License for the specific language governing permissions and
4614+ * limitations under the License.
4615+ */
4616+
4617+#include <vector>
4618+#include <memory>
4619+#include <utility>
4620+#include "tools/converter/converter_packed_node.h"
4621+#include "tools/converter/offline_packing_optimizer.h"
4622+#include "src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h"
4623+#include "mindspore/core/ops/op_name.h"
4624+#include "src/runtime/kernel/cpu/fp32/matmul_fp32.h"
4625+
4626+namespace mindspore {
4627+namespace {
4628+constexpr auto kMatmulCustomType = "MatmulFusionPacked";
4629+}
4630+
4631+namespace lite {
4632+void AddCustomAttr(std::vector<std::unique_ptr<mindspore::schema::AttributeT>> *attrs, const std::string &&key,
4633+                   const std::string &&value) {
4634+  auto attr = std::make_unique<schema::AttributeT>();
4635+  attr->name = key;
4636+  std::vector<uint8_t> attr_data(value.begin(), value.end());
4637+  attr->data = attr_data;
4638+  attrs->emplace_back(std::move(attr));
4639+}
4640+
4641+int AddWeightSumsToInputs(const mindspore::kernel::MatmulDynamicBaseInt8CPUKernel *matmul_kernel,
4642+                          schema::MetaGraphT *meta_graph, const std::unique_ptr<schema::CNodeT> &cnode,
4643+                          size_t weight_sum_size) {
4644+  auto weight_sums_tensor = std::make_unique<schema::TensorT>();
4645+  weight_sums_tensor->nodeType = lite::NodeType_ValueNode;
4646+  weight_sums_tensor->format = schema::Format_NHWC;
4647+  weight_sums_tensor->dataType = TypeId::kNumberTypeInt32;
4648+  weight_sums_tensor->dims = {};
4649+  weight_sums_tensor->dims.emplace_back(weight_sum_size / sizeof(int));
4650+  weight_sums_tensor->data.resize(weight_sum_size);
4651+  weight_sums_tensor->name = cnode->name + "_weight_sums";
4652+  if (memcpy_s(weight_sums_tensor->data.data(), weight_sums_tensor->data.size(), matmul_kernel->GetWeightSums(),
4653+               weight_sum_size) != EOK) {
4654+    MS_LOG(ERROR) << "new CustomT error.";
4655+    return RET_ERROR;
4656+  }
4657+  cnode->inputIndex.emplace_back(meta_graph->allTensors.size());
4658+  meta_graph->allTensors.emplace_back(std::move(weight_sums_tensor));
4659+  return RET_OK;
4660+}
4661+
4662+int ReplaceMatMulFusionToCustom(schema::MetaGraphT *meta_graph, const std::unique_ptr<schema::CNodeT> &cnode,
4663+                                const std::unique_ptr<mindspore::schema::TensorT> &b_input,
4664+                                const std::string &cpu_option) {
4665+  auto lite_kernel = PackDataWrapper::GetInstance().GetPackedKernel(cnode->name);
4666+  if (lite_kernel == nullptr) {
4667+    MS_LOG(ERROR) << "Get Packed Kernel error.";
4668+    return RET_ERROR;
4669+  }
4670+  auto param = lite_kernel->op_parameter();
4671+  if (param == nullptr) {
4672+    MS_LOG(ERROR) << "param is nullptr.";
4673+    return RET_ERROR;
4674+  }
4675+  auto matmul_param = reinterpret_cast<MatMulParameter *>(param);
4676+  if (matmul_param->matmul_type_ == kNotImplemented) {
4677+    MS_LOG(WARNING) << "Unsupported matmul type, only support dynamic quant int8.";
4678+    return RET_OK;
4679+  }
4680+  cnode->primitive->value.type = schema::PrimitiveType_Custom;
4681+  auto primitive = new (std::nothrow) schema::CustomT;
4682+  if (primitive == nullptr) {
4683+    MS_LOG(ERROR) << "new CustomT error.";
4684+    return RET_NULL_PTR;
4685+  }
4686+  primitive->type = kMatmulCustomType;
4687+
4688+  // activation_type
4689+  AddCustomAttr(&(primitive->attr), ops::kActivationType, std::to_string(matmul_param->act_type_));
4690+  // transpose_a
4691+  AddCustomAttr(&(primitive->attr), ops::kTransposeA, std::to_string(matmul_param->a_transpose_));
4692+  // transpose_b
4693+  AddCustomAttr(&(primitive->attr), ops::kTransposeB, std::to_string(matmul_param->b_transpose_));
4694+
4695+  int b_batch;
4696+  const void *pack_b_ptr = nullptr;
4697+  size_t pack_b_size;
4698+  if (matmul_param->matmul_type_ == kMatmulDynamicSdotInt8Cpu) {
4699+    // replace packed data
4700+    auto matmul_kernel = reinterpret_cast<const mindspore::kernel::MatmulDynamicBaseInt8CPUKernel *>(lite_kernel);
4701+    b_batch = matmul_kernel->GetBBatch();
4702+    pack_b_size = b_batch * matmul_param->col_align_ * matmul_param->deep_align_ * sizeof(int8_t);
4703+    pack_b_ptr = reinterpret_cast<const void *>(matmul_kernel->GetPackBPtr());
4704+    auto weight_sum_size = b_batch * matmul_param->col_align_ * sizeof(int);
4705+    int ret = AddWeightSumsToInputs(matmul_kernel, meta_graph, cnode, weight_sum_size);
4706+    if (ret != RET_OK) {
4707+      delete primitive;
4708+      MS_LOG(ERROR) << "add weight sums to inputs error.";
4709+      return ret;
4710+    }
4711+  } else {
4712+    MS_LOG(ERROR) << "matmul_type is error.";
4713+    return RET_ERROR;
4714+  }
4715+
4716+  if (pack_b_ptr == nullptr) {
4717+    delete primitive;
4718+    MS_LOG(ERROR) << "pack_b_ptr is nullptr.";
4719+    return RET_NULL_PTR;
4720+  }
4721+
4722+  // copy packed weight to meta graph
4723+  b_input->data.resize(pack_b_size);
4724+  if (memcpy_s(b_input->data.data(), b_input->data.size(), pack_b_ptr, pack_b_size) != EOK) {
4725+    delete primitive;
4726+    MS_LOG(ERROR) << "memcpy packed weight error.";
4727+    return RET_ERROR;
4728+  }
4729+
4730+  // add scalar to attr
4731+  AddCustomAttr(&(primitive->attr), "b_batch", std::to_string(b_batch));
4732+  AddCustomAttr(&(primitive->attr), "deep", std::to_string(matmul_param->deep_));
4733+  AddCustomAttr(&(primitive->attr), "col", std::to_string(matmul_param->col_));
4734+  AddCustomAttr(&(primitive->attr), "col_align", std::to_string(matmul_param->col_align_));
4735+  AddCustomAttr(&(primitive->attr), "deep_align", std::to_string(matmul_param->deep_align_));
4736+
4737+  // add cpu option
4738+  std::string cpu_option_str = cpu_option;
4739+  AddCustomAttr(&(primitive->attr), "cpu_option", std::move(cpu_option_str));
4740+
4741+  cnode->primitive->value.value = primitive;
4742+  return RET_OK;
4743+}
4744+
4745+int ConverterPackedNode(schema::MetaGraphT *meta_graph, const std::string &cpu_option) {
4746+  for (auto &dst_node : meta_graph->nodes) {
4747+    if (dst_node->primitive == nullptr || dst_node->primitive->value.type != schema::PrimitiveType_MatMulFusion) {
4748+      continue;
4749+    }
4750+    MS_CHECK_TRUE_MSG(dst_node->inputIndex.size() >= kInputSize1, RET_ERROR, "inputs size is wrong.");
4751+    auto a_index = dst_node->inputIndex[FIRST_INPUT];
4752+    MS_CHECK_TRUE_MSG(meta_graph->allTensors.size() > a_index, RET_ERROR, "allTensors size is wrong.");
4753+    auto &a_input = meta_graph->allTensors.at(a_index);
4754+    CHECK_NULL_RETURN(a_input);
4755+
4756+    auto b_index = dst_node->inputIndex[SECOND_INPUT];
4757+    MS_CHECK_TRUE_MSG(meta_graph->allTensors.size() > b_index, RET_ERROR, "allTensors size is wrong.");
4758+    auto &b_input = meta_graph->allTensors.at(b_index);
4759+    CHECK_NULL_RETURN(b_input);
4760+
4761+    if (a_input->dataType != b_input->dataType) {
4762+      MS_LOG(ERROR) << "inputs dataType is not same." << a_input->dataType << " " << b_input->dataType;
4763+      return RET_ERROR;
4764+    }
4765+
4766+    if (b_input->data.empty()) {
4767+      continue;
4768+    }
4769+    auto ret = ReplaceMatMulFusionToCustom(meta_graph, dst_node, b_input, cpu_option);
4770+    if (ret != RET_OK) {
4771+      MS_LOG(ERROR) << "ReplaceMatmulToCustom error.";
4772+      return ret;
4773+    }
4774+  }
4775+
4776+  return RET_OK;
4777+}
4778+}  // namespace lite
4779+}  // namespace mindspore
4780diff --git a/mindspore/lite/tools/converter/converter_packed_node.h b/mindspore/lite/tools/converter/converter_packed_node.h
4781new file mode 100644
4782index 00000000..cee891fa
4783--- /dev/null
4784+++ b/mindspore/lite/tools/converter/converter_packed_node.h
4785@@ -0,0 +1,29 @@
4786+/**
4787+ * Copyright 2023 Huawei Technologies Co., Ltd
4788+ *
4789+ * Licensed under the Apache License, Version 2.0 (the "License");
4790+ * you may not use this file except in compliance with the License.
4791+ * You may obtain a copy of the License at
4792+ *
4793+ * http://www.apache.org/licenses/LICENSE-2.0
4794+ *
4795+ * Unless required by applicable law or agreed to in writing, software
4796+ * distributed under the License is distributed on an "AS IS" BASIS,
4797+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4798+ * See the License for the specific language governing permissions and
4799+ * limitations under the License.
4800+ */
4801+
4802+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERT_PACKED_NODE_H
4803+#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERT_PACKED_NODE_H
4804+
4805+#include <string>
4806+#include "schema/inner/model_generated.h"
4807+
4808+namespace mindspore {
4809+namespace lite {
4810+int ConverterPackedNode(schema::MetaGraphT *meta_graph, const std::string &cpu_option);
4811+}  // namespace lite
4812+}  // namespace mindspore
4813+
4814+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_CONVERT_PACKED_NODE_H
4815diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h
4816index 00b7fa3c..1d7d1aec 100644
4817--- a/mindspore/lite/tools/converter/cxx_api/converter_para.h
4818+++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h
4819@@ -48,6 +48,11 @@ struct ThirdPartyModelParam {
4820   std::map<std::string, std::vector<uint8_t>> extended_parameters;
4821 };
4822
4823+struct CpuOptionCfg {
4824+  std::string architecture;
4825+  std::string instruction;
4826+};
4827+
4828 struct ConverterPara {
4829   converter::FmkType fmk_type;
4830   std::string model_file;
4831@@ -82,6 +87,7 @@ struct ConverterPara {
4832   lite::micro::MicroParam microParam;
4833   ParallelSplitConfig parallel_split_config;
4834   ThirdPartyModelParam thirdPartyModelParam;
4835+  CpuOptionCfg cpuOptionCfgParam;
4836 };
4837 }  // namespace mindspore
4838 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_
4839diff --git a/mindspore/lite/tools/converter/offline_packing_optimizer.cc b/mindspore/lite/tools/converter/offline_packing_optimizer.cc
4840new file mode 100644
4841index 00000000..d9a62c15
4842--- /dev/null
4843+++ b/mindspore/lite/tools/converter/offline_packing_optimizer.cc
4844@@ -0,0 +1,307 @@
4845+/**
4846+ * Copyright 2023 Huawei Technologies Co., Ltd
4847+ *
4848+ * Licensed under the Apache License, Version 2.0 (the "License");
4849+ * you may not use this file except in compliance with the License.
4850+ * You may obtain a copy of the License at
4851+ *
4852+ * http://www.apache.org/licenses/LICENSE-2.0
4853+ *
4854+ * Unless required by applicable law or agreed to in writing, software
4855+ * distributed under the License is distributed on an "AS IS" BASIS,
4856+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4857+ * See the License for the specific language governing permissions and
4858+ * limitations under the License.
4859+ */
4860+
4861+#include <memory>
4862+#include <algorithm>
4863+#include <vector>
4864+#include <string>
4865+#include <map>
4866+#include "tools/common/graph_util.h"
4867+#include "tools/converter/offline_packing_optimizer.h"
4868+#include "tools/converter/quantizer/quantize_util.h"
4869+#include "src/common/primitive_t_utils.h"
4870+#include "src/common/ops/anf_utils.h"
4871+#include "src/common/file_utils.h"
4872+#include "nnacl/matmul_parameter.h"
4873+#include "src/runtime//kernel/cpu/int8/matmul_dynamic_base_int8.h"
4874+#include "tools/optimizer/common/gllo_utils.h"
4875+
4876+using mindspore::kernel::MatmulDynamicBaseInt8CPUKernel;
4877+
4878+namespace mindspore::lite {
4879+namespace {
4880+constexpr const int kPrimIndex = 0;
4881+constexpr const int kSingleThread = 1;
4882+const char kAndroidArmCpuBackendOption[] = "ANDROID_ARM_CPU";
4883+}  // namespace
4884+
4885+mindspore::lite::InnerContext *InitInnerContextForAndroidArmCpu() {
4886+  // if the operation use thread_pool in inner context will throw exception.
4887+  auto inner_context = new (std::nothrow) lite::InnerContext();
4888+  inner_context->Init();
4889+  MS_CHECK_TRUE_MSG(inner_context != nullptr, nullptr, "Create InnerContext failed.");
4890+  inner_context->thread_num_ = kSingleThread;
4891+  inner_context->instructions_ctx_.support_sdot = true;
4892+  return inner_context;
4893+}
4894+
4895+schema::PrimitiveType GetSchemaPrimitiveType(const AnfNodePtr &node) {
4896+  auto primitive_t = GetPrimitiveT(node);
4897+  if (primitive_t == nullptr) {
4898+    MS_LOG(ERROR) << "Failed to generate PrimitiveT.";
4899+    return schema::PrimitiveType::PrimitiveType_NONE;
4900+  }
4901+  return GetSchemaPrimType(primitive_t.get());
4902+}
4903+
4904+STATUS CreateMatmulPackDataIntoTable(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors,
4905+                                     OpParameter *op_parameter, const kernel::KernelKey &desc,
4906+                                     const mindspore::lite::InnerContext *ctx) {
4907+  if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
4908+    MS_LOG(ERROR) << op_parameter->name_ << " is not supported.";
4909+    return RET_ERROR;
4910+  }
4911+
4912+  kernel::LiteKernel *kernel =
4913+    KernelRegistry::GetInstance()->GetLiteKernel(in_tensors, out_tensors, ctx, desc, op_parameter);
4914+  if (kernel == nullptr) {
4915+    MS_LOG(ERROR) << "Anf node cannot be nullptr.";
4916+    return RET_ERROR;
4917+  }
4918+  kernel->set_name(op_parameter->name_);
4919+
4920+  if (kernel->Prepare() != RET_OK) {
4921+    MS_LOG(ERROR) << "Failed to generate pack data for " << op_parameter->name_ << ".";
4922+    return RET_ERROR;
4923+  }
4924+
4925+  PackDataWrapper::GetInstance().AddPackedKernel(op_parameter->name_, kernel);
4926+  return RET_OK;
4927+}
4928+
4929+schema::QuantType GetQuantType(const CNodePtr &cnode) {
4930+  MS_CHECK_TRUE_MSG(cnode != nullptr, schema::QuantType::QuantType_QUANT_NONE, "cnode cannot be nullptr.");
4931+  auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
4932+  if (primitive == nullptr) {
4933+    MS_LOG(INFO) << "primitive is nullptr";
4934+    return schema::QuantType::QuantType_QUANT_NONE;
4935+  }
4936+  auto quant_param_holder = quant::GetCNodeQuantHolder(primitive);
4937+  if (quant_param_holder != nullptr) {
4938+    return quant_param_holder->quant_type();
4939+  }
4940+  return schema::QuantType::QuantType_QUANT_NONE;
4941+}
4942+
4943+TypeId GetDataType(const CNodePtr &cnode, const std::vector<Tensor *> &in_tensors,
4944+                   const std::vector<Tensor *> &out_tensors) {
4945+  if (in_tensors.empty()) {
4946+    MS_LOG(ERROR) << "in tensor is empty.";
4947+    return kTypeUnknown;
4948+  }
4949+
4950+  // Currently, fp16 is not a supported option.
4951+  TypeId data_type =
4952+    in_tensors[0]->data_type() == kObjectTypeTensorType ? kNumberTypeFloat32 : in_tensors[0]->data_type();
4953+  // How to judge quant type?
4954+  auto quant_type = GetQuantType(cnode);
4955+  if (quant_type == schema::QuantType_QUANT_WEIGHT) {
4956+    data_type =
4957+      in_tensors.front()->data_type() == kNumberTypeBool ? TypeId::kNumberTypeBool : TypeId::kNumberTypeFloat32;
4958+  }
4959+  return data_type;
4960+}
4961+
4962+void QuantParamTToQuantParam(const schema::QuantParamT &quant_param_t, lite::LiteQuantParam *quant_param) {
4963+  quant_param->inited = true;
4964+  quant_param->bitNum = quant_param_t.numBits;
4965+  quant_param->scale = quant_param_t.scale;
4966+  quant_param->zeroPoint = quant_param_t.zeroPoint;
4967+  quant_param->var_corr = quant_param_t.varCorr;
4968+  quant_param->mean_corr = quant_param_t.meanCorr;
4969+  quant_param->roundType = quant_param_t.roundType;
4970+  quant_param->multiplier = quant_param_t.multiplier;
4971+  quant_param->dstDtype = quant_param_t.dstDtype;
4972+  quant_param->min = quant_param_t.min;
4973+  quant_param->max = quant_param_t.max;
4974+}
4975+
4976+void AddQuantParams(Tensor *in_tensor, const std::vector<schema::QuantParamT> &quant_param_t) {
4977+  std::vector<lite::LiteQuantParam> lite_quant_params(quant_param_t.size());
4978+  for (size_t i = 0; i < lite_quant_params.size(); i++) {
4979+    QuantParamTToQuantParam(quant_param_t[i], &lite_quant_params[i]);
4980+  }
4981+  in_tensor->set_quant_params(lite_quant_params);
4982+}
4983+
4984+STATUS CreateLiteTensor(const CNodePtr &cnode, std::vector<Tensor *> *in_tensors, std::vector<Tensor *> *out_tensors) {
4985+  std::vector<int> shape(0);
4986+  mindspore::TypeId type_id = TypeId::kTypeUnknown;
4987+  auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
4988+  if (primitive == nullptr) {
4989+    MS_LOG(INFO) << "primitive is nullptr";
4990+    return RET_ERROR;
4991+  }
4992+  auto quant_param_holder = quant::GetCNodeQuantHolder(primitive);
4993+  std::vector<std::vector<schema::QuantParamT>> input_quant_params_vec;
4994+  std::vector<std::vector<schema::QuantParamT>> output_quant_params_vec;
4995+  if (quant_param_holder != nullptr) {
4996+    input_quant_params_vec = quant_param_holder->get_input_quant_params();
4997+    output_quant_params_vec = quant_param_holder->get_output_quant_params();
4998+  }
4999+
5000+  // Generate input tensor.
5001+  for (size_t i = kPrimIndex + 1; i < cnode->inputs().size(); i++) {
5002+    if (opt::GetDataTypeFromAnfNode(cnode->input(i), &type_id) != RET_OK) {
5003+      MS_LOG(ERROR) << "Cannot get data type from " << cnode->input(i)->fullname_with_scope();
5004+      return RET_ERROR;
5005+    }
5006+    void *tensor_data = nullptr;
5007+    Category category = cnode->input(i)->isa<Parameter>() ? lite::Category::CONST_TENSOR : lite::Category::VAR;
5008+
5009+    MS_CHECK_TRUE_MSG(GetCNodeOrParameterShapeVec(cnode->input(i), &shape) == RET_OK, RET_ERROR,
5010+                      "Infer shape must be done when using offline packing.");
5011+    MS_CHECK_TRUE_MSG(!shape.empty(), RET_ERROR, "Infer shape must be done when using offline packing.");
5012+    // Get tensor data from parameter node.
5013+    if (cnode->input(i)->isa<Parameter>() && cnode->input(i)->cast<ParameterPtr>()->has_default()) {
5014+      auto param_node = cnode->input(i)->cast<ParameterPtr>();
5015+      if (param_node->has_default()) {
5016+        auto tensor_info = std::static_pointer_cast<tensor::Tensor>(param_node->default_param());
5017+        tensor_data = tensor_info->data().data();
5018+      }
5019+    }
5020+    auto in_tensor = new (std::nothrow) Tensor(type_id, shape);
5021+    MS_CHECK_TRUE_MSG(in_tensor != nullptr, RET_ERROR, "Create input tensor failed.");
5022+    in_tensor->set_category(category);
5023+    // Tensor data is managed by funcGraph.
5024+    in_tensor->set_data(tensor_data);
5025+    in_tensor->set_own_data(false);
5026+    // Setup quant params.
5027+    if (type_id == TypeId::kNumberTypeInt8 && !input_quant_params_vec.empty()) {
5028+      AddQuantParams(in_tensor, input_quant_params_vec.front());
5029+      input_quant_params_vec.erase(input_quant_params_vec.begin());
5030+    }
5031+    in_tensors->emplace_back(in_tensor);
5032+    shape.clear();
5033+    type_id = TypeId::kTypeUnknown;
5034+  }
5035+
5036+  if (!input_quant_params_vec.empty()) {
5037+    MS_LOG(WARNING) << cnode->fullname_with_scope() << " quant params' count are not equal to inputs' size";
5038+  }
5039+
5040+  // Generate output tensor.
5041+  MS_CHECK_TRUE_MSG(GetCNodeOrParameterShapeVec(cnode, &shape) == RET_OK, RET_ERROR,
5042+                    "Infer shape must be done when using offline packing.");
5043+  MS_CHECK_TRUE_MSG(!shape.empty(), RET_ERROR, "Infer shape must be done when using offline packing.");
5044+  if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) {
5045+    MS_LOG(ERROR) << "Cannot get data type from " + cnode->fullname_with_scope() + ".";
5046+    return RET_ERROR;
5047+  }
5048+  auto out_tensor = new (std::nothrow) Tensor(type_id, shape);
5049+  MS_CHECK_TRUE_MSG(out_tensor != nullptr, RET_ERROR, "Create output tensor failed.");
5050+  if (type_id == TypeId::kNumberTypeInt8 && !output_quant_params_vec.empty()) {
5051+    AddQuantParams(out_tensor, output_quant_params_vec.front());
5052+    output_quant_params_vec.erase(output_quant_params_vec.begin());
5053+  }
5054+  out_tensors->emplace_back(out_tensor);
5055+
5056+  if (in_tensors->size() != cnode->inputs().size() - 1 || out_tensors->empty()) {
5057+    MS_LOG(ERROR) << "Failed to populate input tensors for " << cnode->fullname_with_scope() << ".";
5058+    return RET_ERROR;
5059+  }
5060+
5061+  return RET_OK;
5062+}
5063+
5064+STATUS MatmulPacking(const mindspore::CNodePtr &cnode_ptr, const FuncGraphPtr &funcGraphPtr,
5065+                     const lite::InnerContext *ctx) {
5066+  if (cnode_ptr == nullptr) {
5067+    MS_LOG(ERROR) << "Matmul node cannot be nullptr.";
5068+    return RET_ERROR;
5069+  }
5070+  auto primT = mindspore::lite::GetPrimitiveT(cnode_ptr->input(kPrimIndex));
5071+  if (primT == nullptr) {
5072+    MS_LOG(ERROR) << "Failed to generate PrimitiveT for " << cnode_ptr->fullname_with_scope() << ".";
5073+    return RET_ERROR;
5074+  }
5075+  OpParameter *op_parameter = GetOpParameter(primT.get());
5076+  if (op_parameter == nullptr) {
5077+    MS_LOG(ERROR) << "Failed to generate op parameter for " << cnode_ptr->fullname_with_scope() << ".";
5078+    return RET_ERROR;
5079+  }
5080+  op_parameter->thread_num_ = kSingleThread;
5081+  op_parameter->quant_type_ = GetQuantType(cnode_ptr);
5082+
5083+  constexpr size_t max_name_len = 100;
5084+  if (memcpy_s(op_parameter->name_, max_name_len, cnode_ptr->fullname_with_scope().c_str(),
5085+               cnode_ptr->fullname_with_scope().length()) != EOK) {
5086+    MS_LOG(ERROR) << "Set op parameter name failed.";
5087+    return RET_ERROR;
5088+  }
5089+
5090+  std::vector<Tensor *> in_tensors;
5091+  std::vector<Tensor *> out_tensors;
5092+  if (CreateLiteTensor(cnode_ptr, &in_tensors, &out_tensors) != RET_OK) {
5093+    MS_LOG(ERROR) << "Failed to populate input tensors for " << cnode_ptr->fullname_with_scope() << ".";
5094+    return RET_ERROR;
5095+  }
5096+
5097+  TypeId data_type = GetDataType(cnode_ptr, in_tensors, out_tensors);
5098+  MS_CHECK_TRUE_MSG(data_type != TypeId::kTypeUnknown, RET_ERROR,
5099+                    "Can't get data type from " + cnode_ptr->fullname_with_scope() + ".");
5100+  kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, op_parameter->type_};
5101+
5102+  return CreateMatmulPackDataIntoTable(in_tensors, out_tensors, op_parameter, desc, ctx);
5103+}
5104+
5105+BackendType FindBackend(const std::string &target_backend) {
5106+  if (target_backend == std::string(kAndroidArmCpuBackendOption)) {
5107+    return BackendType::kAndroidArmCpuBackend;
5108+  }
5109+  return BackendType::kUnknownBackend;
5110+}
5111+
5112+STATUS OfflinePackingOptimizer::Optimize(const FuncGraphPtr &func_graph, const std::string &target_backend) {
5113+  BackendType backend = FindBackend(target_backend);
5114+  if (backend == BackendType::kUnknownBackend ||
5115+      this->packing_strategies_selector_.find(backend) == this->packing_strategies_selector_.end() ||
5116+      this->ctx_creator_selector_.find(backend) == this->ctx_creator_selector_.end()) {
5117+    MS_LOG(ERROR) << target_backend << " is not supported to do offline packing.";
5118+    return RET_ERROR;
5119+  }
5120+
5121+  // Get built-in backend optimizer.
5122+  std::map<schema::PrimitiveType, OfflinePackingFunc> selected_backend_op_cvt =
5123+    this->packing_strategies_selector_[backend];
5124+  mindspore::lite::InnerContext *inner_context = this->ctx_creator_selector_[backend]();
5125+  MS_CHECK_TRUE_MSG(inner_context != nullptr, RET_ERROR, "Failed to initialize runtime context.");
5126+
5127+  auto anf_nodes = mindspore::TopoSort(func_graph->get_return());
5128+  for (auto &anf_node : anf_nodes) {
5129+    if (!utils::isa<CNodePtr>(anf_node)) {
5130+      continue;
5131+    }
5132+    if (mindspore::opt::CheckPrimitiveType(anf_node, prim::kPrimReturn) ||
5133+        mindspore::opt::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple) ||
5134+        mindspore::opt::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
5135+      continue;
5136+    }
5137+    auto cnode = anf_node->cast<CNodePtr>();
5138+    schema::PrimitiveType op_type = GetSchemaPrimitiveType(cnode->input(kPrimIndex));
5139+    if (selected_backend_op_cvt.find(op_type) != selected_backend_op_cvt.end()) {
5140+      OfflinePackingFunc packing_func = selected_backend_op_cvt[op_type];
5141+      if (packing_func(cnode, func_graph, inner_context) != RET_OK) {
5142+        MS_LOG(ERROR) << "Failed to pack for " << anf_node->fullname_with_scope();
5143+        delete inner_context;
5144+        return RET_ERROR;
5145+      }
5146+    }
5147+  }
5148+  delete inner_context;
5149+  return RET_OK;
5150+}
5151+}  // namespace mindspore::lite
5152diff --git a/mindspore/lite/tools/converter/offline_packing_optimizer.h b/mindspore/lite/tools/converter/offline_packing_optimizer.h
5153new file mode 100644
5154index 00000000..2590f542
5155--- /dev/null
5156+++ b/mindspore/lite/tools/converter/offline_packing_optimizer.h
5157@@ -0,0 +1,87 @@
5158+/**
5159+ * Copyright 2023 Huawei Technologies Co., Ltd
5160+ *
5161+ * Licensed under the Apache License, Version 2.0 (the "License");
5162+ * you may not use this file except in compliance with the License.
5163+ * You may obtain a copy of the License at
5164+ *
5165+ * http://www.apache.org/licenses/LICENSE-2.0
5166+ *
5167+ * Unless required by applicable law or agreed to in writing, software
5168+ * distributed under the License is distributed on an "AS IS" BASIS,
5169+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5170+ * See the License for the specific language governing permissions and
5171+ * limitations under the License.
5172+ */
5173+
5174+#ifndef LITE_OFFLINE_PACKING_OPTIMIZER_H
5175+#define LITE_OFFLINE_PACKING_OPTIMIZER_H
5176+#include <string>
5177+#include <map>
5178+#include "base/base.h"
5179+#include "ir/anf.h"
5180+#include "ops/core_ops.h"
5181+#include "runtime/lite_kernel.h"
5182+#include "runtime/kernel_registry.h"
5183+
5184+namespace mindspore::lite {
5185+using OfflinePackingFunc = STATUS (*)(const mindspore::CNodePtr &cnode_ptr, const FuncGraphPtr &funcGraphPtr,
5186+                                      const lite::InnerContext *ctx);
5187+using InnerContextCreatorFunc = mindspore::lite::InnerContext *(*)();
5188+
5189+STATUS MatmulPacking(const mindspore::CNodePtr &cnode_ptr, const FuncGraphPtr &funcGraphPtr,
5190+                     const lite::InnerContext *ctx);
5191+mindspore::lite::InnerContext *InitInnerContextForAndroidArmCpu();
5192+
5193+enum class BackendType : uint8_t {
5194+  kUnknownBackend = 0,
5195+  kAndroidArmCpuBackend,
5196+};
5197+
5198+class PackDataWrapper {
5199+ public:
5200+  static PackDataWrapper &GetInstance() {
5201+    static PackDataWrapper instance;
5202+    return instance;
5203+  }
5204+
5205+  const kernel::LiteKernel *GetPackedKernel(const std::string &node_name) {
5206+    if (this->pack_mapping_.find(node_name) == this->pack_mapping_.end()) {
5207+      return nullptr;
5208+    }
5209+    return this->pack_mapping_[node_name];
5210+  }
5211+
5212+  void AddPackedKernel(const std::string &node_name, const kernel::LiteKernel *data) {
5213+    if (this->pack_mapping_.find(node_name) != this->pack_mapping_.end()) {
5214+      MS_LOG(WARNING) << "Key conflict when add packed kernel.";
5215+    }
5216+    this->pack_mapping_[node_name] = data;
5217+  }
5218+
5219+ private:
5220+  PackDataWrapper() = default;
5221+  ~PackDataWrapper() = default;
5222+
5223+ private:
5224+  std::map<std::string, const kernel::LiteKernel *> pack_mapping_;
5225+};
5226+
5227+class OfflinePackingOptimizer {
5228+ public:
5229+  OfflinePackingOptimizer() {
5230+    this->packing_strategies_selector_[BackendType::kAndroidArmCpuBackend] =
5231+      std::map<schema::PrimitiveType, OfflinePackingFunc>{
5232+        {schema::PrimitiveType::PrimitiveType_MatMulFusion, MatmulPacking},
5233+      };
5234+    this->ctx_creator_selector_[BackendType::kAndroidArmCpuBackend] = InitInnerContextForAndroidArmCpu;
5235+  }
5236+
5237+  STATUS Optimize(const FuncGraphPtr &func_graph, const std::string &target_backend);
5238+
5239+ private:
5240+  std::map<BackendType, std::map<schema::PrimitiveType, OfflinePackingFunc>> packing_strategies_selector_;
5241+  std::map<BackendType, InnerContextCreatorFunc> ctx_creator_selector_;
5242+};
5243+};      // namespace mindspore::lite
5244+#endif  // LITE_OFFLINE_PACKING_OPTIMIZER_H
5245diff --git a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc
5246index 51d3d992..96eec450 100644
5247--- a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc
5248+++ b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc
5249@@ -27,7 +27,15 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
5250   auto quantizer = WeightQuantizer(param_);
5251   const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
5252   const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
5253-  auto ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes);
5254+	int ret;
5255+	// when activation is perchannel quantization, weight perlayer quant
5256+	if (activation_perchannel_) {
5257+		const std::set<PrimitivePtr> support_per_layers_nodes = {prim::kPrimMatMulFusion};
5258+		ret =
5259+				quantizer.WeightQuant(func_graph, support_weight_quant_nodes, support_per_layers_nodes, symmetric_nodes);
5260+	} else {
5261+		ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes);
5262+	}
5263   if (ret != RET_OK) {
5264     MS_LOG(ERROR) << "Weight Quant failed.";
5265     return ret;
5266@@ -36,7 +44,8 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
5267   const std::set<PrimitivePtr> support_dynamic_quant_ops = {
5268     prim::kPrimMatMulFusion,
5269   };
5270-  ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, param_->commonQuantParam.skip_quant_node);
5271+  ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, param_->commonQuantParam.skip_quant_node,
5272+																			 activation_perchannel_);
5273   if (ret != RET_OK) {
5274     MS_LOG(ERROR) << "Insert dynamic quant failed.";
5275     return ret;
5276diff --git a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h
5277index 8a172e7b..00ed204b 100644
5278--- a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h
5279+++ b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h
5280@@ -43,6 +43,7 @@ class DynamicQuantizer : public Quantizer {
5281  public:
5282   explicit DynamicQuantizer(const std::shared_ptr<ConverterPara> &param) : Quantizer(param) {
5283     bit_num_ = param->commonQuantParam.bit_num;
5284+		activation_perchannel_ = (param->commonQuantParam.dynamic_strategy == quant::ACTIVATION_CHANNEL);
5285   }
5286   ~DynamicQuantizer() = default;
5287
5288@@ -53,6 +54,7 @@ class DynamicQuantizer : public Quantizer {
5289   int quant_max_{127};
5290   int quant_min_{-128};
5291   TypeId type_id_{kNumberTypeInt8};
5292+	bool activation_perchannel_ = false;
5293 };
5294 }  // namespace mindspore::lite::quant
5295 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
5296diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc
5297index 2f42240f..c528ffbd 100644
5298--- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc
5299+++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc
5300@@ -24,11 +24,15 @@
5301 #include "tools/optimizer/common/gllo_utils.h"
5302 #include "tools/optimizer/common/format_utils.h"
5303 #include "tools/common/node_util.h"
5304+#include "ops/op_name.h"
5305+#include "ops/fusion/mat_mul_fusion.h"
5306
5307 namespace mindspore::lite::quant {
5308 namespace {
5309 constexpr size_t kMinSize3 = 3;
5310 constexpr size_t kPrimitiveCOffset = 1;
5311+constexpr int kLastFisrtIndex = -1;
5312+constexpr int kLastSecondIndex = -2;
5313 }  // namespace
5314 ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type,
5315                                                            const std::vector<schema::QuantParamT> &quant_params) {
5316@@ -166,11 +170,17 @@ int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph)
5317 }
5318
5319 int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode,
5320-                                                        size_t index) {
5321+                                                        size_t index, bool activation_perchannel) {
5322   auto primitive = std::make_shared<ops::DynamicQuant>();
5323   auto primitive_c = primitive->GetPrim();
5324   primitive->set_dst_type(dst_type_);
5325-  primitive->set_symmetric(symmetric_);
5326+	bool symmetric = activation_perchannel ? true : false;
5327+  primitive->set_symmetric(symmetric);
5328+  primitive->set_activation_perchannel(activation_perchannel);
5329+  if (activation_perchannel && SetPreferAxis(cnode, index, primitive) != RET_OK) {
5330+    MS_LOG(ERROR) << "Set prefer axis failed, " << cnode->fullname_with_scope();
5331+    return RET_ERROR;
5332+  }
5333   auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(index)});
5334   auto name = cnode->fullname_with_scope() + "_dynamic_cast_node_" + std::to_string(index);
5335   dynamic_quant_cnode->set_fullname_with_scope(name);
5336@@ -181,7 +191,8 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap
5337     return RET_NULL_PTR;
5338   }
5339   dynamic_quant_cnode->set_abstract(abstract);
5340-  auto ret = UpdateDataType(cnode, dst_type_);
5341+  abstract->set_shape(cnode->input(index)->Shape());
5342+  auto ret = UpdateDataType(dynamic_quant_cnode, dst_type_);
5343   if (ret != RET_OK) {
5344     MS_LOG(ERROR) << cnode->fullname_with_scope() << " set new dtype failed.";
5345     return ret;
5346@@ -191,7 +202,39 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap
5347   return RET_OK;
5348 }
5349
5350-int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
5351+int InsertQuantNodeManager::SetPreferAxis(const CNodePtr &cnode, size_t index,
5352+                                          const std::shared_ptr<ops::DynamicQuant> &dynamic_primitive) {
5353+  auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
5354+  if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) {
5355+    auto matmul_prim = api::MakeShared<ops::MatMul>(primitive);
5356+    CHECK_NULL_RETURN(matmul_prim);
5357+    // For MatMul A
5358+    if (index == kInputIndex + kPrimOffset) {
5359+      if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) {
5360+        dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
5361+        dynamic_primitive->set_transpose(true);
5362+      } else {
5363+        dynamic_primitive->set_prefer_axis(kLastSecondIndex);
5364+        dynamic_primitive->set_transpose(false);
5365+      }
5366+    }
5367+    // For MatMul B
5368+    if (index == kWeightIndex + kPrimOffset) {
5369+      if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) {
5370+        dynamic_primitive->set_prefer_axis(kLastSecondIndex);
5371+        dynamic_primitive->set_transpose(true);
5372+      } else {
5373+        dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
5374+        dynamic_primitive->set_transpose(false);
5375+      }
5376+    }
5377+  } else {
5378+    MS_LOG(WARNING) << "cnode don't need prefer axis, cnode name: " << cnode->fullname_with_scope();
5379+  }
5380+  return RET_OK;
5381+}
5382+
5383+int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode, bool activation_perchannel) {
5384   auto op_name = cnode->fullname_with_scope();
5385   if (cnode->size() < kMinSize3) {
5386     MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
5387@@ -199,11 +242,11 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
5388   }
5389   auto input = cnode->input(kInputIndex + kPrimitiveCOffset);
5390   if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
5391-    InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset);
5392+    InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset, activation_perchannel);
5393   }
5394   auto weight = cnode->input(kWeightIndex + kPrimitiveCOffset);
5395   if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
5396-    InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset);
5397+    InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset, activation_perchannel);
5398   }
5399   return RET_OK;
5400 }
5401@@ -222,7 +265,8 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
5402
5403 int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
5404                                                    const std::set<PrimitivePtr> &support_dynamic_quant_ops,
5405-                                                   const std::set<std::string> &skip_quant_node) {
5406+                                                   const std::set<std::string> &skip_quant_node,
5407+																									 bool activation_perchannel) {
5408   MS_ASSERT(graph != nullptr);
5409   auto cnodes = graph->GetOrderedCnodes();
5410   for (auto &cnode : cnodes) {
5411@@ -244,7 +288,7 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
5412       MS_LOG(INFO) << "node:" << op_name << " type:" << type << " will not quantify.";
5413       continue;
5414     }
5415-    ret = NewDynamicQuantNode(graph, cnode);
5416+    ret = NewDynamicQuantNode(graph, cnode, activation_perchannel);
5417     if (ret != RET_OK) {
5418       MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed.";
5419       return ret;
5420diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h
5421index 3555a42c..7c0410dd 100644
5422--- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h
5423+++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h
5424@@ -36,7 +36,7 @@ class InsertQuantNodeManager {
5425   int InsertQuantDtypeCastNode(const FuncGraphPtr &graph);
5426
5427   int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
5428-                             const std::set<std::string> &skip_quant_node);
5429+                             const std::set<std::string> &skip_quant_node, bool activation_perchannel = false);
5430
5431  private:
5432   ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params);
5433@@ -45,15 +45,16 @@ class InsertQuantNodeManager {
5434
5435   int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const;
5436
5437-  int NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode);
5438+  int NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode, bool activation_perchannel = false);
5439
5440   int MarkDynamicQuantize(const CNodePtr &cnode);
5441
5442-  int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index);
5443+  int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index,
5444+																	bool activation_perchannel = false);
5445
5446+	int SetPreferAxis(const CNodePtr &cnode, size_t index, const std::shared_ptr<ops::DynamicQuant> &dynamic_primitive);
5447  private:
5448   TypeId dst_type_ = kNumberTypeInt8;
5449-  bool symmetric_ = false;
5450 };
5451 }  // namespace mindspore::lite::quant
5452 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H_
5453diff --git a/mindspore/lite/tools/converter/quantizer/quant_params.h b/mindspore/lite/tools/converter/quantizer/quant_params.h
5454index d7656802..e08b70cb 100644
5455--- a/mindspore/lite/tools/converter/quantizer/quant_params.h
5456+++ b/mindspore/lite/tools/converter/quantizer/quant_params.h
5457@@ -22,6 +22,7 @@
5458 #include <set>
5459 #include "schema/inner/model_generated.h"
5460 namespace mindspore::lite::quant {
5461+constexpr int kPrimOffset = 1;
5462 enum ActivationQuantizedMethod {
5463   MAX_MIN = 0,
5464   KL = 1,
5465@@ -40,6 +41,11 @@ enum DebugMode {
5466   DETAIL,
5467 };
5468
5469+enum DynamicQuantStrategy {
5470+  ACTIVATION_LAYER,
5471+	ACTIVATION_CHANNEL,
5472+};
5473+
5474 struct CommonQuantParam {
5475   schema::QuantType quant_type = schema::QuantType_QUANT_NONE;
5476   int bit_num = 8;
5477@@ -50,6 +56,7 @@ struct CommonQuantParam {
5478   DebugMode debug_mode = DETAIL;
5479   std::set<std::string> skip_quant_node;
5480   int thread_num = 4;
5481+	DynamicQuantStrategy dynamic_strategy = quant::ACTIVATION_LAYER;
5482 };
5483
5484 struct MixedBitWeightQuantParam {
5485--
54862.17.1
5487
5488