From 9ec126625b9b723925b1b0c488355e041888f8bc Mon Sep 17 00:00:00 2001 From: albert-yan Date: Thu, 13 Jul 2023 19:55:07 +0800 Subject: [PATCH] new dynamic quant algorigthm and init packed --- .../plugin/device/cpu/kernel/nnacl/BUILD.gn | 1 + .../opt/DynamicMatmulSdot4x4x16AIWI.S | 240 +++++- .../opt/DynamicMatmulSdot4x4x16AIWIForFp16.S | 789 ++++++++++++++++++ .../kernel/nnacl/dynamic_quant_parameter.h | 3 + .../kernel/nnacl/int8/dynamic_matmul_int8.c | 80 +- .../kernel/nnacl/int8/dynamic_matmul_int8.h | 48 +- .../kernel/nnacl/int8/dynamic_quant_int8.c | 24 +- .../kernel/nnacl/int8/dynamic_quant_int8.h | 2 + .../kernel/nnacl/int8/quant_dtype_cast_int8.c | 134 +++ .../kernel/nnacl/int8/quant_dtype_cast_int8.h | 6 + .../cpu/kernel/nnacl/matmul_parameter.h | 17 +- mindspore/core/ops/dynamic_quant.cc | 19 + mindspore/core/ops/dynamic_quant.h | 30 + mindspore/core/ops/op_name.h | 3 + mindspore/lite/BUILD.gn | 2 + mindspore/lite/schema/inner/ops_generated.h | 64 +- mindspore/lite/schema/ops.fbs | 3 + mindspore/lite/schema/ops_generated.h | 34 +- mindspore/lite/src/CMakeLists.txt | 2 + mindspore/lite/src/common/mmap_utils.cc | 63 ++ mindspore/lite/src/common/mmap_utils.h | 27 + mindspore/lite/src/common/ops/ops_def.cc | 3 + .../ops/populate/dynamic_quant_populate.cc | 3 + .../lite/src/common/primitive_t_utils.cc | 14 +- mindspore/lite/src/common/primitive_t_utils.h | 3 +- mindspore/lite/src/runtime/inner_context.h | 9 + .../runtime/kernel/cpu/int8/dynamic_quant.cc | 166 +++- .../runtime/kernel/cpu/int8/dynamic_quant.h | 23 +- .../kernel/cpu/int8/matmul_base_int8.h | 1 + .../cpu/int8/matmul_dynamic_base_int8.cc | 237 ++++-- .../cpu/int8/matmul_dynamic_base_int8.h | 33 +- .../kernel/cpu/int8/matmul_dynamic_int8.cc | 35 +- .../kernel/cpu/int8/matmul_dynamic_int8.h | 4 +- .../cpu/int8/matmul_dynamic_sdot_int8.cc | 132 ++- .../cpu/int8/matmul_dynamic_sdot_int8.h | 23 +- .../runtime/kernel/cpu/int8/matmul_int8.cc | 2 +- .../src/runtime/kernel/cpu/int8/matmul_int8.h | 4 +- mindspore/lite/src/runtime/kernel_registry.h | 4 +- mindspore/lite/src/runtime/lite_kernel.h | 2 + mindspore/lite/src/runtime/lite_model.cc | 7 +- mindspore/lite/src/runtime/lite_model.h | 1 + mindspore/lite/src/runtime/lite_session.cc | 45 +- mindspore/lite/src/runtime/lite_session.h | 3 +- .../src/runtime/runtime_packed_node_pass.cc | 358 ++++++++ .../src/runtime/runtime_packed_node_pass.h | 83 ++ mindspore/lite/tools/common/graph_util.cc | 103 +++ mindspore/lite/tools/common/graph_util.h | 6 + mindspore/lite/tools/converter/CMakeLists.txt | 4 + .../lite/tools/converter/anf_transform.cc | 8 + .../lite/tools/converter/anf_transform.h | 1 + .../config_parser/config_file_parser.cc | 19 + .../config_parser/config_file_parser.h | 9 + .../config_parser/cpu_option_param_parser.cc | 41 + .../config_parser/cpu_option_param_parser.h | 32 + .../config_parser/quant_param_parser.cc | 20 + .../config_parser/quant_param_parser.h | 1 + mindspore/lite/tools/converter/converter.cc | 19 + .../tools/converter/converter_packed_node.cc | 179 ++++ .../tools/converter/converter_packed_node.h | 29 + .../tools/converter/cxx_api/converter_para.h | 6 + .../converter/offline_packing_optimizer.cc | 307 +++++++ .../converter/offline_packing_optimizer.h | 87 ++ .../converter/quantizer/dynamic_quantizer.cc | 13 +- .../converter/quantizer/dynamic_quantizer.h | 2 + .../quantizer/insert_quant_node_manager.cc | 60 +- .../quantizer/insert_quant_node_manager.h | 9 +- .../tools/converter/quantizer/quant_params.h | 7 + 67 files changed, 3497 insertions(+), 251 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S create mode 100644 mindspore/lite/src/common/mmap_utils.cc create mode 100644 mindspore/lite/src/common/mmap_utils.h create mode 100644 mindspore/lite/src/runtime/runtime_packed_node_pass.cc create mode 100644 mindspore/lite/src/runtime/runtime_packed_node_pass.h create mode 100644 mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc create mode 100644 mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h create mode 100644 mindspore/lite/tools/converter/converter_packed_node.cc create mode 100644 mindspore/lite/tools/converter/converter_packed_node.h create mode 100644 mindspore/lite/tools/converter/offline_packing_optimizer.cc create mode 100644 mindspore/lite/tools/converter/offline_packing_optimizer.h diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn index 3427a8a4..64188a68 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn @@ -619,6 +619,7 @@ arm64_fp16_assembly_sources = [ optimizing_assembly_sources = [ "assembly/opt/DynamicMatmulSdot4x4x16AIWI.S", + "assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S", "assembly/opt/MatmulDpInt8Opt.S", "assembly/opt/MatmulDpInt8.S", "assembly/opt/MatmulOptR4Int8.S", diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S index efacd61b..bf646f32 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2021-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ // void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales, // float *bias, size_t row, size_t col, size_t stride, const int *a_sums, -// const int *b_sums, int64_t a_zp, int64_t b_zp_sum); +// const int *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); // x0: a(left matrix ptr) // x1: b(right matrix ptr) // x2: out ptr @@ -34,18 +34,23 @@ // x10: b_sums // x19/w19: a_zp // x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 +// x22: mode -> 0: TensorByTensor, 1:TensorByChannel, 2:ChannelByTensor, 3:ChannelByChannel asm_function DynamicMatmulSdot4x4x16AIWI - sub sp, sp, #144 + sub sp, sp, #160 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 ldr x8, [sp] ldr x9, [sp, #8] ldr x10, [sp, #16] ldr x19, [sp, #24] ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. dup v17.4s, wzr @@ -64,7 +69,7 @@ asm_function DynamicMatmulSdot4x4x16AIWI dup v30.4s, wzr dup v31.4s, wzr - mov x18, x1 // reload rhs ptr + mov x11, x1 // reload rhs ptr mov x17, x0 // reload lhs ptr mov x16, x3 // reload depth @@ -75,7 +80,7 @@ asm_function DynamicMatmulSdot4x4x16AIWI LoopDepth: ld1 {v0.16b}, [x17], #16 - ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x18], #64 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 sdot v16.4s, v1.16b, v0.4b[0] sdot v17.4s, v2.16b, v0.4b[0] @@ -100,8 +105,8 @@ LoopDepth: LoopDepthHalf: ld1 {v0.16b}, [x17], #16 - ld1 {v1.16b, v2.16b}, [x18] - add x18, x18, #64 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 sdot v16.4s, v1.16b, v0.4b[0] sdot v17.4s, v2.16b, v0.4b[0] sdot v20.4s, v1.16b, v0.4b[1] @@ -117,8 +122,8 @@ LoopDepthHalf: LoopDepthQuarter: ld1 {v0.16b}, [x17], #16 - ld1 {v1.16b}, [x18] - add x18, x18, #64 + ld1 {v1.16b}, [x11] + add x11, x11, #64 sdot v16.4s, v1.16b, v0.4b[0] sdot v20.4s, v1.16b, v0.4b[1] sdot v24.4s, v1.16b, v0.4b[2] @@ -225,28 +230,108 @@ Convert2Float: MultiplyScale: // multi_scale * input_matrix - ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x4] - - fmul v16.4s,v16.4s,v1.4s - fmul v17.4s,v17.4s,v2.4s - fmul v18.4s,v18.4s,v3.4s - fmul v19.4s,v19.4s,v4.4s - - fmul v20.4s,v20.4s,v1.4s - fmul v21.4s,v21.4s,v2.4s - fmul v22.4s,v22.4s,v3.4s - fmul v23.4s,v23.4s,v4.4s - - fmul v24.4s,v24.4s,v1.4s - fmul v25.4s,v25.4s,v2.4s - fmul v26.4s,v26.4s,v3.4s - fmul v27.4s,v27.4s,v4.4s - - fmul v28.4s,v28.4s,v1.4s - fmul v29.4s,v29.4s,v2.4s - fmul v30.4s,v30.4s,v3.4s - fmul v31.4s,v31.4s,v4.4s - + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b AddBias + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b AddBias + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b AddBias + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] AddBias: // +bias cbz x5, StoreData @@ -272,6 +357,88 @@ AddBias: fadd v30.4s,v30.4s,v3.4s fadd v31.4s,v31.4s,v4.4s +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4s, wzr + + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + b StoreData + +Relu6: + dup v1.4s, wzr + movi v2.4s, #6 + scvtf v2.4s, v2.4s + + // max (out, 0) + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + // min (out, 6) + + smin v16.4s,v16.4s,v2.4s + smin v17.4s,v17.4s,v2.4s + smin v18.4s,v18.4s,v2.4s + smin v19.4s,v19.4s,v2.4s + + smin v20.4s,v20.4s,v2.4s + smin v21.4s,v21.4s,v2.4s + smin v22.4s,v22.4s,v2.4s + smin v23.4s,v23.4s,v2.4s + + smin v24.4s,v24.4s,v2.4s + smin v25.4s,v25.4s,v2.4s + smin v26.4s,v26.4s,v2.4s + smin v27.4s,v27.4s,v2.4s + + smin v28.4s,v28.4s,v2.4s + smin v29.4s,v29.4s,v2.4s + smin v30.4s,v30.4s,v2.4s + smin v31.4s,v31.4s,v2.4s + + b StoreData + StoreData: cmp x7, #16 beq Write16 @@ -547,19 +714,19 @@ Write4: b StoreDataEnd Write3: - st1 {v16.1d}, [x15] + st1 {v16.1d}, [x15], #8 st1 {v16.s}[2], [x15] cmp x6, #1 beq StoreDataEnd - st1 {v20.1d}, [x14] + st1 {v20.1d}, [x14], #8 st1 {v20.s}[2], [x14] cmp x6, #2 beq StoreDataEnd - st1 {v24.1d}, [x13] + st1 {v24.1d}, [x13], #8 st1 {v24.s}[2], [x13] cmp x6, #3 beq StoreDataEnd - st1 {v28.1d}, [x12] + st1 {v28.1d}, [x12], #8 st1 {v28.s}[2], [x12] b StoreDataEnd @@ -589,9 +756,10 @@ Write1: st1 {v28.s}[0], [x12] b StoreDataEnd StoreDataEnd: - sub sp, sp, #144 + sub sp, sp, #160 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 ret #endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S new file mode 100644 index 00000000..e22a572a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S @@ -0,0 +1,789 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +// void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, +// float16_t *multi_scales, float16_t *bias, size_t row, size_t col, size_t stride, +// const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, +// int64_t act_type, int64_t mode); +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: deep +// x4: multi_scales +// x5: bias +// x6: row +// x7: col +// x8: stride +// x9: a_sums +// x10: b_sums +// x19/w19: a_zp +// x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 +// x22: mode -> 0: TensorByTensor, 1:TensorByChannel, 2:ChannelByTensor, 3:ChannelByChannel + +asm_function DynamicMatmulSdot4x4x16AIWIForFp16 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x19, [sp, #24] + ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] + + dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + mov x11, x1 // reload rhs ptr + mov x17, x0 // reload lhs ptr + mov x16, x3 // reload depth + + cmp x7, #4 + ble LoopDepthQuarter + cmp x7, #8 + ble LoopDepthHalf + +LoopDepth: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 + + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepth + b AddInputSum + +LoopDepthHalf: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthHalf + b AddInputSum + +LoopDepthQuarter: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthQuarter + b AddInputSum + +AddInputSum: + cmp w20, #0 + beq AddInputSumEnd + ld1 {v5.4s}, [x9], #16 + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + dup v8.4s, v5.s[2] + dup v9.4s, v5.s[3] + + sub v16.4s, v16.4s, v6.4s + sub v17.4s, v17.4s, v6.4s + sub v18.4s, v18.4s, v6.4s + sub v19.4s, v19.4s, v6.4s + sub v20.4s, v20.4s, v7.4s + sub v21.4s, v21.4s, v7.4s + sub v22.4s, v22.4s, v7.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v8.4s + sub v26.4s, v26.4s, v8.4s + sub v27.4s, v27.4s, v8.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v9.4s + sub v30.4s, v30.4s, v9.4s + sub v31.4s, v31.4s, v9.4s +AddInputSumEnd: + +AddWeightSum: + ld1 {v9.4s}, [x10], #16 + ld1 {v10.4s}, [x10], #16 + ld1 {v11.4s}, [x10], #16 + ld1 {v12.4s}, [x10], #16 + dup v13.4s, w19 + mul v9.4s, v9.4s, v13.4s + mul v10.4s, v10.4s, v13.4s + mul v11.4s, v11.4s, v13.4s + mul v12.4s, v12.4s, v13.4s + sub v16.4s, v16.4s, v9.4s + sub v17.4s, v17.4s, v10.4s + sub v18.4s, v18.4s, v11.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v9.4s + sub v21.4s, v21.4s, v10.4s + sub v22.4s, v22.4s, v11.4s + sub v23.4s, v23.4s, v12.4s + sub v24.4s, v24.4s, v9.4s + sub v25.4s, v25.4s, v10.4s + sub v26.4s, v26.4s, v11.4s + sub v27.4s, v27.4s, v12.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v10.4s + sub v30.4s, v30.4s, v11.4s + sub v31.4s, v31.4s, v12.4s + +AddZpSum: + mul w15, w19, w20 + cmp w15, #0 + beq AddZpSumEnd + dup v14.4s, w15 + add v16.4s, v16.4s, v14.4s + add v17.4s, v17.4s, v14.4s + add v18.4s, v18.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v20.4s, v20.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v22.4s, v22.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v24.4s, v24.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v26.4s, v26.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v28.4s, v28.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v30.4s, v30.4s, v14.4s + add v31.4s, v31.4s, v14.4s +AddZpSumEnd: + +Convert2Float: + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + scvtf v31.4s, v31.4s + +MultiplyScale: + // multi_scale * input_matrix + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b ConvertHalfPrecision + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b ConvertHalfPrecision + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b ConvertHalfPrecision + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] + +ConvertHalfPrecision: +// from single-precision convert to half-precision + fcvtn v16.4h,v16.4s + fcvtn v17.4h,v17.4s + fcvtn v18.4h,v18.4s + fcvtn v19.4h,v19.4s + + fcvtn v20.4h,v20.4s + fcvtn v21.4h,v21.4s + fcvtn v22.4h,v22.4s + fcvtn v23.4h,v23.4s + + fcvtn v24.4h,v24.4s + fcvtn v25.4h,v25.4s + fcvtn v26.4h,v26.4s + fcvtn v27.4h,v27.4s + + fcvtn v28.4h,v28.4s + fcvtn v29.4h,v29.4s + fcvtn v30.4h,v30.4s + fcvtn v31.4h,v31.4s + +AddBias: + // +bias + cbz x5, StoreData + ld1 {v1.4h, v2.4h, v3.4h, v4.4h}, [x5] + + fadd v16.4h,v16.4h,v1.4h + fadd v17.4h,v17.4h,v2.4h + fadd v18.4h,v18.4h,v3.4h + fadd v19.4h,v19.4h,v4.4h + + fadd v20.4h,v20.4h,v1.4h + fadd v21.4h,v21.4h,v2.4h + fadd v22.4h,v22.4h,v3.4h + fadd v23.4h,v23.4h,v4.4h + + fadd v24.4h,v24.4h,v1.4h + fadd v25.4h,v25.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v27.4h,v27.4h,v4.4h + + fadd v28.4h,v28.4h,v1.4h + fadd v29.4h,v29.4h,v2.4h + fadd v30.4h,v30.4h,v3.4h + fadd v31.4h,v31.4h,v4.4h + +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4h, wzr + + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + b StoreData + +Relu6: + dup v1.4h, wzr + movi v2.4h, #6 + scvtf v2.4h, v2.4h + + // max (out, 0) + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + // min (out, 6) + + smin v16.4h,v16.4h,v2.4h + smin v17.4h,v17.4h,v2.4h + smin v18.4h,v18.4h,v2.4h + smin v19.4h,v19.4h,v2.4h + + smin v20.4h,v20.4h,v2.4h + smin v21.4h,v21.4h,v2.4h + smin v22.4h,v22.4h,v2.4h + smin v23.4h,v23.4h,v2.4h + + smin v24.4h,v24.4h,v2.4h + smin v25.4h,v25.4h,v2.4h + smin v26.4h,v26.4h,v2.4h + smin v27.4h,v27.4h,v2.4h + + smin v28.4h,v28.4h,v2.4h + smin v29.4h,v29.4h,v2.4h + smin v30.4h,v30.4h,v2.4h + smin v31.4h,v31.4h,v2.4h + + b StoreData + +StoreData: + cmp x7, #16 + beq Write16 + + mov x15, x2 // reload out ptr + add x14, x15, x8 + add x13, x14, x8 + add x12, x13, x8 + + cmp x7, #15 + beq Write15 + cmp x7, #14 + beq Write14 + cmp x7, #13 + beq Write13 + cmp x7, #12 + beq Write12 + cmp x7, #11 + beq Write11 + cmp x7, #10 + beq Write10 + cmp x7, #9 + beq Write9 + cmp x7, #8 + beq Write8 + cmp x7, #7 + beq Write7 + cmp x7, #6 + beq Write6 + cmp x7, #5 + beq Write5 + cmp x7, #4 + beq Write4 + cmp x7, #3 + beq Write3 + cmp x7, #2 + beq Write2 + cmp x7, #1 + beq Write1 + b StoreDataEnd + +Write16: + cmp x6, #4 + beq Write16Row4 + cmp x6, #3 + beq Write16Row3 + cmp x6, #2 + beq Write16Row2 + cmp x6, #1 + beq Write16Row1 + + Write16Row4: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2], x8 + st1 {v28.4h,v29.4h,v30.4h,v31.4h}, [x2] + b StoreDataEnd + Write16Row3: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2] + b StoreDataEnd + Write16Row2: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2] + b StoreDataEnd + Write16Row1: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2] + b StoreDataEnd + +Write15: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15], #4 + st1 {v19.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14], #4 + st1 {v23.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13], #4 + st1 {v27.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12], #4 + st1 {v31.h}[2], [x12] + b StoreDataEnd + +Write14: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12] + b StoreDataEnd + +Write13: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.h}[0], [x12] + b StoreDataEnd + +Write12: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + b StoreDataEnd + +Write11: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15], #4 + st1 {v18.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14], #4 + st1 {v22.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13], #4 + st1 {v26.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12], #4 + st1 {v30.h}[2], [x12] + b StoreDataEnd + +Write10: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12] + b StoreDataEnd + +Write9: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.h}[0], [x12] + b StoreDataEnd + +Write8: + st1 {v16.4h,v17.4h}, [x15], #16 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + b StoreDataEnd + +Write7: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15], #4 + st1 {v17.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14], #4 + st1 {v21.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13], #4 + st1 {v25.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12], #4 + st1 {v29.h}[2], [x12] + b StoreDataEnd + +Write6: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12] + b StoreDataEnd + +Write5: + st1 {v16.4h}, [x15], #8 + st1 {v17.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.h}[0], [x12] + b StoreDataEnd + +Write4: + st1 {v16.4h}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12] + b StoreDataEnd + +Write3: + st1 {v16.s}[0], [x15], #4 + st1 {v16.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14], #4 + st1 {v20.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13], #4 + st1 {v24.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12], #4 + st1 {v28.h}[2], [x12] + b StoreDataEnd + +Write2: + st1 {v16.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12] + b StoreDataEnd + +Write1: + st1 {v16.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.h}[0], [x12] + b StoreDataEnd +StoreDataEnd: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h index 627b9ee6..dfc05f28 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h @@ -22,6 +22,9 @@ typedef struct DynamicQuantParameter { OpParameter op_parameter_; bool symmetric_; int64_t dst_type_; + bool activation_perchannel_; + int64_t prefer_axis_; + bool transpose_; } DynamicQuantParameter; #endif // MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c index 0bfa6475..a09a4359 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.c @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,15 +17,15 @@ #include "nnacl/int8/dynamic_matmul_int8.h" #include "nnacl/int8/fixed_point.h" -void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales, - float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, - const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum) { +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode) { /* * * row4x4-major * row4x16-major => (int8)row-major * support activation per-layer symmetric && weight per-layer/per-channel symmetric * */ for (int r = 0; r < row; r++) { - int64_t s2 = a_sums[r] * b_zp_sum; + int64_t s2 = a_sums[r]; for (int c = 0; c < col; c++) { int r4div = r / C4NUM, r4mod = r % C4NUM; int c16div = c / C16NUM, c16mod = c % C16NUM; @@ -39,18 +39,67 @@ void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_ int64_t s3 = b_sums[c] * a_zp; int64_t s4 = a_zp * b_zp_sum; size_t ci = r * stride / sizeof(float) + c; - out[ci] = multi_scales[c] * (s1 - s2 - s3 + s4); + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); if (bias != NULL) { out[ci] += bias[c]; } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } } } return; } +#ifdef ENABLE_FP16 +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode) { + /* * + * row4x4-major * row4x16-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + int64_t s2 = a_sums[r]; + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + int32_t s1 = 0; + for (int d = 0; d < deep4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod; + s1 += a[ai] * b[bi]; + } + int64_t s3 = b_sums[c] * a_zp; + int64_t s4 = a_zp * b_zp_sum; + size_t ci = r * stride / sizeof(float16_t) + c; + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); + if (bias != NULL) { + out[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } + } + } + return; +} +#endif + void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, - int deep, int deep16, size_t stride, int input_zp, float input_scale, - const float *filter_scale, const int filter_zp, bool filter_per_channel) { + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type) { /* * * row4x16-major * row16x4-major => (int8)row-major * support activation per-layer symmetric && weight per-layer/per-channel symmetric @@ -74,13 +123,20 @@ void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias s3 += input_zp * filter_zp; } value = s0 - s1 - s2 + s3; + int input_quant_index = input_per_channel ? r : 0; int filter_quant_index = filter_per_channel ? c : 0; - float multi_scale = input_scale * filter_scale[filter_quant_index]; + float multi_scale = input_scale[input_quant_index] * filter_scale[filter_quant_index]; size_t ci = r * stride + c; dst[ci] = multi_scale * value; if (bias != NULL) { dst[ci] += bias[c]; } + if (act_type == ActType_Relu) { + dst[ci] = MSMAX(0, dst[ci]); + } else if (act_type == ActType_Relu6) { + dst[ci] = MSMAX(0, dst[ci]); + dst[ci] = MSMIN(C6NUM, dst[ci]); + } } } return; @@ -166,8 +222,8 @@ void PackInput4x4Asm(const int8_t *src_ic, int8_t *pack_ic, size_t ic_4div, size "6: \n" : - : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), - [ ic_4res ] "r"(ic_4res) + : [src_ic] "r"(src_ic), [pack_ic] "r"(pack_ic), [src_stride] "r"(src_stride), [ic_4div] "r"(ic_4div), + [ic_4res] "r"(ic_4res) : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); } #endif @@ -276,7 +332,7 @@ void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, in "1:\n" : - : [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ row ] "r"(row_div), [ row_stride ] "r"(row_stride_int64) + : [src_ic] "r"(src_ic), [packed_ic] "r"(packed_ic), [row] "r"(row_div), [row_stride] "r"(row_stride_int64) : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); packed_ic += C4NUM * row_div; src_ic += row_div * row_stride; diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h index ef835898..77e861bb 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_matmul_int8.h @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,18 +27,46 @@ extern "C" { void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride); void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size); void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, - int deep, int deep16, size_t stride, int input_zp, float input_scale, - const float *filter_scale, const int filter_zp, bool filter_per_channel); + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type); void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order); void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order); -#ifdef ENABLE_ARM64 -void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales, - float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, - const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum); +#if defined(ENABLE_ARM64) && !defined(USE_AOS_GCC_TOOLCHAIN) +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#endif +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#ifdef ENABLE_FP16 +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); #endif -void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales, - float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, - const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum); + #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c index bca1cbca..4ec4ebb8 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.c @@ -16,6 +16,9 @@ #include "nnacl/int8/dynamic_quant_int8.h" void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) { +if (count == 0) { + return; + } #ifndef ENABLE_ARM64 for (int i = 0; i < count; ++i) { if (data[i] < *real_min) { @@ -26,7 +29,7 @@ void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *r } } #else - // avoid to compile optimize. + // avoid to compile optimize. volatile int count_4 = DOWN_ROUND(count, C4NUM); asm volatile( "mov x4, %[data]\n" // reload data @@ -63,3 +66,22 @@ void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *r } #endif } + +void CalculateAllChannelMinMax(const float *data, int count, float *real_min, float *real_max, int channel_length) { + int channel_total = count / channel_length; + for (int i = 0; i < channel_total; i++) { + CalculateMinMaxFp32(data + i * channel_length, channel_length, real_min + i, real_max + i); + } +} + +int GetBucketIndex(int dims[], int dim_size, int prefer_axis, int data_index) { + int stride = 1; + int bucket_count = dims[prefer_axis]; + for (int i = prefer_axis + 1; i < dim_size; i++) { + stride *= dims[i]; + } + if (stride == 0 || bucket_count == 0) { + return 0; + } + return (data_index / stride) % bucket_count; +} diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h index d4a63518..8fa0a9ed 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/dynamic_quant_int8.h @@ -25,6 +25,8 @@ extern "C" { #endif void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max); +void CalculateAllChannelMinMax(const float *data, int count, float *real_min, float *real_max, int channel_length); +int GetBucketIndex(int dims[], int dim_size, int prefer_axis, int data_index); #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c index 25050fda..753aa5dd 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.c @@ -202,6 +202,140 @@ int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float s return NNACL_OK; } +#ifdef ENABLE_ARM64 +inline void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, + int size, int channel_length, int32_t min_value, int32_t max_value) { + volatile float ivs[size]; + for (int i = 0; i < size; i++) { + volatile int channel_index = i / channel_length; + ivs[i] = 1.0f / scales[channel_index]; + } + volatile int32_t zp = zps[0]; + + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "mov x4, %[ivs]\n" // reload ivs + "dup v13.4s, %w[min_value]\n" + "dup v14.4s, %w[max_value]\n" + "cmp w8, #16\n" + "blt 1f\n" + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n" + "dup v8.4s, %w[zp]\n" + "dup v9.4s, %w[zp]\n" + "dup v10.4s, %w[zp]\n" + "dup v11.4s, %w[zp]\n" + "scvtf v4.4s, v8.4s\n" + "scvtf v5.4s, v9.4s\n" + "scvtf v6.4s, v10.4s\n" + "scvtf v7.4s, v11.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v4.4s, v0.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v5.4s, v1.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v6.4s, v2.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v7.4s, v3.4s, v12.4s\n" + + "fcvtas v0.4s, v4.4s\n" + "fcvtas v1.4s, v5.4s\n" + "fcvtas v2.4s, v6.4s\n" + "fcvtas v3.4s, v7.4s\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smax v1.4s, v1.4s, v13.4s\n" + "smax v2.4s, v2.4s, v13.4s\n" + "smax v3.4s, v3.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "smin v1.4s, v1.4s, v14.4s\n" + "smin v2.4s, v2.4s, v14.4s\n" + "smin v3.4s, v3.4s, v14.4s\n" + + "sqxtn v4.4h, v0.4s\n" + "sqxtn2 v4.8h, v1.4s\n" + "sqxtn v5.4h, v2.4s\n" + "sqxtn2 v5.8h, v3.4s\n" + "sqxtn v6.8b, v4.8h\n" + "sqxtn2 v6.16b, v5.8h\n" + "st1 {v6.16b}, [%[quant_values]], #16\n" + + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "scvtf s0, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr s4, [%[real_values]], #4\n" + "fmul s4, s4, s12\n" + "fadd s0, s0, s4\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "sqxtn v1.4h, v0.4s\n" + "sqxtn v0.8b, v1.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + + "bne 1b\n" + + "2:\n" + : + : [ quant_values ] "r"(quant_values), [ real_values ] "r"(real_values), [ scales ] "r"(scales), [ zp ] "r"(zp), + [ size ] "r"(size), [ channel_length ] "r"(channel_length), [ ivs ] "r"(ivs), [ min_value ] "r"(min_value), + [ max_value ] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "x4"); +} +#endif + +int DoPerchannelQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int channel_length, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp32ToInt8Perchannel_arm64(real_values, quant_values, scale, zp, size, channel_length, min_value, max_value); +#else + for (int i = 0; i < size; ++i) { + int channel_index = i / channel_length; + const float inverse_scale = 1.0f / scale[channel_index]; + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp[channel_index]); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +int QuantizeDataFp32ToInt8(const float real_value, int8_t *quant_value, float scale, int32_t zp, int32_t min_value, + int32_t max_value) { + if (quant_value == NULL) { + return NNACL_PARAM_INVALID; + } + const float inverse_scale = 1.0f / scale; + if (real_value == INFINITY) { + *quant_value = max_value; + } else if (real_value == -INFINITY) { + *quant_value = min_value; + } else { + int temp = round(real_value * inverse_scale + zp); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + *quant_value = (int8_t)temp; + } + return NNACL_OK; +} + int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, int32_t min_value, int32_t max_value) { if (quant_values == NULL || real_values == NULL) { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h index 251e9716..950b4287 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/int8/quant_dtype_cast_int8.h @@ -31,12 +31,18 @@ extern "C" { int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, int32_t min_value, int32_t max_value); +int DoPerchannelQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int channel_length, int32_t min_value, int32_t max_value); +int QuantizeDataFp32ToInt8(const float real_value, int8_t *quant_value, float scale, int32_t zp, int32_t min_value, + int32_t max_value); int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, int32_t min_value, int32_t max_value); #ifdef ENABLE_ARM64 void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, int32_t min_value, int32_t max_value); void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size); +void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, + int size, int channel_length, int32_t min_value, int32_t max_value); #endif int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h index 1f1913e1..8116ac58 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/matmul_parameter.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,16 @@ typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2, OutType_NC4HW4 = 3 } OutType; +typedef enum MatmulType { + // reserve 0 for base op + kNotImplemented = 0, + kMatmulInt8Cpu, + kMatmulDynamicInt8Cpu, + kMatmulDynamicSdotInt8Cpu, + kMatmulFp32BaseCpu, + kMatmulFp32Arm64Cpu, +} MatmulType; + typedef struct MatMulParameter { // Primitive parameter OpParameter op_parameter_; @@ -63,6 +73,7 @@ typedef struct MatMulParameter { ActType act_type_; bool use_axis_; int axis_; + MatmulType matmul_type_; } MatMulParameter; typedef struct MatmulQuantParameter { @@ -79,8 +90,8 @@ typedef struct MatmulQuantParameter { } MatmulQuantParameter; typedef struct MatmulDynamicQuantParameter { - float input_scale_; - int32_t input_zp_; + float *input_scale_; + int32_t *input_zp_; float *filter_scale_; int32_t *filter_zp_; } MatmulDynamicQuantParameter; diff --git a/mindspore/core/ops/dynamic_quant.cc b/mindspore/core/ops/dynamic_quant.cc index 77cadbc0..d11ee4ff 100644 --- a/mindspore/core/ops/dynamic_quant.cc +++ b/mindspore/core/ops/dynamic_quant.cc @@ -30,9 +30,28 @@ bool DynamicQuant::get_symmetric() const { } void DynamicQuant::set_dst_type(const int64_t dst_type) { (void)AddAttr(kDstType, api::MakeValue(dst_type)); } int64_t DynamicQuant::get_dst_type() const { return GetValue(GetAttr(kDstType)); } +void DynamicQuant::set_prefer_axis(const int64_t prefer_axis) { + (void)AddAttr(kPreferAxis, api::MakeValue(prefer_axis)); +} +int64_t DynamicQuant::get_prefer_axis() const { return GetValue(GetAttr(kPreferAxis)); } +void DynamicQuant::set_activation_perchannel(const bool activation_perchannel) { + (void)AddAttr(kActivationPerchannel, api::MakeValue(activation_perchannel)); +} +bool DynamicQuant::get_activation_perchannel() const { + auto value_ptr = this->GetAttr(kActivationPerchannel); + return GetValue(value_ptr); +} +void DynamicQuant::set_transpose(const bool transpose) { (void)AddAttr(kTrans, api::MakeValue(transpose)); } +bool DynamicQuant::get_transpose() const { + auto value_ptr = this->GetAttr(kTrans); + return GetValue(value_ptr); +} void DynamicQuant::Init(const bool symmetric, const int64_t dst_type) { this->set_symmetric(symmetric); this->set_dst_type(dst_type); + this->set_activation_perchannel(false); + this->set_prefer_axis(0); + this->set_transpose(false); } REGISTER_PRIMITIVE_C(kNameDynamicQuant, DynamicQuant); diff --git a/mindspore/core/ops/dynamic_quant.h b/mindspore/core/ops/dynamic_quant.h index ade36b4f..e7f1b7e6 100644 --- a/mindspore/core/ops/dynamic_quant.h +++ b/mindspore/core/ops/dynamic_quant.h @@ -61,6 +61,36 @@ class MIND_API DynamicQuant : public BaseOperator { /// /// \return the data type of output. int64_t get_dst_type() const; + + /// \brief Method to set prefer_axis attribute. + /// + /// \param[in] prefer_axis Define the preferred axis. + void set_prefer_axis(const int64_t prefer_axis); + + /// \brief Method to get prefer_axis attribute. + /// + /// \return the preferred axis. + int64_t get_prefer_axis() const; + + /// \brief Method to set activation perchannel attribute. + /// + /// \param[in] activation_perchannel Define whether activation perchannel quantization. + void set_activation_perchannel(const bool activation_perchannel); + + /// \brief Method to get activation perchannel attribute. + /// + /// \return Whether activation perchannel quantization. + bool get_activation_perchannel() const; + + /// \brief Method to set transpose attribute. + /// + /// \param[in] symmetric Define whether transpose matrix. + void set_transpose(const bool transpose); + + /// \brief Method to get transpose attribute. + /// + /// \return Whether transpose matrix. + bool get_transpose() const; }; abstract::AbstractBasePtr DynamicQuantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h index 7a509840..8dc2e3f4 100644 --- a/mindspore/core/ops/op_name.h +++ b/mindspore/core/ops/op_name.h @@ -22,6 +22,7 @@ namespace mindspore::ops { constexpr auto kAlpha = "alpha"; constexpr auto kActivation = "activation"; constexpr auto kActivationType = "activation_type"; +constexpr auto kActivationPerchannel = "activation_perchannel"; constexpr auto kAttentionQActType = "attention_q_act_type"; constexpr auto kAttentionKActType = "attention_k_act_type"; constexpr auto kAttentionVActType = "attention_v_act_type"; @@ -178,6 +179,7 @@ constexpr auto kDivisorOverride = "divisor_override"; constexpr auto kPostNmsTopn = "post_nms_topn"; constexpr auto kPower = "power"; constexpr auto kPreNmsTopn = "pre_nms_topn"; +constexpr auto kPreferAxis = "prefer_axis"; constexpr auto kRankSize = "rank_size"; constexpr auto kRatio = "ratio"; constexpr auto kReduction = "reduction"; @@ -209,6 +211,7 @@ constexpr auto kSummarize = "summarize"; constexpr auto kTimeMajor = "time_major"; constexpr auto kTolerance = "tolerance"; constexpr auto kTopK = "top_k"; +constexpr auto kTrans = "trans"; constexpr auto kTransposeA = "transpose_a"; constexpr auto kTransposeB = "transpose_b"; constexpr auto kNegativeSlope = "negative_slope"; diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn index a4d77b1c..86b80a28 100644 --- a/mindspore/lite/BUILD.gn +++ b/mindspore/lite/BUILD.gn @@ -142,11 +142,13 @@ all_lite_sources = [ "src/common/utils.cc", "src/common/graph_util.cc", "src/common/log.cc", + "src/common/mmap_utils.cc", "src/common/prim_util.cc", "src/common/tensor_util.cc", "src/runtime/allocator.cc", "src/runtime/inner_allocator.cc", "src/runtime/runtime_allocator.cc", + "src/runtime/runtime_packed_node_pass.cc", "src/runtime/infer_manager.cc", "src/runtime/runtime_shape_fusion_pass.cc", "src/runtime/runtime_pass.cc", diff --git a/mindspore/lite/schema/inner/ops_generated.h b/mindspore/lite/schema/inner/ops_generated.h index e0614168..86fdbad1 100644 --- a/mindspore/lite/schema/inner/ops_generated.h +++ b/mindspore/lite/schema/inner/ops_generated.h @@ -19484,6 +19484,9 @@ struct DynamicQuantT : public flatbuffers::NativeTable { typedef DynamicQuant TableType; bool symmetric = false; int64_t dst_type = 32LL; + bool activation_perchannel = false; + int64_t prefer_axis = 0; + bool transpose = false; }; struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -19494,7 +19497,10 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { } enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_SYMMETRIC = 4, - VT_DST_TYPE = 6 + VT_DST_TYPE = 6, + VT_ACTIVATION_PERCHANNEL = 8, + VT_PREFER_AXIS = 10, + VT_TRANSPOSE = 12 }; bool symmetric() const { return GetField(VT_SYMMETRIC, 0) != 0; @@ -19508,10 +19514,31 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { bool mutate_dst_type(int64_t _dst_type) { return SetField(VT_DST_TYPE, _dst_type, 32LL); } + bool activation_perchannel() const { + return GetField(VT_ACTIVATION_PERCHANNEL, 0) != 0; + } + bool mutate_activation_perchannel(bool _activation_perchannel) { + return SetField(VT_ACTIVATION_PERCHANNEL, static_cast(_activation_perchannel), 0); + } + int64_t prefer_axis() const { + return GetField(VT_PREFER_AXIS, 0); + } + bool mutate_prefer_axis(int64_t _prefer_axis) { + return SetField(VT_PREFER_AXIS, _prefer_axis, 0); + } + bool transpose() const { + return GetField(VT_TRANSPOSE, 0) != 0; + } + bool mutate_transpose(bool _transpose) { + return SetField(VT_TRANSPOSE, static_cast(_transpose), 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_SYMMETRIC) && VerifyField(verifier, VT_DST_TYPE) && + VerifyField(verifier, VT_ACTIVATION_PERCHANNEL) && + VerifyField(verifier, VT_PREFER_AXIS) && + VerifyField(verifier, VT_TRANSPOSE) && verifier.EndTable(); } DynamicQuantT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -19529,6 +19556,15 @@ struct DynamicQuantBuilder { void add_dst_type(int64_t dst_type) { fbb_.AddElement(DynamicQuant::VT_DST_TYPE, dst_type, 32LL); } + void add_activation_perchannel(bool activation_perchannel) { + fbb_.AddElement(DynamicQuant::VT_ACTIVATION_PERCHANNEL, static_cast(activation_perchannel), 0); + } + void add_prefer_axis(int64_t prefer_axis) { + fbb_.AddElement(DynamicQuant::VT_PREFER_AXIS, prefer_axis, 0); + } + void add_transpose(bool transpose) { + fbb_.AddElement(DynamicQuant::VT_TRANSPOSE, static_cast(transpose), 0); + } explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -19543,9 +19579,15 @@ struct DynamicQuantBuilder { inline flatbuffers::Offset CreateDynamicQuant( flatbuffers::FlatBufferBuilder &_fbb, bool symmetric = false, - int64_t dst_type = 32LL) { + int64_t dst_type = 32LL, + bool activation_perchannel = false, + int64_t prefer_axis = 0, + bool transpose = false) { DynamicQuantBuilder builder_(_fbb); + builder_.add_prefer_axis(prefer_axis); builder_.add_dst_type(dst_type); + builder_.add_transpose(transpose); + builder_.add_activation_perchannel(activation_perchannel); builder_.add_symmetric(symmetric); return builder_.Finish(); } @@ -26124,6 +26166,9 @@ inline void DynamicQuant::UnPackTo(DynamicQuantT *_o, const flatbuffers::resolve (void)_resolver; { auto _e = symmetric(); _o->symmetric = _e; } { auto _e = dst_type(); _o->dst_type = _e; } + { auto _e = activation_perchannel(); _o->activation_perchannel = _e; } + { auto _e = prefer_axis(); _o->prefer_axis = _e; } + { auto _e = transpose(); _o->transpose = _e; } } inline flatbuffers::Offset DynamicQuant::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DynamicQuantT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -26136,10 +26181,16 @@ inline flatbuffers::Offset CreateDynamicQuant(flatbuffers::FlatBuf struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DynamicQuantT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _symmetric = _o->symmetric; auto _dst_type = _o->dst_type; + auto _activation_perchannel = _o->activation_perchannel; + auto _prefer_axis = _o->prefer_axis; + auto _transpose = _o->transpose; return mindspore::schema::CreateDynamicQuant( _fbb, _symmetric, - _dst_type); + _dst_type, + _activation_perchannel, + _prefer_axis, + _transpose); } inline LSTMGradDataT *LSTMGradData::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -33528,10 +33579,13 @@ inline const flatbuffers::TypeTable *ReduceScatterTypeTable() { inline const flatbuffers::TypeTable *DynamicQuantTypeTable() { static const flatbuffers::TypeCode type_codes[] = { { flatbuffers::ET_BOOL, 0, -1 }, - { flatbuffers::ET_LONG, 0, -1 } + { flatbuffers::ET_LONG, 0, -1 }, + { flatbuffers::ET_BOOL, 0, -1 }, + { flatbuffers::ET_LONG, 0, -1 }, + { flatbuffers::ET_BOOL, 0, -1 } }; static const flatbuffers::TypeTable tt = { - flatbuffers::ST_TABLE, 2, type_codes, nullptr, nullptr, nullptr, nullptr + flatbuffers::ST_TABLE, 5, type_codes, nullptr, nullptr, nullptr, nullptr }; return &tt; } diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 023ee7e5..32775bac 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1231,6 +1231,9 @@ table ReduceScatter { table DynamicQuant { symmetric: bool = false; dst_type: long = 32; + activation_perchannel: bool = false; + prefer_axis: long = 0; + transpose: bool = false; } table LSTMGradData { diff --git a/mindspore/lite/schema/ops_generated.h b/mindspore/lite/schema/ops_generated.h index 5b15211a..393cefcd 100644 --- a/mindspore/lite/schema/ops_generated.h +++ b/mindspore/lite/schema/ops_generated.h @@ -12939,7 +12939,10 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef DynamicQuantBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_SYMMETRIC = 4, - VT_DST_TYPE = 6 + VT_DST_TYPE = 6, + VT_ACTIVATION_PERCHANNEL = 8, + VT_PREFER_AXIS = 10, + VT_TRANSPOSE = 12 }; bool symmetric() const { return GetField(VT_SYMMETRIC, 0) != 0; @@ -12947,10 +12950,22 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { int64_t dst_type() const { return GetField(VT_DST_TYPE, 32LL); } + bool activation_perchannel() const { + return GetField(VT_ACTIVATION_PERCHANNEL, 0) != 0; + } + int64_t prefer_axis() const { + return GetField(VT_PREFER_AXIS, 0); + } + bool transpose() const { + return GetField(VT_TRANSPOSE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_SYMMETRIC) && VerifyField(verifier, VT_DST_TYPE) && + VerifyField(verifier, VT_ACTIVATION_PERCHANNEL) && + VerifyField(verifier, VT_PREFER_AXIS) && + VerifyField(verifier, VT_TRANSPOSE) && verifier.EndTable(); } }; @@ -12965,6 +12980,15 @@ struct DynamicQuantBuilder { void add_dst_type(int64_t dst_type) { fbb_.AddElement(DynamicQuant::VT_DST_TYPE, dst_type, 32LL); } + void add_activation_perchannel(bool activation_perchannel) { + fbb_.AddElement(DynamicQuant::VT_ACTIVATION_PERCHANNEL, static_cast(activation_perchannel), 0); + } + void add_prefer_axis(int64_t prefer_axis) { + fbb_.AddElement(DynamicQuant::VT_PREFER_AXIS, prefer_axis, 0); + } + void add_transpose(bool transpose) { + fbb_.AddElement(DynamicQuant::VT_TRANSPOSE, static_cast(transpose), 0); + } explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -12979,9 +13003,15 @@ struct DynamicQuantBuilder { inline flatbuffers::Offset CreateDynamicQuant( flatbuffers::FlatBufferBuilder &_fbb, bool symmetric = false, - int64_t dst_type = 32LL) { + int64_t dst_type = 32LL, + bool activation_perchannel = false, + int64_t prefer_axis = 0, + bool transpose = false) { DynamicQuantBuilder builder_(_fbb); + builder_.add_prefer_axis(prefer_axis); builder_.add_dst_type(dst_type); + builder_.add_transpose(transpose); + builder_.add_activation_perchannel(activation_perchannel); builder_.add_symmetric(symmetric); return builder_.Finish(); } diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 31941fb1..d28c30d9 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -111,6 +111,7 @@ set(LITE_SRC ${API_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/common/context_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/file_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/common/mmap_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/log.cc @@ -137,6 +138,7 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/sub_graph_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/scheduler.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_packed_node_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/runtime/pack_weight_manager.cc diff --git a/mindspore/lite/src/common/mmap_utils.cc b/mindspore/lite/src/common/mmap_utils.cc new file mode 100644 index 00000000..ca8f8d1e --- /dev/null +++ b/mindspore/lite/src/common/mmap_utils.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/common/mmap_utils.h" +#include "src/common/file_utils.h" +#if !defined(_WIN32) && !defined(_WIN64) +#include +#include +#include +#endif + +namespace mindspore { +namespace lite { +void *ReadFileByMmap(const std::string &file, size_t *size) { +#if !defined(_WIN32) && !defined(_WIN64) && !defined(MS_COMPILE_IOS) + auto real_path = RealPath(file.c_str()); + auto fd = open(real_path.c_str(), O_RDONLY); + if (fd == -1) { + MS_LOG(ERROR) << "Could not open " << file; + return nullptr; + } + struct stat fd_stat; + if (fstat(fd, &fd_stat) != 0) { + MS_LOG(ERROR) << "Get fd stat error."; + close(fd); + return nullptr; + } + *size = fd_stat.st_size; + auto mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED | MAP_POPULATE, fd, 0); + close(fd); + if (mmap_buffers == MAP_FAILED) { + MS_LOG(ERROR) << "Model mmap failed."; + return nullptr; + } + return mmap_buffers; +#else + MS_LOG(ERROR) << "Mmap is unsupported on windows."; + return nullptr; +#endif +} + +void UnmapMmapBuffer(void *buffer, size_t size) { +#if !defined(_WIN32) && !defined(_WIN64) + (void)munmap(buffer, size); +#else + MS_LOG(ERROR) << "Mmap is unsupported on windows."; +#endif +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/common/mmap_utils.h b/mindspore/lite/src/common/mmap_utils.h new file mode 100644 index 00000000..bdd7c9a5 --- /dev/null +++ b/mindspore/lite/src/common/mmap_utils.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_COMMON_MMAP_UTILS_H_ +#define MINDSPORE_LITE_SRC_COMMON_MMAP_UTILS_H_ + +#include + +namespace mindspore { +namespace lite { +void *ReadFileByMmap(const std::string &file, size_t *size); +void UnmapMmapBuffer(void *buffer, size_t size); +} // namespace lite +} // namespace mindspore +#endif diff --git a/mindspore/lite/src/common/ops/ops_def.cc b/mindspore/lite/src/common/ops/ops_def.cc index 011c37df..cfab3113 100644 --- a/mindspore/lite/src/common/ops/ops_def.cc +++ b/mindspore/lite/src/common/ops/ops_def.cc @@ -1231,6 +1231,9 @@ OP_SCHEMA_DEF_END(ReduceScatter) OP_SCHEMA_DEF(DynamicQuant) OP_ATTR_WITH_VALUE(symmetric, bool, false) OP_ATTR_WITH_VALUE(dst_type, long, 32) +OP_ATTR_WITH_VALUE(activation_perchannel, bool, false) +OP_ATTR_WITH_VALUE(prefer_axis, long, 0) +OP_ATTR_WITH_VALUE(transpose, bool, false) OP_SCHEMA_DEF_END(DynamicQuant) OP_SCHEMA_DEF(LSTMGradData) diff --git a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc index b7e62e6c..fe8a939e 100644 --- a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc +++ b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc @@ -38,6 +38,9 @@ OpParameter *PopulateDynamicQuantParameter(const void *prim) { param->op_parameter_.type_ = primitive->value_type(); param->dst_type_ = value->dst_type(); param->symmetric_ = value->symmetric(); + param->activation_perchannel_ = value->activation_perchannel(); + param->prefer_axis_ = value->prefer_axis(); + param->transpose_ = value->transpose(); return reinterpret_cast(param); } REG_POPULATE(PrimitiveType_DynamicQuant, PopulateDynamicQuantParameter, SCHEMA_CUR); diff --git a/mindspore/lite/src/common/primitive_t_utils.cc b/mindspore/lite/src/common/primitive_t_utils.cc index ad406562..db7c7ef0 100644 --- a/mindspore/lite/src/common/primitive_t_utils.cc +++ b/mindspore/lite/src/common/primitive_t_utils.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace lite { constexpr size_t INITIAL_SIZE = 1024; -const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) { +const schema::Primitive *ConvertToPrimitive(const schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb) { if (primitive_t == nullptr || fbb == nullptr) { MS_LOG(ERROR) << "primitiveT or fbb is nullptr."; return nullptr; @@ -71,6 +71,18 @@ std::unique_ptr GetPrimitiveT(const std::shared_ptr(SCHEMA_VERSION::SCHEMA_CUR)); + return static_cast(prim_type); +} } // namespace lite } // namespace mindspore #endif diff --git a/mindspore/lite/src/common/primitive_t_utils.h b/mindspore/lite/src/common/primitive_t_utils.h index 7fe3e781..dba02777 100644 --- a/mindspore/lite/src/common/primitive_t_utils.h +++ b/mindspore/lite/src/common/primitive_t_utils.h @@ -24,9 +24,10 @@ namespace mindspore { namespace lite { -const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb); +const schema::Primitive *ConvertToPrimitive(const schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb); OpParameter *GetOpParameter(schema::PrimitiveT *primitive_t); std::unique_ptr GetPrimitiveT(const std::shared_ptr &op); +schema::PrimitiveType GetSchemaPrimType(const schema::PrimitiveT *primitive_t); } // namespace lite } // namespace mindspore #endif diff --git a/mindspore/lite/src/runtime/inner_context.h b/mindspore/lite/src/runtime/inner_context.h index ff58995f..52537e93 100644 --- a/mindspore/lite/src/runtime/inner_context.h +++ b/mindspore/lite/src/runtime/inner_context.h @@ -32,6 +32,13 @@ #endif namespace mindspore::lite { +typedef struct InstructionsContext { + // Instructions should be checked in the beginning. + bool support_fp16 = false; + bool support_sdot = false; + bool support_sse = false; + bool support_avx512 = false; +} InstructionsContext; #ifdef ENABLE_MINDRT constexpr int kDefaultParallelNum = 2; #endif @@ -77,6 +84,8 @@ struct MS_API InnerContext : public Context { void ReplaceLinkInfoSenderWithNewOne(void *new_sender, void *old_sender); + InstructionsContext instructions_ctx_; + private: bool IsAllDeviceTypeValid() const; diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc index 41b8b58b..f25bf288 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.cc @@ -47,6 +47,13 @@ int DynamicQuantCPUKernel::Prepare() { src_dtype_ = in_tensor->data_type(); dst_dtype_ = param->dst_type_; symmetric_ = param->symmetric_; + activation_perchannel_ = param->activation_perchannel_; + prefer_axis_ = param->prefer_axis_; + transpose_ = param->transpose_; + // shape_size_ = in_tensor->shape().size(); + // for (int i = 0; i < shape_size_; i++) { + // input_shape_[i] = in_tensor->shape().at(i); + // } if (out_tensor->data_type() != dst_dtype_) { MS_LOG(ERROR) << "param data type and tensor data type do not match."; return RET_ERROR; @@ -68,10 +75,33 @@ int DynamicQuantCPUKernel::ReSize() { // Limit for 8 thread thread_n_num_ = MSMIN(thread_n_num_, kBucketNums); } - for (int i = 0; i < kBucketNums; ++i) { - real_min_array_[i] = FLT_MAX; - real_max_array_[i] = FLT_MIN; + + int min_max_array_size = 0; + if (activation_perchannel_) { + auto dims = in_tensor->shape(); + if (prefer_axis_ < 0) { + prefer_axis_ += dims.size(); + } + channel_num_ = dims[prefer_axis_]; + MS_CHECK_GT(channel_num_, 0, RET_ERROR); + channel_length_ = num_unit_ / channel_num_; + thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); + if (channel_length_ > thread_n_stride_) { + thread_n_num_ = 1; + } + min_max_array_size = channel_num_; + } else { + min_max_array_size = kBucketNums; } + real_min_ = (float *)malloc(min_max_array_size * sizeof(float)); + real_max_ = (float *)malloc(min_max_array_size * sizeof(float)); + if (real_min_ == nullptr || real_max_ == nullptr) { + return RET_NULL_PTR; + } + for (int i = 0; i < min_max_array_size; ++i) { + real_min_[i] = FLT_MAX; + real_max_[i] = -FLT_MAX; + } MS_CHECK_GT(thread_n_num_, 0, RET_ERROR); thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_); return RET_OK; @@ -84,8 +114,20 @@ int DynamicQuantCPUKernel::CalculateMinMax(int task_id) { } int thread_offset = task_id * thread_n_stride_; float *data = float32_ptr_ + thread_offset; - - CalculateMinMaxFp32(data, num_unit_thread, &real_min_array_[task_id], &real_max_array_[task_id]); + if (activation_perchannel_) { + int channel_offset = task_id * thread_n_stride_ / channel_length_; + float *real_min = real_min_ + channel_offset; + float *real_max = real_max_ + channel_offset; + if (!transpose_) { + CalculateAllChannelMinMax(data, num_unit_thread, real_min, real_max, channel_length_); + } else { + MS_LOG(ERROR) << "Matrix a transpose not supported."; + } + } else { + float *real_min = real_min_ + task_id; + float *real_max = real_max_ + task_id; + CalculateMinMaxFp32(data, num_unit_thread, real_min, real_max); + } return RET_OK; } @@ -100,34 +142,33 @@ int CalculateMinMaxRun(void *cdata, int task_id, float, float) { return RET_OK; } -void DynamicQuantCPUKernel::ReduceMinMaxFp32() { +void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() { + float real_min = FLT_MAX; + float real_max = -FLT_MAX; for (int i = 0; i < kBucketNums; i++) { - if (real_min_array_[i] < real_min_) { - real_min_ = real_min_array_[i]; - } - if (real_max_array_[i] > real_max_) { - real_max_ = real_max_array_[i]; - } + if (real_min_[i] < real_min) { + real_min = real_min_[i]; + } + if (real_max_[i] > real_max) { + real_max = real_max_[i]; + } } - return; -} -void DynamicQuantCPUKernel::CalculateScaleZp() { lite::LiteQuantParam quant_parm; double scale; int zp = 0; constexpr int kQSymmetricRange = 255; constexpr int kQAsymmetricRange = 254; if (!symmetric_) { - auto range = real_max_ - real_min_; + auto range = real_max - real_min; if (range <= 0) { range = kDefaultRange; MS_LOG(WARNING) << name_ << " range is 0 and set the range to 0.01."; } scale = range / kQSymmetricRange; // -128 ~ 127 - zp = static_cast(std::round(INT8_MIN - real_min_ / scale)); + zp = static_cast(std::round(INT8_MIN - real_min / scale)); } else { - auto max = std::max(abs(real_max_), abs(real_min_)); + auto max = std::max(abs(real_max), abs(real_min)); scale = 2 * max / kQAsymmetricRange; // -127 ~ 127 } quant_parm.scale = scale; @@ -138,27 +179,87 @@ void DynamicQuantCPUKernel::CalculateScaleZp() { return; } +void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() { + std::vector quant_params; + for (int i = 0; i < channel_num_; ++i) { + float real_min = real_min_[i]; + float real_max = real_max_[i]; + + lite::LiteQuantParam quant_parm; + double scale; + int zp = 0; + constexpr int kQSymmetricRange = 255; + constexpr int kQAsymmetricRange = 254; + if (!symmetric_) { + auto range = real_max - real_min; + if (range <= 0) { + range = kDefaultRange; + MS_LOG(WARNING) << name_ << " range is 0 and set the range to 0.01."; + } + scale = range / kQSymmetricRange; // -128 ~ 127 + zp = static_cast(std::round(INT8_MIN - real_min / scale)); + } else { + auto max = std::max(abs(real_max), abs(real_min)); + scale = 2 * max / kQAsymmetricRange; // -127 ~ 127 + } + quant_parm.scale = scale; + quant_parm.zeroPoint = zp; + quant_parm.bitNum = k8Bit; + quant_parm.inited = true; + quant_params.push_back(quant_parm); + } + this->out_tensors_.front()->set_quant_params(quant_params); + return; +} + int DynamicQuantCPUKernel::QuantData(int task_id) { int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_); if (num_unit_thread <= 0) { return RET_OK; } - int thread_offset = task_id * thread_n_stride_; - auto quant_arg = out_tensors_.front()->quant_params().front(); - int ret; TypeId data_type = out_tensors_.front()->data_type(); - if (data_type == TypeId::kNumberTypeInt8) { - ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX); - } else { + if (data_type != TypeId::kNumberTypeInt8) { MS_LOG(ERROR) << "Data type not supported:" << data_type; return RET_PARAM_INVALID; } - if (ret != RET_OK) { - MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; - return RET_ERROR; + int thread_offset = task_id * thread_n_stride_; + int ret; + if (activation_perchannel_) { + if (out_tensors_.front()->quant_params().size() != static_cast(channel_num_)) { + return RET_ERROR; + } + float *scale = (float *)malloc(channel_num_ * sizeof(float)); + int32_t *zero_point = (int32_t *)malloc(channel_num_ * sizeof(int32_t)); + for (int i = 0; i < channel_num_; i++) { + auto quant_arg = out_tensors_.front()->quant_params().at(i); + scale[i] = quant_arg.scale; + zero_point[i] = quant_arg.zeroPoint; + } + if (transpose_) { + MS_LOG(ERROR) << "Matrix a transpose not supported."; + free(scale); + free(zero_point); + return RET_ERROR; + } else { + ret = DoPerchannelQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, scale, + zero_point, num_unit_thread, channel_length_, (int32_t)INT8_MIN, (int32_t)INT8_MAX); + free(scale); + free(zero_point); + if (ret != RET_OK) { + MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + } + } else { + auto quant_arg = out_tensors_.front()->quant_params().front(); + ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, + quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX); + if (ret != RET_OK) { + MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } } - return RET_OK; + return RET_OK; } int QuantDataRun(void *cdata, int task_id, float, float) { @@ -182,8 +283,11 @@ int DynamicQuantCPUKernel::Run() { MS_LOG(ERROR) << "Run error error_code[" << ret << "]"; return RET_ERROR; } - ReduceMinMaxFp32(); - CalculateScaleZp(); + if (activation_perchannel_) { + CalculatePerChannelScaleZp(); + } else { + CalculatePerlayerScaleZp(); + } ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_n_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Run error error_code[" << ret << "]"; diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h index 6acb0d8d..e44c7643 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/dynamic_quant.h @@ -19,6 +19,7 @@ #include #include +#include #include "src/runtime/lite_kernel.h" namespace mindspore::kernel { @@ -27,7 +28,10 @@ class DynamicQuantCPUKernel : public LiteKernel { DynamicQuantCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) : LiteKernel(parameter, inputs, outputs, ctx), thread_num_(ctx->thread_num_) {} - ~DynamicQuantCPUKernel() override = default; + ~DynamicQuantCPUKernel() override { + free(real_min_); + free(real_max_); + }; int Prepare() override; int ReSize() override; @@ -37,8 +41,8 @@ class DynamicQuantCPUKernel : public LiteKernel { int CalculateMinMax(int task_id); private: - void ReduceMinMaxFp32(); - void CalculateScaleZp(); + void CalculatePerlayerScaleZp(); + void CalculatePerChannelScaleZp(); private: int thread_num_; @@ -47,14 +51,19 @@ class DynamicQuantCPUKernel : public LiteKernel { int num_unit_{0}; int8_t *int8_ptr_ = nullptr; float *float32_ptr_ = nullptr; + float *real_min_ = nullptr; + float *real_max_ = nullptr; - float real_min_array_[8]; - float real_max_array_[8]; - float real_min_ = FLT_MAX; - float real_max_ = FLT_MIN; int32_t src_dtype_{0}; int32_t dst_dtype_{0}; bool symmetric_ = false; + bool activation_perchannel_ = false; + bool transpose_ = false; + int32_t prefer_axis_{-1}; + // int32_t input_shape_[8]; + // int32_t shape_size_{0}; + int32_t channel_num_{0}; + int32_t channel_length_{0}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h index a9383eac..5e360789 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_base_int8.h @@ -36,6 +36,7 @@ class MatmulBaseInt8CPUKernel : public LiteKernel { const std::vector &outputs, const lite::InnerContext *ctx) : LiteKernel(parameter, inputs, outputs, ctx) { param_ = reinterpret_cast(op_parameter_); + param_->matmul_type_ = MatmulType::kNotImplemented; } ~MatmulBaseInt8CPUKernel() override; int Prepare() override; diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc index 166510ec..c51d7cc5 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,9 @@ #include "src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h" #include "nnacl/int8/dynamic_matmul_int8.h" +using mindspore::lite::kCHWDimNumber; +using mindspore::lite::kHWDimNumber; +using mindspore::lite::kNCHWDimNumber; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; @@ -79,20 +82,20 @@ int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() { } int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1]; filter_per_channel_ = (weight_quant_params.size() > 1); - channel_num_ = filter_per_channel_ ? col : 1; - if (static_cast(weight_quant_params.size()) != channel_num_) { + auto channel_num = filter_per_channel_ ? col : 1; + if (static_cast(weight_quant_params.size()) != channel_num) { MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size() - << " != channel_num_:" << channel_num_; + << " != channel_num:" << channel_num; return RET_ERROR; } - quant_param_->filter_scale_ = reinterpret_cast(malloc(channel_num_ * sizeof(float))); + quant_param_->filter_scale_ = reinterpret_cast(malloc(channel_num * sizeof(float))); CHECK_NULL_RETURN(quant_param_->filter_scale_); - memset(quant_param_->filter_scale_, 0, sizeof(channel_num_)); - quant_param_->filter_zp_ = reinterpret_cast(malloc(channel_num_ * sizeof(int32_t))); + memset(quant_param_->filter_scale_, 0, sizeof(channel_num)); + quant_param_->filter_zp_ = reinterpret_cast(malloc(channel_num * sizeof(int32_t))); CHECK_NULL_RETURN(quant_param_->filter_zp_); - memset(quant_param_->filter_zp_, 0, sizeof(channel_num_)); + memset(quant_param_->filter_zp_, 0, sizeof(channel_num)); - for (int i = 0; i < channel_num_; i++) { + for (int i = 0; i < channel_num; i++) { quant_param_->filter_scale_[i] = static_cast(weight_quant_params[i].scale); quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint; } @@ -105,57 +108,68 @@ void MatmulDynamicBaseInt8CPUKernel::ResizeMatrixBParameter() { for (size_t i = 0; i < w_shape.size() - kSize2; ++i) { batch *= w_shape[i]; } - param_->batch = batch; + b_batch_ = batch; param_->col_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1]; param_->deep_ = param_->b_transpose_ ? w_shape[w_shape.size() - kSize1] : w_shape[w_shape.size() - kSize2]; param_->col_align_ = UP_ROUND(param_->col_, col_tile_); param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_); - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_)); - thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_); + thread_num_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_)); + thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_num_); return; } void MatmulDynamicBaseInt8CPUKernel::FreeTmpBuffer() { - if (pack_a_ptr_ != nullptr) { - free(pack_a_ptr_); - pack_a_ptr_ = nullptr; - } - if (pack_b_ptr_ != nullptr) { + FreeMatrixABuffer(); + if (pack_b_ptr_ != nullptr && !weight_is_packed_) { free(pack_b_ptr_); pack_b_ptr_ = nullptr; } - if (input_sums_ != nullptr) { - free(input_sums_); - input_sums_ = nullptr; - } - if (weight_sums_ != nullptr) { + if (weight_sums_ != nullptr && !weight_is_packed_) { free(weight_sums_); weight_sums_ = nullptr; } - if (fp32_bias_ptr_ != nullptr) { - free(fp32_bias_ptr_); - fp32_bias_ptr_ = nullptr; + if (bias_ptr_ != nullptr) { + free(bias_ptr_); + bias_ptr_ = nullptr; } - return; } -int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam() { +int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam(std::vector *scales, std::vector *zp) { auto in_quant_params = in_tensors_.at(kInputIndex)->quant_params(); if (in_quant_params.empty()) { MS_LOG(ERROR) << "invalid in quant param"; return RET_ERROR; } - quant_param_->input_zp_ = in_quant_params.front().zeroPoint; - quant_param_->input_scale_ = static_cast(in_quant_params.front().scale); + input_per_channel_ = (in_quant_params.size() > 1); + auto channel_num = input_per_channel_ ? param_->row_ : 1; + if (static_cast(in_quant_params.size()) != channel_num) { + MS_LOG(ERROR) << in_tensors_.at(kInputIndex)->tensor_name() << " quant params size:" << in_quant_params.size() + << " != channel_num:" << channel_num; + return RET_ERROR; + } + scales->resize(channel_num); + zp->resize(channel_num); + for (int i = 0; i < channel_num; ++i) { + (*scales)[i] = in_quant_params[i].scale; + (*zp)[i] = in_quant_params[i].zeroPoint; + } + quant_param_->input_zp_ = zp->data(); + quant_param_->input_scale_ = scales->data(); return RET_OK; } int MatmulDynamicBaseInt8CPUKernel::TransferB() { + if (weight_is_packed_) { + CHECK_NULL_RETURN(weight_sums_tensor_); + pack_b_ptr_ = static_cast(in_tensors_.at(kWeightIndex)->data()); + weight_sums_ = static_cast(weight_sums_tensor_->data()); + return RET_OK; + } auto weight_data = reinterpret_cast(in_tensors_.at(kWeightIndex)->data()); CHECK_NULL_RETURN(weight_data); - for (int i = 0; i < param_->batch; i++) { + for (int i = 0; i < b_batch_; i++) { auto current_weight = weight_data + i * param_->deep_ * param_->col_; auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; auto current_sums = weight_sums_ + i * param_->col_align_; @@ -168,40 +182,51 @@ int MatmulDynamicBaseInt8CPUKernel::TransferB() { CalcWeightSums(current_weight, param_->deep_, param_->col_, current_sums, RowMajor); } } + return RET_OK; } int MatmulDynamicBaseInt8CPUKernel::InitMatrixABuffer() { - if (pack_a_ptr_ != nullptr) { - free(pack_a_ptr_); - pack_a_ptr_ = nullptr; + size_t pack_a_size = param_->row_align_ * param_->deep_align_ * sizeof(int8_t); + size_t sum_a_size = param_->row_align_ * sizeof(int); + if (ms_context_ != nullptr && ms_context_->allocator != nullptr) { + pack_a_ptr_ = reinterpret_cast(ms_context_->allocator->Malloc(pack_a_size + sum_a_size)); + } else { + pack_a_ptr_ = reinterpret_cast(malloc(pack_a_size + sum_a_size)); } - pack_a_ptr_ = reinterpret_cast(malloc(param_->row_align_ * param_->deep_align_ * sizeof(int8_t))); if (pack_a_ptr_ == nullptr) { - FreeTmpBuffer(); - return RET_ERROR; + MS_LOG(ERROR) << "alloc run-buffer for matrix-a failed."; + return lite::RET_NULL_PTR; } - if (input_sums_ != nullptr) { - free(input_sums_); - input_sums_ = nullptr; + input_sums_ = reinterpret_cast(pack_a_ptr_ + pack_a_size); + memset(pack_a_ptr_, 0, pack_a_size + sum_a_size); + return RET_OK; +} + +void MatmulDynamicBaseInt8CPUKernel::FreeMatrixABuffer() { + if (pack_a_ptr_ == nullptr) { + return; } - input_sums_ = reinterpret_cast(malloc(param_->row_align_ * sizeof(int))); - if (input_sums_ == nullptr) { - FreeTmpBuffer(); - return RET_ERROR; + if (ms_context_ != nullptr && ms_context_->allocator != nullptr) { + ms_context_->allocator->Free(pack_a_ptr_); + } else { + free(pack_a_ptr_); } - memset(pack_a_ptr_, 0, param_->row_align_ * param_->deep_align_ * sizeof(int8_t)); - memset(input_sums_, 0, param_->row_align_ * sizeof(int)); - return RET_OK; + pack_a_ptr_ = nullptr; + input_sums_ = nullptr; } int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() { + if (weight_is_packed_) { + return RET_OK; + } + if (pack_b_ptr_ != nullptr) { free(pack_b_ptr_); pack_b_ptr_ = nullptr; } pack_b_ptr_ = - reinterpret_cast(malloc(param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t))); + reinterpret_cast(malloc(b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t))); if (pack_b_ptr_ == nullptr) { FreeTmpBuffer(); return RET_ERROR; @@ -210,28 +235,32 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() { free(weight_sums_); weight_sums_ = nullptr; } - weight_sums_ = reinterpret_cast(malloc(param_->batch * param_->col_align_ * sizeof(int))); + weight_sums_ = reinterpret_cast(malloc(b_batch_ * param_->col_align_ * sizeof(int))); if (weight_sums_ == nullptr) { FreeTmpBuffer(); return RET_ERROR; } - memset(pack_b_ptr_, 0, param_->batch * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)); - memset(weight_sums_, 0, param_->batch * param_->col_align_ * sizeof(int)); + memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t)); + memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int)); return RET_OK; } int MatmulDynamicBaseInt8CPUKernel::CopyBias() { if (in_tensors_.size() == kHasBiasSize) { + CHECK_NULL_RETURN(in_tensors_[kBiasIndex]); auto bias_tensor = in_tensors_[kBiasIndex]; - fp32_bias_ptr_ = static_cast(malloc(bias_tensor->Size())); - if (fp32_bias_ptr_ == nullptr) { + auto bias_shape = bias_tensor->shape(); + MS_CHECK_TRUE_MSG(bias_shape.size() == 1, lite::RET_INPUT_TENSOR_ERROR, "bias is not 1D."); + size_t bias_pack_size = UP_ROUND(bias_shape.back(), col_tile_) * lite::DataTypeSize(bias_tensor->data_type()); + bias_ptr_ = malloc(bias_pack_size); + if (bias_ptr_ == nullptr) { MS_LOG(ERROR) << "Memory allocation failed"; FreeTmpBuffer(); return RET_MEMORY_FAILED; } - memcpy(fp32_bias_ptr_, bias_tensor->data(), bias_tensor->ElementsNum() * sizeof(float)); + memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->Size()); } else { - fp32_bias_ptr_ = nullptr; + bias_ptr_ = nullptr; } return RET_OK; } @@ -239,6 +268,18 @@ int MatmulDynamicBaseInt8CPUKernel::CopyBias() { int MatmulDynamicBaseInt8CPUKernel::Prepare() { CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize); CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize); + CHECK_NULL_RETURN(in_tensors_[0]); + CHECK_NULL_RETURN(in_tensors_[1]); + CHECK_NULL_RETURN(out_tensors_[0]); + if (in_tensors_[0]->data_type() != mindspore::kNumberTypeInt8 || + in_tensors_[1]->data_type() != mindspore::kNumberTypeInt8) { + MS_LOG(ERROR) << "Datatype error, input0 data_type is " << in_tensors_[0]->data_type() << ", input1 data_type is " + << in_tensors_[1]->data_type(); + return RET_ERROR; + } +#ifdef ENABLE_FP16 + enable_fp16_ = ms_context_->device_list_[0].device_info_.cpu_device_info_.enable_float16_; +#endif InitParameter(); auto ret = MallocQuantParam(); if (ret != RET_OK) { @@ -277,18 +318,24 @@ int MatmulDynamicBaseInt8CPUKernel::Prepare() { } int MatmulDynamicBaseInt8CPUKernel::ReSize() { + // In the framework, the out_tensors data_type is forced to kNumberTypeFloat32 + if (enable_fp16_) { + out_tensors_[0]->set_data_type(kNumberTypeFloat16); + } auto x_shape = in_tensors_.at(0)->shape(); auto o_shape = out_tensors_.at(0)->shape(); MS_ASSERT(o_shape.size() >= kSize2); + param_->row_ = o_shape[o_shape.size() - kSize2]; param_->row_align_ = UP_ROUND(param_->row_, row_tile_); param_->deep_ = param_->a_transpose_ ? x_shape[x_shape.size() - kSize2] : x_shape[x_shape.size() - kSize1]; param_->deep_align_ = UP_ROUND(param_->deep_, deep_tile_); - auto ret = InitMatrixABuffer(); + auto ret = InitBroadcastParams(in_tensors_[kInputIndex]->shape(), in_tensors_[kWeightIndex]->shape(), param_, + &a_offset_, &b_offset_); if (ret != RET_OK) { - FreeQuantParam(); - return ret; + MS_LOG(ERROR) << "InitBroadcastParams failed."; + return RET_ERROR; } if (!param_->b_const_) { @@ -301,4 +348,80 @@ int MatmulDynamicBaseInt8CPUKernel::ReSize() { } return RET_OK; } + +int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector &a_shape_const, + const std::vector &b_shape_const, MatMulParameter *params, + std::vector *a_offsets, std::vector *b_offsets) { + std::vector a_shape = a_shape_const; + if (a_shape.size() < kNCHWDimNumber) { + size_t add_nums = kNCHWDimNumber - a_shape.size(); + for (size_t i = 0; i < add_nums; ++i) { + (void)a_shape.insert(a_shape.begin(), 1); + } + } + std::vector b_shape = b_shape_const; + if (b_shape.size() < kNCHWDimNumber) { + size_t add_nums = kNCHWDimNumber - b_shape.size(); + for (size_t i = 0; i < add_nums; ++i) { + (void)b_shape.insert(b_shape.begin(), 1); + } + } + + int batch_sizes[MAX_SHAPE_SIZE] = {0}; + int a_batch_sizes[MAX_SHAPE_SIZE] = {0}; + int b_batch_sizes[MAX_SHAPE_SIZE] = {0}; + for (int i = a_shape.size() - kCHWDimNumber; i >= 0; --i) { + if (static_cast(a_shape.size() - kCHWDimNumber) == i) { + batch_sizes[i] = std::max(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_shape[i]; + b_batch_sizes[i] = b_shape[i]; + } else { + batch_sizes[i] = batch_sizes[i + 1] * std::max(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_batch_sizes[i + 1] * a_shape[i]; + b_batch_sizes[i] = b_batch_sizes[i + 1] * b_shape[i]; + } + } + + int out_batch = 1; + for (size_t i = 0; i < a_shape.size() - kHWDimNumber; ++i) { + int max_v = MSMAX(a_shape[i], b_shape[i]); + int min_v = MSMIN(a_shape[i], b_shape[i]) > 0 ? MSMIN(a_shape[i], b_shape[i]) : 1; + out_batch *= max_v; + if (max_v != min_v && max_v % min_v != 0) { + MS_LOG(ERROR) << "matmul don't support broadcast for dimension " << a_shape << " and " << b_shape; + return RET_ERROR; + } + } + params->batch = out_batch; + + a_offsets->resize(params->batch, 0); + b_offsets->resize(params->batch, 0); + for (int i = 0; i < params->batch; ++i) { + int64_t delta = i; + int a_offset = 0; + int b_offset = 0; + for (size_t j = 0; j < a_shape.size() - kHWDimNumber; ++j) { + if (j > 0) { + delta = delta % batch_sizes[j]; + } + if (j < (a_shape.size() - kCHWDimNumber)) { + a_offset += (delta / batch_sizes[j + 1] * a_shape[j] / std::max(a_shape[j], b_shape[j])) * a_batch_sizes[j + 1]; + b_offset += (delta / batch_sizes[j + 1] * b_shape[j] / std::max(a_shape[j], b_shape[j])) * b_batch_sizes[j + 1]; + } else { + a_offset += (delta * a_shape[j] / std::max(a_shape[j], b_shape[j])); + b_offset += (delta * b_shape[j] / std::max(a_shape[j], b_shape[j])); + } + } + (*a_offsets)[i] = a_offset; + (*b_offsets)[i] = b_offset; + } + + return RET_OK; +} + +int MatmulDynamicBaseInt8CPUKernel::PreparePackedWeight(const lite::Tensor *tensor) { + weight_is_packed_ = true; + weight_sums_tensor_ = tensor; + return RET_OK; +} } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h index 68b664af..6f86c07a 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,13 +18,14 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_INT8_MATMUL_DYNAMIC_BASE_INT8_H_ #include +#include #include "include/errorcode.h" -#include "include/context.h" #include "src/runtime/lite_kernel.h" #include "nnacl/matmul_parameter.h" #include "nnacl/common_func.h" #include "nnacl/int8/quantize.h" #include "nnacl/int8/common_func_int8.h" +#include "src/common/common.h" namespace mindspore::kernel { class MatmulDynamicBaseInt8CPUKernel : public LiteKernel { @@ -37,44 +38,60 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel { ~MatmulDynamicBaseInt8CPUKernel() override; int Prepare() override; int ReSize() override; + static int InitBroadcastParams(const std::vector &a_shape_const, const std::vector &b_shape_const, + MatMulParameter *params, std::vector *a_offsets, std::vector *b_offsets); + + const int8_t *GetPackBPtr() const { return pack_b_ptr_; } + const int *GetWeightSums() const { return weight_sums_; } + const int GetBBatch() const { return b_batch_; } + int PreparePackedWeight(const lite::Tensor *tensor) override; private: void ResizeMatrixBParameter(); int CopyBias(); - int InitMatrixABuffer(); int InitMatrixBBuffer(); int MallocQuantParam(); protected: + int a_batch_ = 1; + int b_batch_ = 1; + std::vector a_offset_; + std::vector b_offset_; typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col); virtual void InitParameter() = 0; int TransferA(); - int InitInputQuantParam(); + int InitInputQuantParam(std::vector *scales, std::vector *zp); int InitFilterQuantParam(); int TransferB(); void FreeTmpBuffer(); void FreeQuantParam(); + int InitMatrixABuffer(); + void FreeMatrixABuffer(); protected: MatMulParameter *param_ = nullptr; MatmulDynamicQuantParameter *quant_param_ = nullptr; int8_t *pack_a_ptr_ = nullptr; int8_t *pack_b_ptr_ = nullptr; - float *fp32_bias_ptr_ = nullptr; + + bool input_per_channel_ = false; bool filter_per_channel_ = true; int8_t *batch_input_ptr_ = nullptr; int8_t *batch_weight_ptr_ = nullptr; + int8_t *batch_a_ptr_ = nullptr; int8_t *batch_b_ptr_ = nullptr; - float *batch_c_ptr_ = nullptr; + void *bias_ptr_ = nullptr; + void *batch_c_ptr_ = nullptr; int *input_sums_ = nullptr; int *weight_sums_ = nullptr; int row_tile_ = C4NUM; int col_tile_ = C4NUM; int deep_tile_ = C16NUM; - int channel_num_ = 0; - int thread_count_ = 1; int thread_stride_ = 0; + bool enable_fp16_ = false; PackFunc b_pack_func_ = nullptr; + bool weight_is_packed_ = false; + const lite::Tensor *weight_sums_tensor_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc index 766b7bb2..69c1baae 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -45,8 +45,8 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) { if (cur_oc <= 0) { return RET_OK; } - float *bias_ptr = fp32_bias_ptr_; - if (fp32_bias_ptr_ != nullptr) { + float *bias_ptr = static_cast(bias_ptr_); + if (bias_ptr_ != nullptr) { bias_ptr += cur_stride; } float *filter_scale = quant_param_->filter_scale_; @@ -54,10 +54,12 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) { if (filter_per_channel_) { filter_scale += cur_stride; } - DynamicMatmul4x16x4AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr, - batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_, param_->deep_align_, - param_->col_, quant_param_->input_zp_, quant_param_->input_scale_, filter_scale, filter_zp, - filter_per_channel_); + int64_t act_type = static_cast(param_->act_type_); + + DynamicMatmul4x16x4AIWI(batch_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, bias_ptr, + static_cast(batch_c_ptr_) + cur_stride, param_->row_, cur_oc, param_->deep_, + param_->deep_align_, param_->col_, *quant_param_->input_zp_, quant_param_->input_scale_, + filter_scale, filter_zp, input_per_channel_, filter_per_channel_, act_type); return RET_OK; } @@ -81,11 +83,18 @@ void MatmulDynamicInt8CPUKernel::InitParameter() { } int MatmulDynamicInt8CPUKernel::Run() { - auto ret = InitInputQuantParam(); + std::vector input_scales; + std::vector input_zp; + auto ret = InitInputQuantParam(&input_scales, &input_zp); if (ret != RET_OK) { MS_LOG(ERROR) << "Init input quant param failed."; return ret; } + ret = InitMatrixABuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << " failed."; + return ret; + } if (!param_->b_const_) { ret = InitFilterQuantParam(); if (ret != RET_OK) { @@ -104,8 +113,8 @@ int MatmulDynamicInt8CPUKernel::Run() { CHECK_NULL_RETURN(a_ptr); CHECK_NULL_RETURN(c_ptr); for (int i = 0; i < param_->batch; i++) { - memset(pack_a_ptr_, quant_param_->input_zp_, param_->row_align_ * param_->deep_align_ * sizeof(int8_t)); - auto current_src_a = a_ptr + i * param_->row_ * param_->deep_; + memset(pack_a_ptr_, *(quant_param_->input_zp_), param_->row_align_ * param_->deep_align_ * sizeof(int8_t)); + auto current_src_a = a_ptr + a_offset_[i] * param_->row_ * param_->deep_; if (param_->a_transpose_) { MS_CHECK_TRUE_RET(a_pack_func_ != nullptr, RET_ERROR); a_pack_func_(current_src_a, pack_a_ptr_, param_->deep_, param_->row_); @@ -114,15 +123,17 @@ int MatmulDynamicInt8CPUKernel::Run() { a_pack_func_(current_src_a, pack_a_ptr_, param_->row_, param_->deep_); } - batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; + batch_a_ptr_ = pack_a_ptr_; + batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_; batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_; - ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_count_); + ret = ParallelLaunch(this->ms_context_, MatmulDynamicInt8Run, this, thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]"; return ret; } } + FreeMatrixABuffer(); return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h index 71869275..86b2c009 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_int8.h @@ -25,7 +25,9 @@ class MatmulDynamicInt8CPUKernel : public MatmulDynamicBaseInt8CPUKernel { public: MatmulDynamicInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {} + : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) { + param_->matmul_type_ = MatmulType::kMatmulDynamicInt8Cpu; + } ~MatmulDynamicInt8CPUKernel() override = default; int Run() override; diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc index 132fa5d7..755b81e9 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.cc @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,6 +81,41 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotPre(int task_id) { return RET_OK; } +void MatMulDynamicSdotInt8Kernel::ComputeMultiScaleAhead(std::vector *multi_scale, int col_start, + size_t col_num) { + auto &scales = *multi_scale; + if (!input_per_channel_) { + if (!filter_per_channel_) { + scales.resize(1); + scales[0] = quant_param_->input_scale_[0] * quant_param_->filter_scale_[0]; + } else { + scales.resize(UP_ROUND(col_num, col_tile_)); + float *filter_scales = quant_param_->filter_scale_ + col_start; + for (size_t i = 0; i < col_num; ++i) { + scales[i] = quant_param_->input_scale_[0] * filter_scales[i]; + } + } + } else if (!filter_per_channel_) { + scales.resize(param_->row_align_); + for (int i = 0; i < param_->row_; ++i) { + scales[i] = quant_param_->input_scale_[i] * quant_param_->filter_scale_[0]; + } + } +} + +void MatMulDynamicSdotInt8Kernel::ComputeMultiScaleChannelByChannel(std::vector *multi_scale, int row_start, + size_t row_num, int col_start, size_t col_num) { + auto &scales = *multi_scale; + scales.resize(row_tile_ * col_tile_, 0); + float *in_scales = quant_param_->input_scale_ + row_start; + float *filter_scales = quant_param_->filter_scale_ + col_start; + for (size_t i = 0; i < row_num; ++i) { + for (size_t j = 0; j < col_num; ++j) { + scales[i * col_tile_ + j] = in_scales[i] * filter_scales[j]; + } + } +} + int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) { // Multi-thread split by col. int stride = thread_stride_ * col_tile_; @@ -104,15 +139,13 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) { } } - std::vector multi_scale(cur_oc); - for (int i = 0; i < cur_oc; ++i) { - if (!param_->b_const_) { - multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[0]; - } else { - multi_scale[i] = quant_param_->input_scale_ * quant_param_->filter_scale_[cur_stride + i]; - } - } - auto out_stride = param_->col_ * sizeof(float); + std::vector multi_scale; + ComputeMultiScaleAhead(&multi_scale, cur_stride, cur_oc); + int64_t mode = input_per_channel_ * C2NUM + filter_per_channel_; + + size_t data_type_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); + auto out_stride = param_->col_ * data_type_size; + int64_t act_type = static_cast(param_->act_type_); for (int r = 0; r < param_->row_; r += C4NUM) { size_t row = MSMIN(C4NUM, param_->row_ - r); auto a_ptr = pack_a_ptr_ + r * param_->deep_align_; @@ -122,21 +155,30 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicArm64SdotImpl(int task_id) { auto col_offset = cur_stride + c; auto b_ptr = batch_b_ptr_ + col_offset * param_->deep_align_; int *weight_sums_ptr = current_sums + c; - auto out_ptr = batch_c_ptr_ + r * param_->col_ + col_offset; - auto bias = fp32_bias_ptr_; - if (bias != nullptr) { - bias += col_offset; - } -#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(SUPPORT_34XX) && (!defined(MACHINE_LINUX_ARM64)) - DynamicMatmulSdot4x4x16AIWI(a_ptr, b_ptr, out_ptr, param_->deep_align_, multi_scale.data() + c, bias, row, col, - out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_, - quant_param_->filter_zp_[0] * param_->deep_); -#else - DynamicMatmul4x4x16AIWI(a_ptr, b_ptr, out_ptr, param_->deep_align_, multi_scale.data() + c, bias, row, col, - out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_, - quant_param_->filter_zp_[0] * param_->deep_); + void *out_ptr = static_cast(batch_c_ptr_) + (r * param_->col_ + col_offset) * data_type_size; + auto bias = bias_ptr_; + if (bias_ptr_ != nullptr) { + bias = static_cast(bias) + col_offset * data_type_size; + } + if (mode == C3NUM) { + ComputeMultiScaleChannelByChannel(&multi_scale, r, row, col_offset, col); + } + int multi_scale_offset = + (input_per_channel_ == filter_per_channel_ ? 0 : input_per_channel_ * r + filter_per_channel_ * c); + if (!enable_fp16_) { + dynamic_matmul_compute_fp32(a_ptr, b_ptr, reinterpret_cast(out_ptr), param_->deep_align_, + multi_scale.data() + multi_scale_offset, reinterpret_cast(bias), row, col, + out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_[0], + quant_param_->filter_zp_[0] * param_->deep_, act_type, mode); + } else { +#ifdef ENABLE_FP16 + dynamic_matmul_compute_fp16(a_ptr, b_ptr, reinterpret_cast(out_ptr), param_->deep_align_, + multi_scale.data() + multi_scale_offset, reinterpret_cast(bias), row, + col, out_stride, input_sums_ptr, weight_sums_ptr, quant_param_->input_zp_[0], + quant_param_->filter_zp_[0] * param_->deep_, act_type, mode); #endif + } } } return RET_OK; @@ -155,31 +197,44 @@ void MatMulDynamicSdotInt8Kernel::InitParameter() { } else { b_pack_func_ = RowMajor2Col4x16MajorInt8; } - return; +#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(SUPPORT_34XX) && (!defined(MACHINE_LINUX_ARM64)) && \ + !defined(USE_AOS_GCC_TOOLCHAIN) + dynamic_matmul_compute_fp32 = DynamicMatmulSdot4x4x16AIWI; +#else + dynamic_matmul_compute_fp32 = DynamicMatmul4x4x16AIWI; +#endif +#ifdef ENABLE_FP16 +#if defined(ENABLE_ARM64) && !defined(SUPPORT_NNIE) && !defined(SUPPORT_34XX) && (!defined(MACHINE_LINUX_ARM64)) && \ + !defined(USE_AOS_GCC_TOOLCHAIN) + dynamic_matmul_compute_fp16 = DynamicMatmulSdot4x4x16AIWIForFp16; +#else + dynamic_matmul_compute_fp16 = DynamicMatmul4x4x16AIWIForFp16; +#endif +#endif } int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() { int8_t *a_ptr = reinterpret_cast(in_tensors_.at(0)->data()); int8_t *b_ptr = reinterpret_cast(in_tensors_.at(1)->data()); - float *c_ptr = reinterpret_cast(out_tensors_.at(0)->data()); + void *c_ptr = out_tensors_.at(0)->data(); CHECK_NULL_RETURN(a_ptr); CHECK_NULL_RETURN(b_ptr); CHECK_NULL_RETURN(c_ptr); + size_t data_type_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); for (int i = 0; i < param_->batch; i++) { - batch_input_ptr_ = a_ptr + i * param_->row_ * param_->deep_; + batch_input_ptr_ = a_ptr + a_offset_[i] * param_->row_ * param_->deep_; auto ret = ParallelLaunch(this->ms_context_, Arm64SdotPreRun, this, op_parameter_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Arm64SdotPreRun error: [" << ret << "]"; return ret; } - batch_weight_ptr_ = b_ptr + i * param_->col_ * param_->deep_; - batch_sums_ = weight_sums_ + i * param_->col_align_; - batch_b_ptr_ = pack_b_ptr_ + i * param_->col_align_ * param_->deep_align_; - batch_c_ptr_ = c_ptr + i * param_->row_ * param_->col_; - - ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_count_); + batch_weight_ptr_ = b_ptr + b_offset_[i] * param_->col_ * param_->deep_; + batch_sums_ = weight_sums_ + b_offset_[i] * param_->col_align_; + batch_b_ptr_ = pack_b_ptr_ + b_offset_[i] * param_->col_align_ * param_->deep_align_; + batch_c_ptr_ = static_cast(c_ptr) + i * param_->row_ * param_->col_ * data_type_size; + ret = ParallelLaunch(this->ms_context_, Arm64SdotRun, this, thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Arm64SdotRun error: [" << ret << "]"; return ret; @@ -189,11 +244,18 @@ int MatMulDynamicSdotInt8Kernel::MatMulDynamicRunArm64Sdot() { } int MatMulDynamicSdotInt8Kernel::Run() { - auto ret = InitInputQuantParam(); + std::vector input_scales; + std::vector input_zp; + auto ret = InitInputQuantParam(&input_scales, &input_zp); if (ret != RET_OK) { MS_LOG(ERROR) << "Init input quant param failed."; return ret; } + ret = InitMatrixABuffer(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Alloc run-buffer for matrix-a failed."; + return ret; + } if (!param_->b_const_) { ret = InitFilterQuantParam(); if (ret != RET_OK) { @@ -202,6 +264,8 @@ int MatMulDynamicSdotInt8Kernel::Run() { return ret; } } - return MatMulDynamicRunArm64Sdot(); + ret = MatMulDynamicRunArm64Sdot(); + FreeMatrixABuffer(); + return ret; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h index fc6832bc..131af45b 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_dynamic_sdot_int8.h @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,20 +25,35 @@ class MatMulDynamicSdotInt8Kernel : public MatmulDynamicBaseInt8CPUKernel { public: MatMulDynamicSdotInt8Kernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {} + : MatmulDynamicBaseInt8CPUKernel(parameter, inputs, outputs, ctx) { + param_->matmul_type_ = MatmulType::kMatmulDynamicSdotInt8Cpu; + } ~MatMulDynamicSdotInt8Kernel() override = default; int Run() override; public: - int MatMulDynamicRunArm64Sdot(); int MatMulDynamicArm64SdotPre(int task_id); int MatMulDynamicArm64SdotImpl(int task_id); - private: + protected: void InitParameter() override; private: + template + using DynamicMatmulComputer = void (*)(const int8_t *a, const int8_t *b, T *out, size_t deep4, + const float *multi_scles, const T *bias, size_t row, size_t col, size_t stride, + const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, + int64_t act_type, int64_t mode); + + int MatMulDynamicRunArm64Sdot(); + void ComputeMultiScaleAhead(std::vector *multi_scale, int col_start, size_t col_num); + void ComputeMultiScaleChannelByChannel(std::vector *multi_scale, int row_start, size_t row_num, int col_start, + size_t col_num); int *batch_sums_ = nullptr; + DynamicMatmulComputer dynamic_matmul_compute_fp32{nullptr}; +#ifdef ENABLE_FP16 + DynamicMatmulComputer dynamic_matmul_compute_fp16{nullptr}; +#endif }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc index b539224f..5ad3fd8a 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.cc @@ -92,7 +92,7 @@ kernel::LiteKernel *MatmulInt8CPUKernelCreator(const std::vector MS_LOG(ERROR) << "kernel: " << parameter->name_ << " is unsupported A is const."; return nullptr; } - if (lite::IsSupportSDot()) { + if (lite::IsSupportSDot() || static_cast(ctx)->instructions_ctx_.support_sdot) { kernel = new (std::nothrow) MatMulDynamicSdotInt8Kernel(parameter, inputs, outputs, static_cast(ctx)); } else { diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h index 4e9e4e42..d711f727 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/matmul_int8.h @@ -29,7 +29,9 @@ class MatmulInt8CPUKernel : public MatmulBaseInt8CPUKernel { public: MatmulInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) - : MatmulBaseInt8CPUKernel(parameter, inputs, outputs, ctx) {} + : MatmulBaseInt8CPUKernel(parameter, inputs, outputs, ctx) { + param_->matmul_type_ = MatmulType::kMatmulInt8Cpu; + } ~MatmulInt8CPUKernel() override = default; int Prepare() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel_registry.h b/mindspore/lite/src/runtime/kernel_registry.h index f563d82d..82f55c81 100644 --- a/mindspore/lite/src/runtime/kernel_registry.h +++ b/mindspore/lite/src/runtime/kernel_registry.h @@ -46,13 +46,13 @@ class MS_API KernelRegistry { const InnerContext *ctx, const mindspore::Context *ms_ctx, const kernel::KernelKey &key, OpParameter *op_parameter, kernel::KernelExec **kernel, const void *primitive = nullptr); int ReplaceKernelExec(kernel::KernelExec *kernel, const kernel::KernelKey &key); + kernel::LiteKernel *GetLiteKernel(const std::vector &in_tensors, const std::vector &out_tensors, + const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *parameter); protected: int GetCustomKernel(const std::vector &in_tensors, const std::vector &out_tensors, const mindspore::Context *ctx, const kernel::KernelKey &key, kernel::KernelExec **kernel, const void *primitive = nullptr); - kernel::LiteKernel *GetLiteKernel(const std::vector &in_tensors, const std::vector &out_tensors, - const InnerContext *ctx, const kernel::KernelKey &key, OpParameter *parameter); static const int device_type_length_{kKernelArch_MAX - kKernelArch_MIN + 1}; static const int data_type_length_{kNumberTypeEnd - kNumberTypeBegin + 1}; static const int op_type_length_{PrimitiveType_MAX - PrimitiveType_MIN + 1}; diff --git a/mindspore/lite/src/runtime/lite_kernel.h b/mindspore/lite/src/runtime/lite_kernel.h index a27f77d8..c85278f4 100644 --- a/mindspore/lite/src/runtime/lite_kernel.h +++ b/mindspore/lite/src/runtime/lite_kernel.h @@ -193,6 +193,8 @@ class MS_API LiteKernel : public Abstractkernel { const lite::Context *context() const { return this->ms_context_; } bool ws_allocated_ = false; + virtual int PreparePackedWeight(const lite::Tensor *tensor) { return mindspore::lite::RET_OK; } + protected: OpParameter *op_parameter_ = nullptr; // tensor will free in ~lite_session() diff --git a/mindspore/lite/src/runtime/lite_model.cc b/mindspore/lite/src/runtime/lite_model.cc index cd8e68d1..662a0856 100644 --- a/mindspore/lite/src/runtime/lite_model.cc +++ b/mindspore/lite/src/runtime/lite_model.cc @@ -29,6 +29,7 @@ #include "src/common/file_utils.h" #include "src/tensor.h" #include "extendrt/mindir_loader/model_loader.h" +#include "src/common/mmap_utils.h" namespace mindspore::lite { namespace { @@ -36,7 +37,11 @@ constexpr size_t kMaxModelBufferSize = static_cast(1024) * 1024 * 1024 * } void LiteModel::Free() { - if (this->buf != nullptr) { + if (this->model_buf_by_mmap_) { + UnmapMmapBuffer(static_cast(this->buf), this->buf_size_); + this->buf = nullptr; + } + if (this->buf != nullptr && !this->model_buf_by_mmap_) { delete[](this->buf); this->buf = nullptr; } diff --git a/mindspore/lite/src/runtime/lite_model.h b/mindspore/lite/src/runtime/lite_model.h index af62cb91..d18ae051 100644 --- a/mindspore/lite/src/runtime/lite_model.h +++ b/mindspore/lite/src/runtime/lite_model.h @@ -310,6 +310,7 @@ class MS_API LiteModel : public Model { public: std::vector node_bufs_; + bool model_buf_by_mmap_ = false; protected: std::vector attr_tensor_bufs_; diff --git a/mindspore/lite/src/runtime/lite_session.cc b/mindspore/lite/src/runtime/lite_session.cc index 6661b410..eb1b5ef7 100644 --- a/mindspore/lite/src/runtime/lite_session.cc +++ b/mindspore/lite/src/runtime/lite_session.cc @@ -33,10 +33,12 @@ #include "src/common/graph_util.h" #include "src/common/tensor_util.h" #include "src/common/file_utils.h" +#include "src/common/mmap_utils.h" #include "src/runtime/lite_model.h" #include "src/runtime/weight_decoder.h" #include "src/runtime/runtime_allocator.h" #include "src/runtime/kernel_exec_util.h" +#include "src/runtime/runtime_packed_node_pass.h" #ifndef CUSTOM_KERNEL_REGISTRY_CLIP #include "src/registry/register_kernel_impl.h" #endif @@ -561,7 +563,7 @@ int LiteSession::CompileGraph(Model *model) { } InitGraphInputTensors(model); InitGraphOutputTensors(model); - + PackedNodePass::GetInstance().Run(model, tensors_); // scheduler kernels Scheduler scheduler(context_, ms_context_, model, &tensors_, &inputs_, &outputs_, is_train_session_, &is_infershape_, &is_control_flow_, execution_plan_, delegate_, delegate_device_type_); @@ -672,6 +674,11 @@ int LiteSession::PrepareKernels(const Model *model) { return RET_ERROR; } for (auto &node : subgraph_kernel->nodes()) { + ret = PackKernelExec(node, tensors_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Pack KernelExec failed."; + return ret; + } ret = node->Prepare(); if (ret != RET_OK) { MS_LOG(ERROR) << "node: " << node->name() << " prepare failed."; @@ -1707,9 +1714,14 @@ const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspor } const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, - const std::shared_ptr &ms_context) { + const std::shared_ptr &ms_context, bool use_mmap) { size_t buf_size; - auto model_buf = lite::ReadFile(file.c_str(), &buf_size); + char *model_buf; + if (use_mmap) { + model_buf = reinterpret_cast(lite::ReadFileByMmap(file.c_str(), &buf_size)); + } else { + model_buf = lite::ReadFile(file.c_str(), &buf_size); + } if (model_buf == nullptr) { MS_LOG(ERROR) << "The model path is invalid"; return model_buf; @@ -1829,7 +1841,8 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type, const std::shared_ptr &ms_context) { size_t model_size; - auto model_buf = LoadModelByPath(model_path, model_type, &model_size, ms_context); + bool use_mmap = IsMmapEnable(); + auto model_buf = LoadModelByPath(model_path, model_type, &model_size, ms_context, use_mmap); if (model_buf == nullptr) { MS_LOG(ERROR) << "Read model file failed"; return RET_ERROR; @@ -1837,17 +1850,26 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, auto *model = lite::ImportFromBuffer(model_buf, model_size, true, model_type, model_path); if (model == nullptr) { MS_LOG(ERROR) << "Import model failed"; - delete[] model_buf; + if (use_mmap) { + lite::UnmapMmapBuffer(const_cast(static_cast(model_buf)), model_size); + } else { + delete[] model_buf; + } return RET_ERROR; } auto status = lite::PackWeightManager::GetInstance()->InitPackWeightByBuf(model_buf, model_size); MS_CHECK_FALSE_MSG(status != RET_OK, RET_ERROR, "InitPackWeightByBuf failed."); (reinterpret_cast(model))->set_keep_model_buf(true); + reinterpret_cast(model)->model_buf_by_mmap_ = use_mmap; auto ret = CompileGraph(model); if (ret != lite::RET_OK) { MS_LOG(ERROR) << "Compile model failed"; - delete[] model_buf; + if (use_mmap) { + lite::UnmapMmapBuffer(const_cast(static_cast(model_buf)), model_size); + } else { + delete[] model_buf; + } model->buf = nullptr; delete model; return RET_ERROR; @@ -1855,4 +1877,15 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, set_model(model); return RET_OK; } + +bool lite::LiteSession::IsMmapEnable() { +#if !defined(_WIN32) && !defined(_WIN64) + if (delegate_device_type_ == DT_NPU) { + return false; + } + return true; +#else + return false; +#endif +} } // namespace mindspore diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h index d5a672bb..c9edf63e 100644 --- a/mindspore/lite/src/runtime/lite_session.h +++ b/mindspore/lite/src/runtime/lite_session.h @@ -60,7 +60,7 @@ class MS_API LiteSession { const std::shared_ptr &ms_context); static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size); static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, - const std::shared_ptr &ms_context); + const std::shared_ptr &ms_context, bool use_mmap = false); virtual int Init(InnerContext *context); virtual void BindThread(bool if_bind); virtual int CompileGraph(Model *model); @@ -154,6 +154,7 @@ class MS_API LiteSession { static void FreePackOpWeight(const std::vector &kernels); std::string ParseWeightPath(); static void MarkSharedWeight(const std::vector &kernels); + bool IsMmapEnable(); private: int PreCheck(Model *model); diff --git a/mindspore/lite/src/runtime/runtime_packed_node_pass.cc b/mindspore/lite/src/runtime/runtime_packed_node_pass.cc new file mode 100644 index 00000000..81d50522 --- /dev/null +++ b/mindspore/lite/src/runtime/runtime_packed_node_pass.cc @@ -0,0 +1,358 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/runtime_packed_node_pass.h" +#include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" + +using RecoveryWeightFunc = void (*)(void *, void *, int, int, bool); +namespace mindspore { +namespace { +constexpr size_t kFlatbuffersBuilderInitSize = 1024; +constexpr auto kActivationType = "activation_type"; +constexpr auto kTransposeA = "transpose_a"; +constexpr auto kTransposeB = "transpose_b"; +constexpr auto kArm64SimdDot = "ARM64SIMD_DOT"; +} // namespace + +namespace lite { +PackedNodePass::~PackedNodePass() { + for (auto &pack_info : node_pack_info_map_) { + delete pack_info.second; + } + node_pack_info_map_.clear(); +} + +void PackedNodePass::Run(Model *model, const std::vector &tensors) { + for (auto &node : model->graph_.all_nodes_) { + MS_ASSERT(node != nullptr); + if (node->node_type_ != schema::PrimitiveType_Custom) { + continue; + } + auto *primitive = reinterpret_cast(node->primitive_); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!"; + return; + } + auto custom = primitive->value_as_Custom(); + if (custom == nullptr || custom->type() == nullptr) { + MS_LOG(ERROR) << "Custom node is nullptr"; + return; + } + auto custom_type = custom->type()->str(); + if (custom_type != "MatmulFusionPacked") { + continue; + } + flatbuffers::FlatBufferBuilder fbb(kFlatbuffersBuilderInitSize); + + auto custom_attr = custom->attr(); + std::map attr_map; + for (size_t i = 0; i < custom_attr->size(); ++i) { + auto attr = custom_attr->Get(i); + auto attr_key = attr->name()->str(); + auto data_bytes = attr->data(); + int data_size = static_cast(data_bytes->size()); + std::string attr_value; + for (int j = 0; j < data_size; j++) { + attr_value.push_back(static_cast(data_bytes->Get(j))); + } + attr_map[attr_key] = attr_value; + } + if (attr_map.find(kActivationType) == attr_map.end() || attr_map.find(kTransposeA) == attr_map.end() || + attr_map.find(kTransposeB) == attr_map.end()) { + MS_LOG(ERROR) << "Custom attr error."; + return; + } + auto val_offset = + schema::CreateMatMulFusion(fbb, std::stoi(attr_map[kTransposeA]), std::stoi(attr_map[kTransposeB]), + static_cast(std::stoi(attr_map[kActivationType]))); + auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_MatMulFusion, val_offset.o); + fbb.Finish(prim_offset); + void *prim = malloc(fbb.GetSize()); + if (prim == nullptr) { + MS_LOG(ERROR) << "malloc primitive failed."; + return; + } + memcpy(prim, fbb.GetBufferPointer(), fbb.GetSize()); + auto custom_primitive = flatbuffers::GetRoot(prim); + fbb.Clear(); + PackInfo *pack_info = new (std::nothrow) PackInfo(); + if (pack_info == nullptr) { + free(prim); + MS_LOG(ERROR) << "new PackInfo failed."; + return; + } + node->primitive_ = custom_primitive; + pack_info->is_packed_ = true; + pack_info->b_batch_ = std::stoi(attr_map["b_batch"]); + pack_info->col_ = std::stoi(attr_map["col"]); + pack_info->deep_ = std::stoi(attr_map["deep"]); + pack_info->col_align_ = std::stoi(attr_map["col_align"]); + pack_info->deep_align_ = std::stoi(attr_map["deep_align"]); + pack_info->b_transpose_ = std::stoi(attr_map[kTransposeB]); + pack_info->cpu_option_ = attr_map["cpu_option"]; + AddNodePackInfo(node->name_, pack_info); + if (node->quant_type_ == schema::QuantType_QUANT_DYNAMIC) { + pack_info->weight_sums_index_ = node->input_indices_.back(); + node->input_indices_.pop_back(); + if (!(reinterpret_cast(model)->keep_model_buf())) { + auto index = static_cast(pack_info->weight_sums_index_); + if (index > tensors.size()) { + MS_LOG(ERROR) << "weight sums tensor index is error."; + return; + } + auto tensor = tensors[index]; + CopyWeightBiasSumsTensor(tensor); + } + } + + node->node_type_ = schema::PrimitiveType_MatMulFusion; + } + need_store_weight_ = !(reinterpret_cast(model)->keep_model_buf()); +} + +void PackedNodePass::CopyWeightBiasSumsTensor(Tensor *tensor) { + if (!tensor->IsConst() && tensor->data() != nullptr) { + return; + } + if (!tensor->IsConst() || tensor->own_data()) { + return; + } + if (tensor->data_type() == kObjectTypeTensorType) { + MS_ASSERT(tensor->data() == nullptr); + } else { + auto copy_tensor = Tensor::CopyTensor(*tensor, true); + if (copy_tensor == nullptr) { + MS_LOG(ERROR) << "Copy tensor failed"; + return; + } + tensor->FreeData(); + tensor->set_data(copy_tensor->data()); + tensor->set_own_data(true); + copy_tensor->set_data(nullptr); + delete copy_tensor; + } +} + +int PackedNodePass::StoreWeightTensor(Tensor *tensor, size_t data_size) { + void *weight_data = malloc(data_size); + if (weight_data == nullptr) { + MS_LOG(ERROR) << "malloc weight tensor failed."; + return RET_NULL_PTR; + } + memcpy(weight_data, tensor->data(), data_size); + tensor->FreeData(); + tensor->set_data(weight_data); + tensor->IncRefCount(); + return RET_OK; +} + +void MatmulDynamicSdotInt8Unpack(void *src, void *dst, int row, int col, bool transpose) { + auto src_int8 = static_cast(src); + auto dst_int8 = static_cast(dst); + if (!transpose) { + // RowMajor2Col4x16MajorInt8 + int row_4 = UP_ROUND(row, C4NUM); + int stride = C16NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; + if (r < row) { + int src_idx = r * col + c; + src_int8[src_idx] = dst_int8[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM]; + } + } + } + } else { + int temp = row; + row = col; + col = temp; + // RowMajor2Row4x16MajorInt8 + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd16 = r / C16NUM; + int rm16 = r % C16NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4; + int src_index = r * col + c; + src_int8[src_index] = dst_int8[dst_index]; + } + } + } +} + +void MatmulFp32BaseUnpack(void *src, void *dst, int row, int col, bool transpose) { + if (!transpose) { + // RowMajor2Row8MajorParallel + auto src_r = static_cast(src); + auto dst_r = static_cast(dst); + for (int r = 0; r < row; r++) { + float *src_c = src_r + r * col; + int c = 0; + for (; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + src_c[c] = dst_r[cd8 * C8NUM * row + r * C8NUM + cm8]; + } + } + return; + } + // RowMajor2Col8MajorParallel + auto src_r = static_cast(src); + auto dst_r = static_cast(dst); + int row8 = row / C8NUM * C8NUM; + int col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; + + int ri = 0; + for (; ri < row8; ri += C8NUM) { + int ci = 0; + for (; ci < col_skip; ci += skip_size) { + float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + src_c[tr * col + tc] = dst_c[tc * C8NUM + tr]; + } + } + } + for (; ci < col; ci++) { + float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + src_c[i * col] = dst_c[i]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri < row; ri++, src_r += col, dst_r++) { + for (int i = 0; i < col; i++) { + src_r[i] = dst_r[i * C8NUM]; + } + } +} + +RecoveryWeightFunc GetRecoveryWeightFunc(const int quant_type, const TypeId data_type, const int node_type, + const std::string &cpu_option) { + if (cpu_option == kArm64SimdDot && node_type == schema::PrimitiveType_MatMulFusion && + quant_type == schema::QuantType_QUANT_DYNAMIC && data_type == kNumberTypeInt8) { + return MatmulDynamicSdotInt8Unpack; + } + + if (cpu_option == kArm64SimdDot && node_type == schema::PrimitiveType_MatMulFusion && + data_type == kNumberTypeFloat32) { + return MatmulFp32BaseUnpack; + } + return nullptr; +} + +size_t GetWeightTensorSize(MatmulType matmul_type, const PackInfo &pack_info) { + size_t data_size = 0; + if (matmul_type == kMatmulDynamicSdotInt8Cpu) { + data_size = pack_info.b_batch_ * pack_info.col_align_ * pack_info.deep_align_ * sizeof(int8_t); + } else { + data_size = pack_info.b_batch_ * pack_info.col_align_ * pack_info.deep_ * sizeof(float); + } + return data_size; +} + +int PackedMatmulKernelExec(kernel::KernelExec *kernel_exec, const std::vector &tensors) { + auto pack_info = PackedNodePass::GetInstance().GetNodePackInfo(kernel_exec->name()); + if (pack_info == nullptr) { + return RET_OK; + } + MS_CHECK_TRUE_MSG(kernel_exec->in_tensors().size() >= kInputSize1, lite::RET_ERROR, + "kernel doesn't have weight tensor."); + auto dst_tensor = kernel_exec->in_tensors()[SECOND_INPUT]; + auto kernel = kernel_exec->kernel(); + MS_CHECK_TRUE_MSG(kernel != nullptr, lite::RET_NULL_PTR, "kernel is nullptr."); + auto param = reinterpret_cast(kernel_exec->op_parameter()); + if (dst_tensor->data_type() == kNumberTypeFloat32) { + if (param->matmul_type_ == kNotImplemented) { + return RecoveryPackedWeight(dst_tensor, static_cast(kernel->quant_type()), dst_tensor->data_type(), + schema::PrimitiveType_MatMulFusion, pack_info); + } + } + + if (dst_tensor->data_type() == kNumberTypeInt8 && param->matmul_type_ != kMatmulDynamicSdotInt8Cpu && + pack_info->cpu_option_ == kArm64SimdDot) { + return RecoveryPackedWeight(dst_tensor, static_cast(kernel->quant_type()), dst_tensor->data_type(), + schema::PrimitiveType_MatMulFusion, pack_info); + } + + if (PackedNodePass::GetInstance().GetNeedStoreWeight()) { + size_t data_size = GetWeightTensorSize(param->matmul_type_, *pack_info); + int ret = PackedNodePass::GetInstance().StoreWeightTensor(dst_tensor, data_size); + if (ret != RET_OK) { + MS_LOG(ERROR) << "store weight tensor error."; + return ret; + } + } + auto lite_kernel = static_cast(kernel); + lite::Tensor *weight_sums = nullptr; + auto index = static_cast(pack_info->weight_sums_index_); + if (index < tensors.size()) { + weight_sums = tensors.at(index); + } + return lite_kernel->PreparePackedWeight(weight_sums); +} + +int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type, + PackInfo *pack_info) { + auto recovery_func = GetRecoveryWeightFunc(quant_type, data_type, node_type, pack_info->cpu_option_); + if (recovery_func == nullptr) { + MS_LOG(ERROR) << "unsupported recovery func."; + return RET_NULL_PTR; + } + void *unpack_data = malloc(weight->Size()); + if (unpack_data == nullptr) { + MS_LOG(ERROR) << "malloc unpack_data failed."; + return RET_NULL_PTR; + } + void *pack_b_ptr = weight->data(); + for (int i = 0; i < pack_info->b_batch_; i++) { + void *current_weight; + void *current_b_pack; + if (weight->data_type() == kNumberTypeInt8) { + current_weight = static_cast(static_cast(unpack_data) + i * pack_info->deep_ * pack_info->col_); + current_b_pack = + static_cast(static_cast(pack_b_ptr) + i * pack_info->col_align_ * pack_info->deep_align_); + } else if (weight->data_type() == kNumberTypeFloat32) { + current_weight = static_cast(static_cast(unpack_data) + i * pack_info->deep_ * pack_info->col_); + current_b_pack = + static_cast(static_cast(pack_b_ptr) + i * pack_info->col_align_ * pack_info->deep_); + } else { + free(unpack_data); + MS_LOG(ERROR) << "unsupported data type."; + return RET_ERROR; + } + recovery_func(current_weight, current_b_pack, pack_info->deep_, pack_info->col_, pack_info->b_transpose_); + } + weight->FreeData(); + weight->set_data(unpack_data); + return RET_OK; +} + +int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector &tensors) { + if (kernel_exec->type() == schema::PrimitiveType_MatMulFusion) { + return PackedMatmulKernelExec(kernel_exec, tensors); + } + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/runtime/runtime_packed_node_pass.h b/mindspore/lite/src/runtime/runtime_packed_node_pass.h new file mode 100644 index 00000000..0ba18eb7 --- /dev/null +++ b/mindspore/lite/src/runtime/runtime_packed_node_pass.h @@ -0,0 +1,83 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_PACKED_NODE_PASS_ +#define MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_PACKED_NODE_PASS_ + +#include +#include +#include +#include "src/runtime/lite_model.h" +#include "src/tensor.h" +#include "src/runtime/kernel_exec.h" + +namespace mindspore { +namespace lite { +struct PackInfo { + bool is_packed_{false}; + int weight_sums_index_{-1}; + int b_batch_; + int deep_; + int col_; + int deep_align_; + int col_align_; + bool b_transpose_{false}; + std::string cpu_option_; +}; + +class PackedNodePass { + public: + static PackedNodePass &GetInstance() { + static PackedNodePass instance; + return instance; + } + + PackInfo *GetNodePackInfo(const std::string &node_name) { + if (this->node_pack_info_map_.find(node_name) == this->node_pack_info_map_.end()) { + return nullptr; + } + return this->node_pack_info_map_[node_name]; + } + void Run(Model *model, const std::vector &tensors); + void CopyWeightBiasSumsTensor(Tensor *tensor); + int StoreWeightTensor(Tensor *tensor, size_t data_size); + bool GetNeedStoreWeight() const { return need_store_weight_; } + + protected: + void AddNodePackInfo(const std::string &node_name, PackInfo *pack_info) { + if (this->node_pack_info_map_.find(node_name) != this->node_pack_info_map_.end()) { + MS_LOG(WARNING) << "Key conflict when add weight sums index."; + } + this->node_pack_info_map_[node_name] = pack_info; + } + + private: + PackedNodePass() = default; + ~PackedNodePass(); + + private: + std::map node_pack_info_map_; + bool need_store_weight_{false}; +}; + +int PackKernelExec(kernel::KernelExec *kernel_exec, const std::vector &tensors); + +// packed weight data -> unpack +int RecoveryPackedWeight(Tensor *weight, const int quant_type, const TypeId data_type, const int node_type, + PackInfo *packInfo); +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_RUNTIME_RUNTIME_PACKED_NODE_PASS_ diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index bf5de821..887a78e3 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -36,7 +36,32 @@ namespace mindspore { namespace lite { namespace { const int kZeroPointGap = 128; +constexpr size_t kTupleGetItemInputSize = 3; +constexpr size_t kSecondIndex = 1; } // namespace +static STATUS GetAbstractfromTupleGetItem(const CNodePtr &cnode, AbstractBasePtr *abstract, size_t *idx) { + MS_CHECK_TRUE_MSG(abstract != nullptr, lite::RET_ERROR, "Abstract is nullptr."); + MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr."); + auto tuple_inputs = cnode->inputs(); + MS_CHECK_TRUE_MSG(tuple_inputs.size() == kTupleGetItemInputSize, lite::RET_ERROR, "The node must have 3 inputs!"); + auto get_item_input_cnode = tuple_inputs.at(kSecondIndex); + MS_CHECK_TRUE_MSG(get_item_input_cnode != nullptr, lite::RET_ERROR, "input node is nullptr."); + *idx = opt::GetTupleGetItemOutIndex(cnode); + if (!mindspore::utils::isa(get_item_input_cnode->abstract())) { + MS_LOG(ERROR) << "TupleGetItem's abstract is not AbstractTuple, cnode name: " + << get_item_input_cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + auto abstract_tuple = utils::cast(get_item_input_cnode->abstract()); + auto abstract_list = abstract_tuple->elements(); + if (abstract_list.size() <= *idx) { + MS_LOG(ERROR) << "AbstractTuple's size is smaller than expect"; + return lite::RET_ERROR; + } + *abstract = abstract_list[*idx]; + return lite::RET_OK; +} + int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector &outputs) { if (graph == nullptr || outputs.empty()) { MS_LOG(DEBUG) << "Input graph is nullptr or outputs is empty"; @@ -483,5 +508,83 @@ int TransferMetaGraph(const schema::MetaGraphT &graph, void **model_buf, size_t (void)memcpy(*model_buf, content, *size); return RET_OK; } + +STATUS GetShapeVectorFromParameter(const mindspore::ParameterPtr ¶m_node, std::vector *shape_vector) { + MS_CHECK_TRUE_MSG(shape_vector != nullptr, RET_ERROR, "shape vector is nullptr."); + auto abstract_base = param_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + return RET_ERROR; + } + + if (!abstract_base->isa()) { + MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name(); + return lite::RET_ERROR; + } + auto abstract_tensor = abstract_base->cast(); + MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!"); + *shape_vector = abstract_tensor->shape()->shape(); + return lite::RET_OK; +} + + +STATUS GetShapeVectorAndIdxFromCNode(const CNodePtr &cnode, std::vector *shape_vector, size_t *idx) { + MS_CHECK_TRUE_MSG(shape_vector != nullptr, lite::RET_ERROR, "shape is nullptr"); + + AbstractBasePtr cnode_abstract = nullptr; + if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) { + // idx is only used when cnode is type of kPrimTupleGetItem. + MS_CHECK_TRUE_MSG(idx != nullptr, lite::RET_ERROR, "idx is nullptr"); + if (GetAbstractfromTupleGetItem(cnode, &cnode_abstract, idx) != lite::RET_OK) { + MS_LOG(ERROR) << "Get abstract from tuple get item failed."; + return lite::RET_ERROR; + } + } else { + cnode_abstract = cnode->abstract(); + } + // the control flow model may be nullptr + if (cnode_abstract == nullptr) { + *shape_vector = std::vector(); + return lite::RET_OK; + } + if (cnode_abstract->BuildShape() == mindspore::abstract::kNoShape) { + *shape_vector = std::vector(); + return lite::RET_OK; + } + if (!utils::isa(cnode_abstract)) { + MS_LOG(ERROR) << "Abstract is not abstract tensor. " << cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + auto cnode_abstract_tensor = cnode_abstract->cast(); + CHECK_NULL_RETURN(cnode_abstract_tensor); + if (!utils::isa(cnode_abstract_tensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of abstract tensor should be ShapePtr. " << cnode->fullname_with_scope(); + return lite::RET_ERROR; + } + auto shape_ptr = utils::cast(cnode_abstract_tensor->BuildShape()); + CHECK_NULL_RETURN(shape_ptr); + if (shape_ptr->shape().empty()) { + MS_LOG(WARNING) << "Shape is empty " << cnode->fullname_with_scope(); + } + *shape_vector = shape_ptr->shape(); + return lite::RET_OK; +} + +STATUS GetCNodeOrParameterShapeVec(const AnfNodePtr &anf_node, std::vector *shape) { + auto int64_t_to_int_func = [](int64_t x) -> int { return static_cast(x); }; + std::vector in_shape; + if (anf_node->isa()) { + GetShapeVectorAndIdxFromCNode(anf_node->cast(), &in_shape); + } else if (anf_node->isa()) { + auto param_node = anf_node->cast(); + GetShapeVectorFromParameter(param_node, &in_shape); + } else { + MS_LOG(ERROR) << "Node type is not recognized."; + return RET_ERROR; + } + shape->resize(in_shape.size()); + std::transform(in_shape.begin(), in_shape.end(), shape->begin(), int64_t_to_int_func); + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h index 359af757..be239094 100644 --- a/mindspore/lite/tools/common/graph_util.h +++ b/mindspore/lite/tools/common/graph_util.h @@ -89,6 +89,12 @@ STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph); STATUS UpdateGraphOutputName(schema::MetaGraphT *meta_graph); int TransferMetaGraph(const schema::MetaGraphT &graph, void **model_buf, size_t *size); + +STATUS GetShapeVectorAndIdxFromCNode(const CNodePtr &cnode, std::vector *shape_vector, size_t *idx = nullptr); + +STATUS GetShapeVectorFromParameter(const mindspore::ParameterPtr ¶m_node, std::vector *shape_vector); + +STATUS GetCNodeOrParameterShapeVec(const AnfNodePtr &anf_node, std::vector *shape); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 8ce0304e..9031cb96 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -41,6 +41,8 @@ include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel) file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/converter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/offline_packing_optimizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/converter_packed_node.cc ${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc ${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc ${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc @@ -125,6 +127,7 @@ set(LITE_SRC ${API_SRC} ${SRC_DIR}/common/ops/anf_utils.cc ${SRC_DIR}/common/utils.cc ${SRC_DIR}/common/file_utils.cc + ${SRC_DIR}/common/mmap_utils.cc ${SRC_DIR}/common/context_util.cc ${SRC_DIR}/common/graph_util.cc ${SRC_DIR}/common/string_util.cc @@ -152,6 +155,7 @@ set(LITE_SRC ${API_SRC} ${SRC_DIR}/runtime/sub_graph_kernel.cc ${SRC_DIR}/runtime/sub_graph_split.cc ${SRC_DIR}/runtime/lite_session.cc + ${SRC_DIR}/runtime/runtime_packed_node_pass.cc ${SRC_DIR}/runtime/executor.cc ${SRC_DIR}/runtime/lite_model.cc ${SRC_DIR}/errorcode.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 03cac4c0..a4274202 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -501,6 +501,14 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, return nullptr; } + if (!param->cpuOptionCfgParam.architecture.empty()) { + // Do offline pack. + if (OfflinePackingOptimizer().Optimize(old_graph, "ANDROID_ARM_CPU") != RET_OK) { + MS_LOG(ERROR) << "Do offline packing failed."; + return nullptr; + } + } + return old_graph; } diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 42e26310..8d0f2f5d 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -27,6 +27,7 @@ #include "ir/anf.h" #include "tools/converter/quantizer/quantizer.h" #include "tools/converter/converter_context.h" +#include "tools/converter/offline_packing_optimizer.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc index 03ca2ec4..595ce604 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc @@ -31,6 +31,7 @@ constexpr auto kRegistry = "registry"; constexpr auto kAclOptionParam = "acl_option_cfg_param"; constexpr auto kMicroParam = "micro_param"; constexpr auto kThirdPartyModelParam = "third_party_model"; +constexpr auto kCpuOptionParam = "cpu_option_cfg_param"; } // namespace int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) { std::map> maps; @@ -101,6 +102,13 @@ int ConfigFileParser::ParseConfigParam(std::maperase(kCpuOptionParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ParseCpuOptionCfgString failed."; + return ret; + } + for (const auto &config_info : *maps) { ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second); } @@ -152,6 +160,7 @@ int ConfigFileParser::ParseCommonQuantString(const std::map> &maps) { + if (maps.find(kCpuOptionParam) != maps.end()) { + const auto &map = maps.at(kCpuOptionParam); + std::map parse_map{{"architecture", cpu_option_cfg_string_.architecture}, + {"instruction", cpu_option_cfg_string_.instruction}}; + return SetMapData(map, parse_map, kCpuOptionParam); + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h index c407dcdd..36257b3a 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h @@ -45,6 +45,7 @@ struct CommonQuantString { std::string min_quant_weight_channel; std::string skip_quant_node; std::string debug_info_save_path; + std::string dynamic_quant_strategy; }; struct MixedBitWeightQuantString { @@ -102,6 +103,11 @@ struct ThirdPartyModelString { std::string extended_parameters; // format: {key1:value1;ker2:value2} }; +struct CpuOptionCfgString { + std::string architecture; + std::string instruction; +}; + class ConfigFileParser { public: int ParseConfigFile(const std::string &config_file_path); @@ -115,6 +121,7 @@ class ConfigFileParser { AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; } MicroParamString GetMicroParamString() { return this->micro_param_string_; } lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; } + CpuOptionCfgString GetCpuOptionCfgString() { return this->cpu_option_cfg_string_; } private: int ParseDataPreProcessString(const std::map> &maps); @@ -127,6 +134,7 @@ class ConfigFileParser { const std::map &parse_map, const std::string §ion); int ParseMicroParamString(const std::map> &maps); int ParseThirdPartyParamString(const std::map> §ions); + int ParseCpuOptionCfgString(const std::map> §ions); private: DataPreProcessString data_pre_process_string_; @@ -137,6 +145,7 @@ class ConfigFileParser { AclOptionCfgString acl_option_cfg_string_; MicroParamString micro_param_string_; lite::ThirdPartyModelString third_party_model_string_; + CpuOptionCfgString cpu_option_cfg_string_; }; } // namespace lite diff --git a/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc new file mode 100644 index 00000000..41528773 --- /dev/null +++ b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/config_parser/cpu_option_param_parser.h" +#include "common/log.h" +namespace mindspore { +namespace lite { +STATUS CpuOptionParamParser::ParseCpuOptionCfg(const CpuOptionCfgString &cpu_option_string, + CpuOptionCfg *cpu_option_cfg) { + if (cpu_option_string.architecture.empty() || cpu_option_string.instruction.empty()) { + return RET_OK; + } + + if (cpu_option_string.architecture != "ARM64") { + MS_LOG(ERROR) << "cpu instruction only supported ARM64. But get " << cpu_option_string.architecture; + return RET_INPUT_PARAM_INVALID; + } + + if (cpu_option_string.instruction != "SIMD_DOT") { + MS_LOG(ERROR) << "cpu instruction only supported SIMD_DOT. But get " << cpu_option_string.instruction; + return RET_INPUT_PARAM_INVALID; + } + cpu_option_cfg->instruction = cpu_option_string.instruction; + cpu_option_cfg->architecture = cpu_option_string.architecture; + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h new file mode 100644 index 00000000..c549477f --- /dev/null +++ b/mindspore/lite/tools/converter/config_parser/cpu_option_param_parser.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_CPU_OPTION_PARAM_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_CPU_OPTION_PARAM_PARSER_H_ +#include +#include "tools/converter/cxx_api/converter_para.h" +#include "tools/converter/config_parser/config_file_parser.h" +#include "include/errorcode.h" + +namespace mindspore { +namespace lite { +class CpuOptionParamParser { + public: + STATUS ParseCpuOptionCfg(const CpuOptionCfgString &cpu_option_string, CpuOptionCfg *cpu_option_cfg); +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_CPU_OPTION_PARAM_PARSER_H_ diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc index cc9807e4..c0bd6219 100644 --- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc @@ -111,6 +111,11 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str if (!common_quant->debug_info_save_path.empty()) { common_quant->is_debug = true; } + + if (!common_quant_string.dynamic_quant_strategy.empty() && + ParseDynamicQuantStrategy(common_quant_string.dynamic_quant_strategy, &common_quant->dynamic_strategy)) { + return RET_INPUT_PARAM_INVALID; + } return RET_OK; } @@ -210,5 +215,20 @@ int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activati return RET_INPUT_PARAM_INVALID; } } + +int QuantParamParser::ParseDynamicQuantStrategy(const std::string &dynamic_quant_strategy_str, + quant::DynamicQuantStrategy *dynamic_strategy) { + if (dynamic_quant_strategy_str == "ACTIVATION_LAYER") { + (*dynamic_strategy) = quant::ACTIVATION_LAYER; + return RET_OK; + } else if (dynamic_quant_strategy_str == "ACTIVATION_CHANNEL") { + (*dynamic_strategy) = quant::ACTIVATION_CHANNEL; + return RET_OK; + } else { + MS_LOG(ERROR) << "INPUT ILLEGAL: dynamic_quant_strategy must be ACTIVATION_LAYER or ACTIVATION_CHANNEL."; + return RET_INPUT_PARAM_INVALID; + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h index 4f9e3816..bbf3950c 100644 --- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h +++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h @@ -36,6 +36,7 @@ class QuantParamParser { quant::ActivationQuantizedMethod *activation_quant_method); static int ParseFilter(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant); static int ParseBitNum(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant); + static int ParseDynamicQuantStrategy(const std::string &dynamic_quant_strategy_str, quant::DynamicQuantStrategy *dynamic_strategy); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 6177d379..4ca303b5 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -47,6 +47,8 @@ #include "tools/converter/config_parser/third_party_param_parser.h" #include "tools/common/string_util.h" #include "src/common/file_utils.h" +#include "tools/converter/converter_packed_node.h" +#include "tools/converter/config_parser/cpu_option_param_parser.h" namespace mindspore { extern "C" { @@ -396,6 +398,13 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr ¶m) MS_LOG(ERROR) << "Parse micro param failed."; return ret; } + + lite::CpuOptionParamParser cpu_param_parser; + ret = cpu_param_parser.ParseCpuOptionCfg(config_parser.GetCpuOptionCfgString(), ¶m->cpuOptionCfgParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse cpu option param failed."; + return ret; + } return RET_OK; } @@ -795,6 +804,16 @@ int RunConverter(const std::shared_ptr ¶m, void **model_data, status = RET_ERROR; return status; } + + if (!param->cpuOptionCfgParam.architecture.empty()) { + std::string cpu_option = param->cpuOptionCfgParam.architecture + param->cpuOptionCfgParam.instruction; + status = ConverterPackedNode(meta_graph, cpu_option); + if (status != RET_OK) { + MS_LOG(ERROR) << "save pack info failed."; + return status; + } + } + // save graph to file meta_graph->version = Version(); diff --git a/mindspore/lite/tools/converter/converter_packed_node.cc b/mindspore/lite/tools/converter/converter_packed_node.cc new file mode 100644 index 00000000..f632fec3 --- /dev/null +++ b/mindspore/lite/tools/converter/converter_packed_node.cc @@ -0,0 +1,179 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/converter_packed_node.h" +#include "tools/converter/offline_packing_optimizer.h" +#include "src/runtime/kernel/cpu/int8/matmul_dynamic_base_int8.h" +#include "mindspore/core/ops/op_name.h" +#include "src/runtime/kernel/cpu/fp32/matmul_fp32.h" + +namespace mindspore { +namespace { +constexpr auto kMatmulCustomType = "MatmulFusionPacked"; +} + +namespace lite { +void AddCustomAttr(std::vector> *attrs, const std::string &&key, + const std::string &&value) { + auto attr = std::make_unique(); + attr->name = key; + std::vector attr_data(value.begin(), value.end()); + attr->data = attr_data; + attrs->emplace_back(std::move(attr)); +} + +int AddWeightSumsToInputs(const mindspore::kernel::MatmulDynamicBaseInt8CPUKernel *matmul_kernel, + schema::MetaGraphT *meta_graph, const std::unique_ptr &cnode, + size_t weight_sum_size) { + auto weight_sums_tensor = std::make_unique(); + weight_sums_tensor->nodeType = lite::NodeType_ValueNode; + weight_sums_tensor->format = schema::Format_NHWC; + weight_sums_tensor->dataType = TypeId::kNumberTypeInt32; + weight_sums_tensor->dims = {}; + weight_sums_tensor->dims.emplace_back(weight_sum_size / sizeof(int)); + weight_sums_tensor->data.resize(weight_sum_size); + weight_sums_tensor->name = cnode->name + "_weight_sums"; + if (memcpy_s(weight_sums_tensor->data.data(), weight_sums_tensor->data.size(), matmul_kernel->GetWeightSums(), + weight_sum_size) != EOK) { + MS_LOG(ERROR) << "new CustomT error."; + return RET_ERROR; + } + cnode->inputIndex.emplace_back(meta_graph->allTensors.size()); + meta_graph->allTensors.emplace_back(std::move(weight_sums_tensor)); + return RET_OK; +} + +int ReplaceMatMulFusionToCustom(schema::MetaGraphT *meta_graph, const std::unique_ptr &cnode, + const std::unique_ptr &b_input, + const std::string &cpu_option) { + auto lite_kernel = PackDataWrapper::GetInstance().GetPackedKernel(cnode->name); + if (lite_kernel == nullptr) { + MS_LOG(ERROR) << "Get Packed Kernel error."; + return RET_ERROR; + } + auto param = lite_kernel->op_parameter(); + if (param == nullptr) { + MS_LOG(ERROR) << "param is nullptr."; + return RET_ERROR; + } + auto matmul_param = reinterpret_cast(param); + if (matmul_param->matmul_type_ == kNotImplemented) { + MS_LOG(WARNING) << "Unsupported matmul type, only support dynamic quant int8."; + return RET_OK; + } + cnode->primitive->value.type = schema::PrimitiveType_Custom; + auto primitive = new (std::nothrow) schema::CustomT; + if (primitive == nullptr) { + MS_LOG(ERROR) << "new CustomT error."; + return RET_NULL_PTR; + } + primitive->type = kMatmulCustomType; + + // activation_type + AddCustomAttr(&(primitive->attr), ops::kActivationType, std::to_string(matmul_param->act_type_)); + // transpose_a + AddCustomAttr(&(primitive->attr), ops::kTransposeA, std::to_string(matmul_param->a_transpose_)); + // transpose_b + AddCustomAttr(&(primitive->attr), ops::kTransposeB, std::to_string(matmul_param->b_transpose_)); + + int b_batch; + const void *pack_b_ptr = nullptr; + size_t pack_b_size; + if (matmul_param->matmul_type_ == kMatmulDynamicSdotInt8Cpu) { + // replace packed data + auto matmul_kernel = reinterpret_cast(lite_kernel); + b_batch = matmul_kernel->GetBBatch(); + pack_b_size = b_batch * matmul_param->col_align_ * matmul_param->deep_align_ * sizeof(int8_t); + pack_b_ptr = reinterpret_cast(matmul_kernel->GetPackBPtr()); + auto weight_sum_size = b_batch * matmul_param->col_align_ * sizeof(int); + int ret = AddWeightSumsToInputs(matmul_kernel, meta_graph, cnode, weight_sum_size); + if (ret != RET_OK) { + delete primitive; + MS_LOG(ERROR) << "add weight sums to inputs error."; + return ret; + } + } else { + MS_LOG(ERROR) << "matmul_type is error."; + return RET_ERROR; + } + + if (pack_b_ptr == nullptr) { + delete primitive; + MS_LOG(ERROR) << "pack_b_ptr is nullptr."; + return RET_NULL_PTR; + } + + // copy packed weight to meta graph + b_input->data.resize(pack_b_size); + if (memcpy_s(b_input->data.data(), b_input->data.size(), pack_b_ptr, pack_b_size) != EOK) { + delete primitive; + MS_LOG(ERROR) << "memcpy packed weight error."; + return RET_ERROR; + } + + // add scalar to attr + AddCustomAttr(&(primitive->attr), "b_batch", std::to_string(b_batch)); + AddCustomAttr(&(primitive->attr), "deep", std::to_string(matmul_param->deep_)); + AddCustomAttr(&(primitive->attr), "col", std::to_string(matmul_param->col_)); + AddCustomAttr(&(primitive->attr), "col_align", std::to_string(matmul_param->col_align_)); + AddCustomAttr(&(primitive->attr), "deep_align", std::to_string(matmul_param->deep_align_)); + + // add cpu option + std::string cpu_option_str = cpu_option; + AddCustomAttr(&(primitive->attr), "cpu_option", std::move(cpu_option_str)); + + cnode->primitive->value.value = primitive; + return RET_OK; +} + +int ConverterPackedNode(schema::MetaGraphT *meta_graph, const std::string &cpu_option) { + for (auto &dst_node : meta_graph->nodes) { + if (dst_node->primitive == nullptr || dst_node->primitive->value.type != schema::PrimitiveType_MatMulFusion) { + continue; + } + MS_CHECK_TRUE_MSG(dst_node->inputIndex.size() >= kInputSize1, RET_ERROR, "inputs size is wrong."); + auto a_index = dst_node->inputIndex[FIRST_INPUT]; + MS_CHECK_TRUE_MSG(meta_graph->allTensors.size() > a_index, RET_ERROR, "allTensors size is wrong."); + auto &a_input = meta_graph->allTensors.at(a_index); + CHECK_NULL_RETURN(a_input); + + auto b_index = dst_node->inputIndex[SECOND_INPUT]; + MS_CHECK_TRUE_MSG(meta_graph->allTensors.size() > b_index, RET_ERROR, "allTensors size is wrong."); + auto &b_input = meta_graph->allTensors.at(b_index); + CHECK_NULL_RETURN(b_input); + + if (a_input->dataType != b_input->dataType) { + MS_LOG(ERROR) << "inputs dataType is not same." << a_input->dataType << " " << b_input->dataType; + return RET_ERROR; + } + + if (b_input->data.empty()) { + continue; + } + auto ret = ReplaceMatMulFusionToCustom(meta_graph, dst_node, b_input, cpu_option); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ReplaceMatmulToCustom error."; + return ret; + } + } + + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_packed_node.h b/mindspore/lite/tools/converter/converter_packed_node.h new file mode 100644 index 00000000..cee891fa --- /dev/null +++ b/mindspore/lite/tools/converter/converter_packed_node.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERT_PACKED_NODE_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERT_PACKED_NODE_H + +#include +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +int ConverterPackedNode(schema::MetaGraphT *meta_graph, const std::string &cpu_option); +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONVERT_PACKED_NODE_H diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h index 00b7fa3c..1d7d1aec 100644 --- a/mindspore/lite/tools/converter/cxx_api/converter_para.h +++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h @@ -48,6 +48,11 @@ struct ThirdPartyModelParam { std::map> extended_parameters; }; +struct CpuOptionCfg { + std::string architecture; + std::string instruction; +}; + struct ConverterPara { converter::FmkType fmk_type; std::string model_file; @@ -82,6 +87,7 @@ struct ConverterPara { lite::micro::MicroParam microParam; ParallelSplitConfig parallel_split_config; ThirdPartyModelParam thirdPartyModelParam; + CpuOptionCfg cpuOptionCfgParam; }; } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_ diff --git a/mindspore/lite/tools/converter/offline_packing_optimizer.cc b/mindspore/lite/tools/converter/offline_packing_optimizer.cc new file mode 100644 index 00000000..d9a62c15 --- /dev/null +++ b/mindspore/lite/tools/converter/offline_packing_optimizer.cc @@ -0,0 +1,307 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "tools/common/graph_util.h" +#include "tools/converter/offline_packing_optimizer.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "src/common/primitive_t_utils.h" +#include "src/common/ops/anf_utils.h" +#include "src/common/file_utils.h" +#include "nnacl/matmul_parameter.h" +#include "src/runtime//kernel/cpu/int8/matmul_dynamic_base_int8.h" +#include "tools/optimizer/common/gllo_utils.h" + +using mindspore::kernel::MatmulDynamicBaseInt8CPUKernel; + +namespace mindspore::lite { +namespace { +constexpr const int kPrimIndex = 0; +constexpr const int kSingleThread = 1; +const char kAndroidArmCpuBackendOption[] = "ANDROID_ARM_CPU"; +} // namespace + +mindspore::lite::InnerContext *InitInnerContextForAndroidArmCpu() { + // if the operation use thread_pool in inner context will throw exception. + auto inner_context = new (std::nothrow) lite::InnerContext(); + inner_context->Init(); + MS_CHECK_TRUE_MSG(inner_context != nullptr, nullptr, "Create InnerContext failed."); + inner_context->thread_num_ = kSingleThread; + inner_context->instructions_ctx_.support_sdot = true; + return inner_context; +} + +schema::PrimitiveType GetSchemaPrimitiveType(const AnfNodePtr &node) { + auto primitive_t = GetPrimitiveT(node); + if (primitive_t == nullptr) { + MS_LOG(ERROR) << "Failed to generate PrimitiveT."; + return schema::PrimitiveType::PrimitiveType_NONE; + } + return GetSchemaPrimType(primitive_t.get()); +} + +STATUS CreateMatmulPackDataIntoTable(const std::vector &in_tensors, const std::vector &out_tensors, + OpParameter *op_parameter, const kernel::KernelKey &desc, + const mindspore::lite::InnerContext *ctx) { + if (!KernelRegistry::GetInstance()->SupportKernel(desc)) { + MS_LOG(ERROR) << op_parameter->name_ << " is not supported."; + return RET_ERROR; + } + + kernel::LiteKernel *kernel = + KernelRegistry::GetInstance()->GetLiteKernel(in_tensors, out_tensors, ctx, desc, op_parameter); + if (kernel == nullptr) { + MS_LOG(ERROR) << "Anf node cannot be nullptr."; + return RET_ERROR; + } + kernel->set_name(op_parameter->name_); + + if (kernel->Prepare() != RET_OK) { + MS_LOG(ERROR) << "Failed to generate pack data for " << op_parameter->name_ << "."; + return RET_ERROR; + } + + PackDataWrapper::GetInstance().AddPackedKernel(op_parameter->name_, kernel); + return RET_OK; +} + +schema::QuantType GetQuantType(const CNodePtr &cnode) { + MS_CHECK_TRUE_MSG(cnode != nullptr, schema::QuantType::QuantType_QUANT_NONE, "cnode cannot be nullptr."); + auto primitive = GetValueNode(cnode->input(0)); + if (primitive == nullptr) { + MS_LOG(INFO) << "primitive is nullptr"; + return schema::QuantType::QuantType_QUANT_NONE; + } + auto quant_param_holder = quant::GetCNodeQuantHolder(primitive); + if (quant_param_holder != nullptr) { + return quant_param_holder->quant_type(); + } + return schema::QuantType::QuantType_QUANT_NONE; +} + +TypeId GetDataType(const CNodePtr &cnode, const std::vector &in_tensors, + const std::vector &out_tensors) { + if (in_tensors.empty()) { + MS_LOG(ERROR) << "in tensor is empty."; + return kTypeUnknown; + } + + // Currently, fp16 is not a supported option. + TypeId data_type = + in_tensors[0]->data_type() == kObjectTypeTensorType ? kNumberTypeFloat32 : in_tensors[0]->data_type(); + // How to judge quant type? + auto quant_type = GetQuantType(cnode); + if (quant_type == schema::QuantType_QUANT_WEIGHT) { + data_type = + in_tensors.front()->data_type() == kNumberTypeBool ? TypeId::kNumberTypeBool : TypeId::kNumberTypeFloat32; + } + return data_type; +} + +void QuantParamTToQuantParam(const schema::QuantParamT &quant_param_t, lite::LiteQuantParam *quant_param) { + quant_param->inited = true; + quant_param->bitNum = quant_param_t.numBits; + quant_param->scale = quant_param_t.scale; + quant_param->zeroPoint = quant_param_t.zeroPoint; + quant_param->var_corr = quant_param_t.varCorr; + quant_param->mean_corr = quant_param_t.meanCorr; + quant_param->roundType = quant_param_t.roundType; + quant_param->multiplier = quant_param_t.multiplier; + quant_param->dstDtype = quant_param_t.dstDtype; + quant_param->min = quant_param_t.min; + quant_param->max = quant_param_t.max; +} + +void AddQuantParams(Tensor *in_tensor, const std::vector &quant_param_t) { + std::vector lite_quant_params(quant_param_t.size()); + for (size_t i = 0; i < lite_quant_params.size(); i++) { + QuantParamTToQuantParam(quant_param_t[i], &lite_quant_params[i]); + } + in_tensor->set_quant_params(lite_quant_params); +} + +STATUS CreateLiteTensor(const CNodePtr &cnode, std::vector *in_tensors, std::vector *out_tensors) { + std::vector shape(0); + mindspore::TypeId type_id = TypeId::kTypeUnknown; + auto primitive = GetValueNode(cnode->input(0)); + if (primitive == nullptr) { + MS_LOG(INFO) << "primitive is nullptr"; + return RET_ERROR; + } + auto quant_param_holder = quant::GetCNodeQuantHolder(primitive); + std::vector> input_quant_params_vec; + std::vector> output_quant_params_vec; + if (quant_param_holder != nullptr) { + input_quant_params_vec = quant_param_holder->get_input_quant_params(); + output_quant_params_vec = quant_param_holder->get_output_quant_params(); + } + + // Generate input tensor. + for (size_t i = kPrimIndex + 1; i < cnode->inputs().size(); i++) { + if (opt::GetDataTypeFromAnfNode(cnode->input(i), &type_id) != RET_OK) { + MS_LOG(ERROR) << "Cannot get data type from " << cnode->input(i)->fullname_with_scope(); + return RET_ERROR; + } + void *tensor_data = nullptr; + Category category = cnode->input(i)->isa() ? lite::Category::CONST_TENSOR : lite::Category::VAR; + + MS_CHECK_TRUE_MSG(GetCNodeOrParameterShapeVec(cnode->input(i), &shape) == RET_OK, RET_ERROR, + "Infer shape must be done when using offline packing."); + MS_CHECK_TRUE_MSG(!shape.empty(), RET_ERROR, "Infer shape must be done when using offline packing."); + // Get tensor data from parameter node. + if (cnode->input(i)->isa() && cnode->input(i)->cast()->has_default()) { + auto param_node = cnode->input(i)->cast(); + if (param_node->has_default()) { + auto tensor_info = std::static_pointer_cast(param_node->default_param()); + tensor_data = tensor_info->data().data(); + } + } + auto in_tensor = new (std::nothrow) Tensor(type_id, shape); + MS_CHECK_TRUE_MSG(in_tensor != nullptr, RET_ERROR, "Create input tensor failed."); + in_tensor->set_category(category); + // Tensor data is managed by funcGraph. + in_tensor->set_data(tensor_data); + in_tensor->set_own_data(false); + // Setup quant params. + if (type_id == TypeId::kNumberTypeInt8 && !input_quant_params_vec.empty()) { + AddQuantParams(in_tensor, input_quant_params_vec.front()); + input_quant_params_vec.erase(input_quant_params_vec.begin()); + } + in_tensors->emplace_back(in_tensor); + shape.clear(); + type_id = TypeId::kTypeUnknown; + } + + if (!input_quant_params_vec.empty()) { + MS_LOG(WARNING) << cnode->fullname_with_scope() << " quant params' count are not equal to inputs' size"; + } + + // Generate output tensor. + MS_CHECK_TRUE_MSG(GetCNodeOrParameterShapeVec(cnode, &shape) == RET_OK, RET_ERROR, + "Infer shape must be done when using offline packing."); + MS_CHECK_TRUE_MSG(!shape.empty(), RET_ERROR, "Infer shape must be done when using offline packing."); + if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) { + MS_LOG(ERROR) << "Cannot get data type from " + cnode->fullname_with_scope() + "."; + return RET_ERROR; + } + auto out_tensor = new (std::nothrow) Tensor(type_id, shape); + MS_CHECK_TRUE_MSG(out_tensor != nullptr, RET_ERROR, "Create output tensor failed."); + if (type_id == TypeId::kNumberTypeInt8 && !output_quant_params_vec.empty()) { + AddQuantParams(out_tensor, output_quant_params_vec.front()); + output_quant_params_vec.erase(output_quant_params_vec.begin()); + } + out_tensors->emplace_back(out_tensor); + + if (in_tensors->size() != cnode->inputs().size() - 1 || out_tensors->empty()) { + MS_LOG(ERROR) << "Failed to populate input tensors for " << cnode->fullname_with_scope() << "."; + return RET_ERROR; + } + + return RET_OK; +} + +STATUS MatmulPacking(const mindspore::CNodePtr &cnode_ptr, const FuncGraphPtr &funcGraphPtr, + const lite::InnerContext *ctx) { + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "Matmul node cannot be nullptr."; + return RET_ERROR; + } + auto primT = mindspore::lite::GetPrimitiveT(cnode_ptr->input(kPrimIndex)); + if (primT == nullptr) { + MS_LOG(ERROR) << "Failed to generate PrimitiveT for " << cnode_ptr->fullname_with_scope() << "."; + return RET_ERROR; + } + OpParameter *op_parameter = GetOpParameter(primT.get()); + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Failed to generate op parameter for " << cnode_ptr->fullname_with_scope() << "."; + return RET_ERROR; + } + op_parameter->thread_num_ = kSingleThread; + op_parameter->quant_type_ = GetQuantType(cnode_ptr); + + constexpr size_t max_name_len = 100; + if (memcpy_s(op_parameter->name_, max_name_len, cnode_ptr->fullname_with_scope().c_str(), + cnode_ptr->fullname_with_scope().length()) != EOK) { + MS_LOG(ERROR) << "Set op parameter name failed."; + return RET_ERROR; + } + + std::vector in_tensors; + std::vector out_tensors; + if (CreateLiteTensor(cnode_ptr, &in_tensors, &out_tensors) != RET_OK) { + MS_LOG(ERROR) << "Failed to populate input tensors for " << cnode_ptr->fullname_with_scope() << "."; + return RET_ERROR; + } + + TypeId data_type = GetDataType(cnode_ptr, in_tensors, out_tensors); + MS_CHECK_TRUE_MSG(data_type != TypeId::kTypeUnknown, RET_ERROR, + "Can't get data type from " + cnode_ptr->fullname_with_scope() + "."); + kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, NHWC, op_parameter->type_}; + + return CreateMatmulPackDataIntoTable(in_tensors, out_tensors, op_parameter, desc, ctx); +} + +BackendType FindBackend(const std::string &target_backend) { + if (target_backend == std::string(kAndroidArmCpuBackendOption)) { + return BackendType::kAndroidArmCpuBackend; + } + return BackendType::kUnknownBackend; +} + +STATUS OfflinePackingOptimizer::Optimize(const FuncGraphPtr &func_graph, const std::string &target_backend) { + BackendType backend = FindBackend(target_backend); + if (backend == BackendType::kUnknownBackend || + this->packing_strategies_selector_.find(backend) == this->packing_strategies_selector_.end() || + this->ctx_creator_selector_.find(backend) == this->ctx_creator_selector_.end()) { + MS_LOG(ERROR) << target_backend << " is not supported to do offline packing."; + return RET_ERROR; + } + + // Get built-in backend optimizer. + std::map selected_backend_op_cvt = + this->packing_strategies_selector_[backend]; + mindspore::lite::InnerContext *inner_context = this->ctx_creator_selector_[backend](); + MS_CHECK_TRUE_MSG(inner_context != nullptr, RET_ERROR, "Failed to initialize runtime context."); + + auto anf_nodes = mindspore::TopoSort(func_graph->get_return()); + for (auto &anf_node : anf_nodes) { + if (!utils::isa(anf_node)) { + continue; + } + if (mindspore::opt::CheckPrimitiveType(anf_node, prim::kPrimReturn) || + mindspore::opt::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple) || + mindspore::opt::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { + continue; + } + auto cnode = anf_node->cast(); + schema::PrimitiveType op_type = GetSchemaPrimitiveType(cnode->input(kPrimIndex)); + if (selected_backend_op_cvt.find(op_type) != selected_backend_op_cvt.end()) { + OfflinePackingFunc packing_func = selected_backend_op_cvt[op_type]; + if (packing_func(cnode, func_graph, inner_context) != RET_OK) { + MS_LOG(ERROR) << "Failed to pack for " << anf_node->fullname_with_scope(); + delete inner_context; + return RET_ERROR; + } + } + } + delete inner_context; + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/offline_packing_optimizer.h b/mindspore/lite/tools/converter/offline_packing_optimizer.h new file mode 100644 index 00000000..2590f542 --- /dev/null +++ b/mindspore/lite/tools/converter/offline_packing_optimizer.h @@ -0,0 +1,87 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_OFFLINE_PACKING_OPTIMIZER_H +#define LITE_OFFLINE_PACKING_OPTIMIZER_H +#include +#include +#include "base/base.h" +#include "ir/anf.h" +#include "ops/core_ops.h" +#include "runtime/lite_kernel.h" +#include "runtime/kernel_registry.h" + +namespace mindspore::lite { +using OfflinePackingFunc = STATUS (*)(const mindspore::CNodePtr &cnode_ptr, const FuncGraphPtr &funcGraphPtr, + const lite::InnerContext *ctx); +using InnerContextCreatorFunc = mindspore::lite::InnerContext *(*)(); + +STATUS MatmulPacking(const mindspore::CNodePtr &cnode_ptr, const FuncGraphPtr &funcGraphPtr, + const lite::InnerContext *ctx); +mindspore::lite::InnerContext *InitInnerContextForAndroidArmCpu(); + +enum class BackendType : uint8_t { + kUnknownBackend = 0, + kAndroidArmCpuBackend, +}; + +class PackDataWrapper { + public: + static PackDataWrapper &GetInstance() { + static PackDataWrapper instance; + return instance; + } + + const kernel::LiteKernel *GetPackedKernel(const std::string &node_name) { + if (this->pack_mapping_.find(node_name) == this->pack_mapping_.end()) { + return nullptr; + } + return this->pack_mapping_[node_name]; + } + + void AddPackedKernel(const std::string &node_name, const kernel::LiteKernel *data) { + if (this->pack_mapping_.find(node_name) != this->pack_mapping_.end()) { + MS_LOG(WARNING) << "Key conflict when add packed kernel."; + } + this->pack_mapping_[node_name] = data; + } + + private: + PackDataWrapper() = default; + ~PackDataWrapper() = default; + + private: + std::map pack_mapping_; +}; + +class OfflinePackingOptimizer { + public: + OfflinePackingOptimizer() { + this->packing_strategies_selector_[BackendType::kAndroidArmCpuBackend] = + std::map{ + {schema::PrimitiveType::PrimitiveType_MatMulFusion, MatmulPacking}, + }; + this->ctx_creator_selector_[BackendType::kAndroidArmCpuBackend] = InitInnerContextForAndroidArmCpu; + } + + STATUS Optimize(const FuncGraphPtr &func_graph, const std::string &target_backend); + + private: + std::map> packing_strategies_selector_; + std::map ctx_creator_selector_; +}; +}; // namespace mindspore::lite +#endif // LITE_OFFLINE_PACKING_OPTIMIZER_H diff --git a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc index 51d3d992..96eec450 100644 --- a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.cc @@ -27,7 +27,15 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) { auto quantizer = WeightQuantizer(param_); const std::set support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather}; const std::set symmetric_nodes = {prim::kPrimMatMulFusion}; - auto ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes); + int ret; + // when activation is perchannel quantization, weight perlayer quant + if (activation_perchannel_) { + const std::set support_per_layers_nodes = {prim::kPrimMatMulFusion}; + ret = + quantizer.WeightQuant(func_graph, support_weight_quant_nodes, support_per_layers_nodes, symmetric_nodes); + } else { + ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes); + } if (ret != RET_OK) { MS_LOG(ERROR) << "Weight Quant failed."; return ret; @@ -36,7 +44,8 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) { const std::set support_dynamic_quant_ops = { prim::kPrimMatMulFusion, }; - ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, param_->commonQuantParam.skip_quant_node); + ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, param_->commonQuantParam.skip_quant_node, + activation_perchannel_); if (ret != RET_OK) { MS_LOG(ERROR) << "Insert dynamic quant failed."; return ret; diff --git a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h index 8a172e7b..00ed204b 100644 --- a/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/dynamic_quantizer.h @@ -43,6 +43,7 @@ class DynamicQuantizer : public Quantizer { public: explicit DynamicQuantizer(const std::shared_ptr ¶m) : Quantizer(param) { bit_num_ = param->commonQuantParam.bit_num; + activation_perchannel_ = (param->commonQuantParam.dynamic_strategy == quant::ACTIVATION_CHANNEL); } ~DynamicQuantizer() = default; @@ -53,6 +54,7 @@ class DynamicQuantizer : public Quantizer { int quant_max_{127}; int quant_min_{-128}; TypeId type_id_{kNumberTypeInt8}; + bool activation_perchannel_ = false; }; } // namespace mindspore::lite::quant #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc index 2f42240f..c528ffbd 100644 --- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc +++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc @@ -24,11 +24,15 @@ #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/format_utils.h" #include "tools/common/node_util.h" +#include "ops/op_name.h" +#include "ops/fusion/mat_mul_fusion.h" namespace mindspore::lite::quant { namespace { constexpr size_t kMinSize3 = 3; constexpr size_t kPrimitiveCOffset = 1; +constexpr int kLastFisrtIndex = -1; +constexpr int kLastSecondIndex = -2; } // namespace ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type, const std::vector &quant_params) { @@ -166,11 +170,17 @@ int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph) } int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, - size_t index) { + size_t index, bool activation_perchannel) { auto primitive = std::make_shared(); auto primitive_c = primitive->GetPrim(); primitive->set_dst_type(dst_type_); - primitive->set_symmetric(symmetric_); + bool symmetric = activation_perchannel ? true : false; + primitive->set_symmetric(symmetric); + primitive->set_activation_perchannel(activation_perchannel); + if (activation_perchannel && SetPreferAxis(cnode, index, primitive) != RET_OK) { + MS_LOG(ERROR) << "Set prefer axis failed, " << cnode->fullname_with_scope(); + return RET_ERROR; + } auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(index)}); auto name = cnode->fullname_with_scope() + "_dynamic_cast_node_" + std::to_string(index); dynamic_quant_cnode->set_fullname_with_scope(name); @@ -181,7 +191,8 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap return RET_NULL_PTR; } dynamic_quant_cnode->set_abstract(abstract); - auto ret = UpdateDataType(cnode, dst_type_); + abstract->set_shape(cnode->input(index)->Shape()); + auto ret = UpdateDataType(dynamic_quant_cnode, dst_type_); if (ret != RET_OK) { MS_LOG(ERROR) << cnode->fullname_with_scope() << " set new dtype failed."; return ret; @@ -191,7 +202,39 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap return RET_OK; } -int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) { +int InsertQuantNodeManager::SetPreferAxis(const CNodePtr &cnode, size_t index, + const std::shared_ptr &dynamic_primitive) { + auto primitive = GetValueNode(cnode->input(0)); + if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) { + auto matmul_prim = api::MakeShared(primitive); + CHECK_NULL_RETURN(matmul_prim); + // For MatMul A + if (index == kInputIndex + kPrimOffset) { + if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) { + dynamic_primitive->set_prefer_axis(kLastFisrtIndex); + dynamic_primitive->set_transpose(true); + } else { + dynamic_primitive->set_prefer_axis(kLastSecondIndex); + dynamic_primitive->set_transpose(false); + } + } + // For MatMul B + if (index == kWeightIndex + kPrimOffset) { + if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) { + dynamic_primitive->set_prefer_axis(kLastSecondIndex); + dynamic_primitive->set_transpose(true); + } else { + dynamic_primitive->set_prefer_axis(kLastFisrtIndex); + dynamic_primitive->set_transpose(false); + } + } + } else { + MS_LOG(WARNING) << "cnode don't need prefer axis, cnode name: " << cnode->fullname_with_scope(); + } + return RET_OK; +} + +int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode, bool activation_perchannel) { auto op_name = cnode->fullname_with_scope(); if (cnode->size() < kMinSize3) { MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3."; @@ -199,11 +242,11 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const } auto input = cnode->input(kInputIndex + kPrimitiveCOffset); if (input->isa() || IsGraphInput(input)) { - InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset); + InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset, activation_perchannel); } auto weight = cnode->input(kWeightIndex + kPrimitiveCOffset); if (weight->isa() || IsGraphInput(weight)) { - InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset); + InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset, activation_perchannel); } return RET_OK; } @@ -222,7 +265,8 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) { int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set &support_dynamic_quant_ops, - const std::set &skip_quant_node) { + const std::set &skip_quant_node, + bool activation_perchannel) { MS_ASSERT(graph != nullptr); auto cnodes = graph->GetOrderedCnodes(); for (auto &cnode : cnodes) { @@ -244,7 +288,7 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph, MS_LOG(INFO) << "node:" << op_name << " type:" << type << " will not quantify."; continue; } - ret = NewDynamicQuantNode(graph, cnode); + ret = NewDynamicQuantNode(graph, cnode, activation_perchannel); if (ret != RET_OK) { MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed."; return ret; diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h index 3555a42c..7c0410dd 100644 --- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h +++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h @@ -36,7 +36,7 @@ class InsertQuantNodeManager { int InsertQuantDtypeCastNode(const FuncGraphPtr &graph); int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set &support_dynamic_quant_ops, - const std::set &skip_quant_node); + const std::set &skip_quant_node, bool activation_perchannel = false); private: ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector &quant_params); @@ -45,15 +45,16 @@ class InsertQuantNodeManager { int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const; - int NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode); + int NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode, bool activation_perchannel = false); int MarkDynamicQuantize(const CNodePtr &cnode); - int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index); + int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index, + bool activation_perchannel = false); + int SetPreferAxis(const CNodePtr &cnode, size_t index, const std::shared_ptr &dynamic_primitive); private: TypeId dst_type_ = kNumberTypeInt8; - bool symmetric_ = false; }; } // namespace mindspore::lite::quant #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H_ diff --git a/mindspore/lite/tools/converter/quantizer/quant_params.h b/mindspore/lite/tools/converter/quantizer/quant_params.h index d7656802..e08b70cb 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_params.h +++ b/mindspore/lite/tools/converter/quantizer/quant_params.h @@ -22,6 +22,7 @@ #include #include "schema/inner/model_generated.h" namespace mindspore::lite::quant { +constexpr int kPrimOffset = 1; enum ActivationQuantizedMethod { MAX_MIN = 0, KL = 1, @@ -40,6 +41,11 @@ enum DebugMode { DETAIL, }; +enum DynamicQuantStrategy { + ACTIVATION_LAYER, + ACTIVATION_CHANNEL, +}; + struct CommonQuantParam { schema::QuantType quant_type = schema::QuantType_QUANT_NONE; int bit_num = 8; @@ -50,6 +56,7 @@ struct CommonQuantParam { DebugMode debug_mode = DETAIL; std::set skip_quant_node; int thread_num = 4; + DynamicQuantStrategy dynamic_strategy = quant::ACTIVATION_LAYER; }; struct MixedBitWeightQuantParam { -- 2.17.1