1From 46c42f4387b8280a25b16592aa25179093d99fa2 Mon Sep 17 00:00:00 2001 2From: z00574805 <z00574805@notesmail.huawei.com/> 3Date: Wed, 24 May 2023 11:10:00 +0800 4Subject: [PATCH 3/5] xiaoyi-0003 5 6--- 7 .jenkins/check/config/filter_cppcheck.txt | 1 + 8 cmake/package_micro.cmake | 4 +- 9 include/api/serialization.h | 26 + 10 .../nnacl/base/minimal_filtering_generator.c | 2 +- 11 .../nnacl/base/minimal_filtering_generator.h | 1 - 12 .../cpu/kernel/nnacl/custom_gru_parameter.h | 31 + 13 .../cpu/kernel/nnacl/fp16/custom_gru_fp16.c | 70 ++ 14 .../cpu/kernel/nnacl/fp16/custom_gru_fp16.h | 32 + 15 .../cpu/kernel/nnacl/fp32/custom_gru_fp32.c | 71 ++ 16 .../cpu/kernel/nnacl/fp32/custom_gru_fp32.h | 32 + 17 .../device/cpu/kernel/nnacl/fp32/lstm_fp32.h | 2 +- 18 .../cpu/kernel/nnacl/infer/custom_gru_infer.c | 45 + 19 .../cpu/kernel/nnacl/infer/custom_gru_infer.h | 30 + 20 .../plugin/device/cpu/kernel/nnacl/op_base.h | 1 + 21 .../lite/include/registry/converter_context.h | 3 +- 22 mindspore/lite/src/CMakeLists.txt | 1 + 23 mindspore/lite/src/common/graph_util.cc | 7 +- 24 .../common/ops/populate/custom_populate.cc | 14 + 25 .../lite/src/runtime/cxx_api/serialization.cc | 31 + 26 .../kernel/cpu/base/group_convolution_base.cc | 34 +- 27 .../cpu/base/group_convolution_creator.cc | 24 +- 28 .../cpu/base/group_convolution_creator.h | 8 +- 29 .../cpu/fp16/convolution_delegate_fp16.cc | 2 +- 30 .../kernel/cpu/fp16/custom_gru_fp16.cc | 132 +++ 31 .../runtime/kernel/cpu/fp16/custom_gru_fp16.h | 40 + 32 .../cpu/fp32/convolution_delegate_fp32.cc | 2 +- 33 .../kernel/cpu/fp32/custom_gru_fp32.cc | 251 ++++++ 34 .../runtime/kernel/cpu/fp32/custom_gru_fp32.h | 51 ++ 35 .../cpu/int8/convolution_int8_creator.cc | 2 +- 36 mindspore/lite/src/runtime/lite_session.h | 6 + 37 mindspore/lite/src/train/graph_fusion.cc | 7 + 38 .../train/optimizer/fusion/gru_fusion_pass.cc | 809 ++++++++++++++++++ 39 .../train/optimizer/fusion/gru_fusion_pass.h | 45 + 40 mindspore/lite/src/train/static_allocator.h | 6 +- 41 mindspore/lite/src/train/train_export.cc | 38 + 42 mindspore/lite/src/train/train_export.h | 1 + 43 mindspore/lite/src/train/train_session.cc | 37 +- 44 mindspore/lite/src/train/train_session.h | 3 + 45 .../test/config_level0/micro/micro_arm64.cfg | 7 + 46 .../config_parser/config_file_parser.cc | 13 +- 47 .../config_parser/config_file_parser.h | 2 + 48 .../config_parser/micro_param_parser.cc | 33 + 49 .../config_parser/micro_param_parser.h | 2 + 50 mindspore/lite/tools/converter/converter.cc | 18 +- 51 .../converter_lite/converter_flags.cc | 4 +- 52 .../converter/micro/cmake/file_list.cmake | 3 + 53 .../micro/coder/allocator/allocator.cc | 82 +- 54 .../micro/coder/allocator/allocator.h | 13 +- 55 .../lite/tools/converter/micro/coder/coder.cc | 51 +- 56 .../lite/tools/converter/micro/coder/coder.h | 9 +- 57 .../lite/tools/converter/micro/coder/config.h | 10 + 58 .../tools/converter/micro/coder/context.h | 25 +- 59 .../generator/component/common_component.cc | 5 +- 60 .../generator/component/weight_component.cc | 301 +++++-- 61 .../generator/component/weight_component.h | 2 +- 62 .../micro/coder/generator/generator.cc | 2 +- 63 .../coder/opcoders/base/reshape_base_coder.cc | 5 + 64 .../coder/opcoders/base/stack_base_coder.cc | 85 ++ 65 .../coder/opcoders/base/stack_base_coder.h | 42 + 66 .../opcoders/base/strided_slice_base_coder.cc | 21 + 67 .../nnacl/fp16/custom_gru_fp16_coder.cc | 34 + 68 .../nnacl/fp16/custom_gru_fp16_coder.h | 44 + 69 .../nnacl/fp16/matmul_fp16_base_coder.cc | 22 +- 70 .../nnacl/fp16/matmul_fp16_base_coder.h | 4 +- 71 .../opcoders/nnacl/fp16/matmul_fp16_coder.h | 4 +- 72 .../fp32/convolution_depthwise_fp32_coder.cc | 69 +- 73 .../fp32/convolution_depthwise_fp32_coder.h | 8 +- 74 .../fp32/convolution_winograd_fp32_coder.cc | 98 ++- 75 .../fp32/convolution_winograd_fp32_coder.h | 15 +- 76 .../nnacl/fp32/custom_gru_fp32_coder.cc | 214 +++++ 77 .../nnacl/fp32/custom_gru_fp32_coder.h | 64 ++ 78 .../opcoders/nnacl/fp32/gather_fp32_coder.cc | 59 +- 79 .../opcoders/nnacl/fp32/gather_fp32_coder.h | 2 + 80 .../nnacl/fp32/matmul_fp32_base_coder.cc | 41 +- 81 .../nnacl/fp32/matmul_fp32_base_coder.h | 2 +- 82 .../micro/coder/opcoders/op_coder_builder.cc | 13 + 83 .../micro/coder/opcoders/op_coder_builder.h | 4 + 84 .../micro/coder/opcoders/op_coder_register.cc | 8 +- 85 .../micro/coder/opcoders/op_coder_register.h | 23 +- 86 .../nnacl_serializer/nnacl_fp32_serializer.cc | 5 + 87 .../nnacl_serializer/nnacl_fp32_serializer.h | 2 + 88 .../tools/converter/micro/coder/session.cc | 49 +- 89 .../tools/converter/micro/coder/session.h | 2 +- 90 .../converter/micro/coder/utils/type_cast.cc | 3 +- 91 .../converter/micro/coder/utils/type_cast.h | 4 +- 92 85 files changed, 3163 insertions(+), 267 deletions(-) 93 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h 94 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 95 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h 96 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c 97 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h 98 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c 99 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h 100 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc 101 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h 102 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc 103 create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h 104 create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc 105 create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h 106 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc 107 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h 108 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc 109 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h 110 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc 111 create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h 112 113diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt 114index 6aeb515a..6f6a08f5 100644 115--- a/.jenkins/check/config/filter_cppcheck.txt 116+++ b/.jenkins/check/config/filter_cppcheck.txt 117@@ -56,6 +56,7 @@ 118 "mindspore/mindspore/lite/tools/converter/quantizer/quantize_util.cc" "useStlAlgorithm" 119 "mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable" 120 "mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable" 121+"mindspore/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc" "stlFindInsert" 122 "mindspore/mindspore/lite/examples/quick_start_micro/" "syntaxError" 123 "mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable" 124 "mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError" 125diff --git a/cmake/package_micro.cmake b/cmake/package_micro.cmake 126index 3c6da3db..0481e3c3 100644 127--- a/cmake/package_micro.cmake 128+++ b/cmake/package_micro.cmake 129@@ -10,6 +10,8 @@ function(__install_micro_wrapper) 130 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") 131 install(DIRECTORY ${NNACL_DIR}/fp32 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl 132 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") 133+ install(DIRECTORY ${NNACL_DIR}/fp16 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl 134+ COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") 135 install(DIRECTORY ${NNACL_DIR}/kernel DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl 136 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") 137 install(DIRECTORY ${NNACL_DIR}/infer DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl 138@@ -34,4 +36,4 @@ function(__install_micro_codegen) 139 COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") 140 install(TARGETS cmsis_nn ARCHIVE DESTINATION ${CODEGEN_ROOT_DIR}/third_party/lib 141 COMPONENT ${RUNTIME_COMPONENT_NAME}) 142-endfunction() 143\ No newline at end of file 144+endfunction() 145diff --git a/include/api/serialization.h b/include/api/serialization.h 146index 1a0c1f57..76d5dbec 100644 147--- a/include/api/serialization.h 148+++ b/include/api/serialization.h 149@@ -105,6 +105,21 @@ class MS_API Serialization { 150 QuantizationType quantization_type = kNoQuant, bool export_inference_only = true, 151 std::vector<std::string> output_tensor_name = {}); 152 153+ /// \brief Export model's weights, which can be used in micro only. 154+ /// 155+ /// \param[in] model The model data. 156+ /// \param[in] model_type The model file type. 157+ /// \param[in] weight_file The path of exported weight file. 158+ /// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`. 159+ /// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format. 160+ /// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable. 161+ /// 162+ /// \return Status. 163+ inline static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, 164+ const std::string &weight_file, bool is_inference = true, 165+ bool enable_fp16 = false, 166+ const std::vector<std::string> &changeable_weights_name = {}); 167+ 168 private: 169 static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, 170 const std::vector<char> &dec_mode); 171@@ -119,6 +134,10 @@ class MS_API Serialization { 172 static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data, 173 QuantizationType quantization_type, bool export_inference_only, 174 const std::vector<std::vector<char>> &output_tensor_name); 175+ static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, 176+ const std::vector<char> &weight_file, bool is_inference, 177+ bool enable_fp16, 178+ const std::vector<std::vector<char>> &changeable_weights_name); 179 }; 180 181 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, 182@@ -150,5 +169,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff 183 VectorStringToChar(output_tensor_name)); 184 } 185 186+Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, 187+ const std::string &weight_file, bool is_inference, 188+ bool enable_fp16, 189+ const std::vector<std::string> &changeable_weights_name) { 190+ return ExportWeightsCollaborateWithMicro(model, model_type, StringToChar(weight_file), is_inference, enable_fp16, 191+ VectorStringToChar(changeable_weights_name)); 192+} 193 } // namespace mindspore 194 #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H 195diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c 196index 3796e47b..81bf8ddf 100644 197--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c 198+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c 199@@ -16,9 +16,9 @@ 200 #include "nnacl/base/minimal_filtering_generator.h" 201 #include <string.h> 202 #include <math.h> 203-#include "nnacl/fp32/winograd_utils.h" 204 #include "nnacl/errorcode.h" 205 #include "nnacl/intrinsics/ms_simd_instructions.h" 206+#include "nnacl/fp32/pack_fp32.h" 207 208 void Polynomial(const float *interval, float *m, int degree) { 209 for (int i = 0; i < degree; ++i) { 210diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h 211index 01b013e8..fc0fa0e6 100644 212--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h 213+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h 214@@ -21,7 +21,6 @@ 215 #include <arm_neon.h> 216 #endif 217 #include <stdbool.h> 218-#include "nnacl/pack.h" 219 220 #ifdef __cplusplus 221 extern "C" { 222diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h 223new file mode 100644 224index 00000000..3bb8a444 225--- /dev/null 226+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h 227@@ -0,0 +1,31 @@ 228+/** 229+ * Copyright 2023 Huawei Technologies Co., Ltd 230+ * 231+ * Licensed under the Apache License, Version 2.0 (the "License"); 232+ * you may not use this file except in compliance with the License. 233+ * You may obtain a copy of the License at 234+ * 235+ * http://www.apache.org/licenses/LICENSE-2.0 236+ * 237+ * Unless required by applicable law or agreed to in writing, software 238+ * distributed under the License is distributed on an "AS IS" BASIS, 239+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 240+ * See the License for the specific language governing permissions and 241+ * limitations under the License. 242+ */ 243+#ifndef MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_ 244+#define MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_ 245+ 246+#include "nnacl/op_base.h" 247+ 248+typedef struct CustomGruParameter { 249+ // Primitive parameter 250+ OpParameter op_parameter_; 251+ // shape correlative 252+ int num_step; 253+ int batch_size; 254+ int input_size; 255+ int hidden_size; 256+} CustomGruParameter; 257+ 258+#endif // MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_ 259diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 260new file mode 100644 261index 00000000..6e754569 262--- /dev/null 263+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 264@@ -0,0 +1,70 @@ 265+#ifdef ENABLE_ARM64 266+/** 267+ * Copyright 2023 Huawei Technologies Co., Ltd 268+ * 269+ * Licensed under the Apache License, Version 2.0 (the "License"); 270+ * you may not use this file except in compliance with the License. 271+ * You may obtain a copy of the License at 272+ * 273+ * http://www.apache.org/licenses/LICENSE-2.0 274+ * 275+ * Unless required by applicable law or agreed to in writing, software 276+ * distributed under the License is distributed on an "AS IS" BASIS, 277+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 278+ * See the License for the specific language governing permissions and 279+ * limitations under the License. 280+ */ 281+#include "nnacl/fp16/custom_gru_fp16.h" 282+#include "nnacl/fp16/activation_fp16.h" 283+#include "nnacl/fp16/arithmetic_fp16.h" 284+#include "nnacl/fp16/matmul_fp16.h" 285+ 286+void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, 287+ const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, 288+ const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param) { 289+ int num_step = gru_param->num_step; 290+ int batch_size = gru_param->batch_size; 291+ int input_size = gru_param->input_size; 292+ int hidden_size = gru_param->hidden_size; 293+ int output_size = batch_size * hidden_size; 294+ int double_output_size = output_size * C2NUM; 295+ int col_align = UP_ROUND(hidden_size, C8NUM); 296+ int weight_in_offset = col_align * input_size; 297+ int weight_hidden_offset = col_align * hidden_size; 298+ float16_t *input_gate = buffer[1]; 299+ float16_t *hidden_gate = buffer[C3NUM]; 300+ for (int i = 0; i < num_step; ++i) { 301+ if (batch_size != 1) { 302+ RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size); 303+ for (int j = 0; j < C3NUM; ++j) { 304+ MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, 305+ bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, 306+ OutType_Nhwc); 307+ } 308+ RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size); 309+ for (int j = 0; j < C3NUM; ++j) { 310+ MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, 311+ bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, 312+ OutType_Nhwc); 313+ } 314+ } else { 315+ for (int j = 0; j < C3NUM; ++j) { 316+ VecMatmulFp16(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, 317+ bias_input + j * col_align, ActType_No, input_size, hidden_size); 318+ VecMatmulFp16(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, 319+ bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); 320+ } 321+ } 322+ ElementAddFp16(input_gate, hidden_gate, input_gate, double_output_size); 323+ SigmoidFp16(input_gate, input_gate, double_output_size); 324+ ElementMulFp16(input_gate, hidden_gate + double_output_size, input_gate, output_size); 325+ ElementAddFp16(input_gate, input_gate + double_output_size, input_gate, output_size); 326+ TanhFp16(input_gate, input_gate, output_size); 327+ ElementSubFp16(init_h, input_gate, hidden_gate, output_size); 328+ ElementMulFp16(input_gate + output_size, hidden_gate, hidden_gate, output_size); 329+ ElementAddFp16(input_gate, hidden_gate, output, output_size); 330+ init_h = output; 331+ output += output_size; 332+ } 333+} 334+#endif 335diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h 336new file mode 100644 337index 00000000..67008f03 338--- /dev/null 339+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h 340@@ -0,0 +1,32 @@ 341+/** 342+ * Copyright 2023 Huawei Technologies Co., Ltd 343+ * 344+ * Licensed under the Apache License, Version 2.0 (the "License"); 345+ * you may not use this file except in compliance with the License. 346+ * You may obtain a copy of the License at 347+ * 348+ * http://www.apache.org/licenses/LICENSE-2.0 349+ * 350+ * Unless required by applicable law or agreed to in writing, software 351+ * distributed under the License is distributed on an "AS IS" BASIS, 352+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 353+ * See the License for the specific language governing permissions and 354+ * limitations under the License. 355+ */ 356+#ifndef MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_ 357+#define MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_ 358+#ifdef ENABLE_ARM64 359+#include "nnacl/custom_gru_parameter.h" 360+ 361+#ifdef __cplusplus 362+extern "C" { 363+#endif 364+void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, 365+ const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, 366+ const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param); 367+#ifdef __cplusplus 368+} 369+#endif 370+ 371+#endif 372+#endif // MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_ 373diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c 374new file mode 100644 375index 00000000..caeece4a 376--- /dev/null 377+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c 378@@ -0,0 +1,71 @@ 379+#ifdef ENABLE_ARM64 380+/** 381+ * Copyright 2023 Huawei Technologies Co., Ltd 382+ * 383+ * Licensed under the Apache License, Version 2.0 (the "License"); 384+ * you may not use this file except in compliance with the License. 385+ * You may obtain a copy of the License at 386+ * 387+ * http://www.apache.org/licenses/LICENSE-2.0 388+ * 389+ * Unless required by applicable law or agreed to in writing, software 390+ * distributed under the License is distributed on an "AS IS" BASIS, 391+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 392+ * See the License for the specific language governing permissions and 393+ * limitations under the License. 394+ */ 395+#include "nnacl/fp32/custom_gru_fp32.h" 396+#include "nnacl/fp32/activation_fp32.h" 397+#include "nnacl/fp32/arithmetic_fp32.h" 398+#include "nnacl/fp32/matmul_fp32.h" 399+#include "nnacl/fp32/pack_fp32.h" 400+ 401+void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, 402+ const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], 403+ const CustomGruParameter *gru_param) { 404+ int num_step = gru_param->num_step; 405+ int batch_size = gru_param->batch_size; 406+ int input_size = gru_param->input_size; 407+ int hidden_size = gru_param->hidden_size; 408+ int output_size = batch_size * hidden_size; 409+ int double_output_size = output_size * C2NUM; 410+ int col_align = UP_ROUND(hidden_size, C8NUM); 411+ int weight_in_offset = col_align * input_size; 412+ int weight_hidden_offset = col_align * hidden_size; 413+ float *input_gate = buffer[1]; 414+ float *hidden_gate = buffer[C3NUM]; 415+ for (int i = 0; i < num_step; ++i) { 416+ if (batch_size != 1) { 417+ RowMajor2Col12Major(input + i * batch_size * input_size, buffer[0], batch_size, input_size); 418+ for (int j = 0; j < C3NUM; ++j) { 419+ MatMulOpt(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, 420+ bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, 421+ OutType_Nhwc); 422+ } 423+ RowMajor2Col12Major(init_h, buffer[C2NUM], batch_size, hidden_size); 424+ for (int j = 0; j < C3NUM; ++j) { 425+ MatMulOpt(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, 426+ bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, 427+ OutType_Nhwc); 428+ } 429+ } else { 430+ for (int j = 0; j < C3NUM; ++j) { 431+ MatVecMulFp32Neon64(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, 432+ bias_input + j * col_align, ActType_No, input_size, hidden_size, col_align); 433+ MatVecMulFp32Neon64(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, 434+ bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size, col_align); 435+ } 436+ } 437+ ElementAdd(input_gate, hidden_gate, input_gate, double_output_size); 438+ Sigmoid(input_gate, double_output_size, input_gate); 439+ ElementMul(input_gate, hidden_gate + double_output_size, input_gate, output_size); 440+ ElementAdd(input_gate, input_gate + double_output_size, input_gate, output_size); 441+ Tanh(input_gate, output_size, input_gate); 442+ ElementSub(init_h, input_gate, hidden_gate, output_size); 443+ ElementMul(input_gate + output_size, hidden_gate, hidden_gate, output_size); 444+ ElementAdd(input_gate, hidden_gate, output, output_size); 445+ init_h = output; 446+ output += output_size; 447+ } 448+} 449+#endif 450diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h 451new file mode 100644 452index 00000000..576726c5 453--- /dev/null 454+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h 455@@ -0,0 +1,32 @@ 456+/** 457+ * Copyright 2023 Huawei Technologies Co., Ltd 458+ * 459+ * Licensed under the Apache License, Version 2.0 (the "License"); 460+ * you may not use this file except in compliance with the License. 461+ * You may obtain a copy of the License at 462+ * 463+ * http://www.apache.org/licenses/LICENSE-2.0 464+ * 465+ * Unless required by applicable law or agreed to in writing, software 466+ * distributed under the License is distributed on an "AS IS" BASIS, 467+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 468+ * See the License for the specific language governing permissions and 469+ * limitations under the License. 470+ */ 471+#ifndef MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ 472+#define MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ 473+#ifdef ENABLE_ARM64 474+#include "nnacl/custom_gru_parameter.h" 475+ 476+#ifdef __cplusplus 477+extern "C" { 478+#endif 479+void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, 480+ const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], 481+ const CustomGruParameter *gru_param); 482+#ifdef __cplusplus 483+} 484+#endif 485+ 486+#endif 487+#endif // MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ 488diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 489index b608e1e0..8e217d02 100644 490--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 491+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 492@@ -43,7 +43,7 @@ int ElementOptMulAcc(const float *input0, const float input1, float *output, con 493 494 void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, 495 const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state, 496- float *buffer[6], const LstmParameter *lstm_param); 497+ float *buffer[C6NUM], const LstmParameter *lstm_param); 498 499 void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, 500 const float *state_bias, float *hidden_state, float *cell_state, float *buffer[7], 501diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c 502new file mode 100644 503index 00000000..060d04cf 504--- /dev/null 505+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c 506@@ -0,0 +1,45 @@ 507+/** 508+ * Copyright 2023 Huawei Technologies Co., Ltd 509+ * 510+ * Licensed under the Apache License, Version 2.0 (the "License"); 511+ * you may not use this file except in compliance with the License. 512+ * You may obtain a copy of the License at 513+ * 514+ * http://www.apache.org/licenses/LICENSE-2.0 515+ * 516+ * Unless required by applicable law or agreed to in writing, software 517+ * distributed under the License is distributed on an "AS IS" BASIS, 518+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 519+ * See the License for the specific language governing permissions and 520+ * limitations under the License. 521+ */ 522+ 523+#include "nnacl/infer/custom_gru_infer.h" 524+#include "nnacl/infer/infer_register.h" 525+ 526+int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 527+ OpParameter *parameter) { 528+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C6NUM, 1); 529+ if (check_ret != NNACL_OK) { 530+ return check_ret; 531+ } 532+ 533+ const TensorC *input = inputs[0]; 534+ TensorC *output = outputs[0]; 535+ SetDataTypeFormat(output, input); 536+ if (!InferFlag(inputs, inputs_size)) { 537+ return NNACL_INFER_INVALID; 538+ } 539+ if (input->shape_size_ != C3NUM) { 540+ return NNACL_INPUT_TENSOR_ERROR; 541+ } 542+ SetShapeTensor(output, input); 543+ const TensorC *weight_in = inputs[1]; 544+ if (weight_in->shape_size_ != C2NUM || weight_in->shape_[0] % C3NUM != 0) { 545+ return NNACL_INPUT_TENSOR_ERROR; 546+ } 547+ output->shape_[C2NUM] = weight_in[0].shape_[0] / C3NUM; 548+ return NNACL_OK; 549+} 550+ 551+REG_INFER(CustomGru, PrimType_Inner_CustomGru, CustomGruInferShape) 552diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h 553new file mode 100644 554index 00000000..830150d5 555--- /dev/null 556+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h 557@@ -0,0 +1,30 @@ 558+/** 559+ * Copyright 2023 Huawei Technologies Co., Ltd 560+ * 561+ * Licensed under the Apache License, Version 2.0 (the "License"); 562+ * you may not use this file except in compliance with the License. 563+ * You may obtain a copy of the License at 564+ * 565+ * http://www.apache.org/licenses/LICENSE-2.0 566+ * 567+ * Unless required by applicable law or agreed to in writing, software 568+ * distributed under the License is distributed on an "AS IS" BASIS, 569+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 570+ * See the License for the specific language governing permissions and 571+ * limitations under the License. 572+ */ 573+#ifndef MINDSPORE_NNACL_CUSTOM_GRU_INFER_H 574+#define MINDSPORE_NNACL_CUSTOM_GRU_INFER_H 575+#include "nnacl/infer/common_infer.h" 576+ 577+#ifdef __cplusplus 578+extern "C" { 579+#endif 580+ 581+int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 582+ OpParameter *parameter); 583+ 584+#ifdef __cplusplus 585+} 586+#endif 587+#endif // MINDSPORE_NNACL_CUSTOM_GRU_INFER_H 588diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 589index 5876bdf6..8c219212 100644 590--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 591+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 592@@ -520,6 +520,7 @@ enum PrimType { 593 PrimType_Inner_ShapeFusion = 10003, 594 PrimType_Inner_GraphKernel = 10004, 595 PrimType_Inner_ThirdPartyModel = 10005, 596+ PrimType_Inner_CustomGru = 10006, 597 PrimType_InnerOpMax, 598 PrimType_InnerOpMin = PrimType_Inner_ToFormat 599 }; 600diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h 601index dd6e6d08..0d2de256 100644 602--- a/mindspore/lite/include/registry/converter_context.h 603+++ b/mindspore/lite/include/registry/converter_context.h 604@@ -34,7 +34,8 @@ enum MS_API FmkType : int { 605 kFmkTypeTflite = 4, 606 kFmkTypePytorch = 5, 607 kFmkTypeThirdParty = 6, 608- kFmkTypeEnd = 7, // For range check purpose, valid range: [0, kFmkTypeEnd) 609+ kFmkTypeMsLite = 7, 610+ kFmkTypeEnd = 8, // For range check purpose, valid range: [0, kFmkTypeEnd) 611 }; 612 613 /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. 614diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt 615index 48e0fe7c..a80656f2 100644 616--- a/mindspore/lite/src/CMakeLists.txt 617+++ b/mindspore/lite/src/CMakeLists.txt 618@@ -380,6 +380,7 @@ set(TRAIN_SRC 619 ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc 620 ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc 621 ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc 622+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/gru_fusion_pass.cc 623 ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc 624 ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 625 ${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc 626diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc 627index 1688fb79..d5cb2a98 100644 628--- a/mindspore/lite/src/common/graph_util.cc 629+++ b/mindspore/lite/src/common/graph_util.cc 630@@ -23,6 +23,7 @@ 631 #include "src/common/log_adapter.h" 632 #include "src/common/version_manager.h" 633 #include "include/errorcode.h" 634+#include "nnacl/op_base.h" 635 636 namespace mindspore { 637 namespace lite { 638@@ -86,9 +87,9 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor 639 640 // only support op_type from current schema 641 bool IsPackedOp(int op_type) { 642- static const std::vector<int> packed_ops = {schema::PrimitiveType_Conv2DFusion, 643- schema::PrimitiveType_Conv2dTransposeFusion, 644- schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion}; 645+ static const std::vector<int> packed_ops = { 646+ schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, 647+ schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion, PrimType::PrimType_Inner_CustomGru}; 648 return IsContain(packed_ops, op_type); 649 } 650 651diff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc 652index f1506ece..391a587b 100644 653--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc 654+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc 655@@ -17,10 +17,22 @@ 656 #include "src/common/log_adapter.h" 657 #include "src/tensor.h" 658 #include "nnacl/custom_parameter.h" 659+#include "nnacl/custom_gru_parameter.h" 660 using mindspore::schema::PrimitiveType_Custom; 661 662 namespace mindspore { 663 namespace lite { 664+OpParameter *CreateCustomGruParameter() { 665+ auto *param = static_cast<CustomGruParameter *>(malloc(sizeof(CustomGruParameter))); 666+ if (param == nullptr) { 667+ MS_LOG(ERROR) << "malloc CustomGruParameter failed."; 668+ return nullptr; 669+ } 670+ memset(param, 0, sizeof(CustomGruParameter)); 671+ param->op_parameter_.type_ = PrimType_Inner_CustomGru; 672+ return reinterpret_cast<OpParameter *>(param); 673+} 674+ 675 OpParameter *PopulateCustomParameter(const void *prim) { 676 MS_CHECK_TRUE_RET(prim != nullptr, nullptr); 677 auto primitive = static_cast<const schema::Primitive *>(prim); 678@@ -62,6 +74,8 @@ OpParameter *PopulateCustomParameter(const void *prim) { 679 // Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary. 680 param->attr_data[0] = static_cast<char *>(const_cast<void *>(prim)); 681 return reinterpret_cast<OpParameter *>(param); 682+ } else if (type == "CustomGRU") { 683+ return CreateCustomGruParameter(); 684 } else { 685 MS_LOG(ERROR) << "Unsupported custom type: " << type; 686 } 687diff --git a/mindspore/lite/src/runtime/cxx_api/serialization.cc b/mindspore/lite/src/runtime/cxx_api/serialization.cc 688index 8405f4b2..ddf69d23 100644 689--- a/mindspore/lite/src/runtime/cxx_api/serialization.cc 690+++ b/mindspore/lite/src/runtime/cxx_api/serialization.cc 691@@ -212,4 +212,35 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons 692 693 return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; 694 } 695+ 696+Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, 697+ const std::vector<char> &weight_file, bool is_inference, 698+ bool enable_fp16, 699+ const std::vector<std::vector<char>> &changeable_weights_name) { 700+ if (model.impl_ == nullptr) { 701+ MS_LOG(ERROR) << "Model implement is null."; 702+ return kLiteUninitializedObj; 703+ } 704+ if (!model.impl_->IsTrainModel()) { 705+ MS_LOG(ERROR) << "Model is not TrainModel."; 706+ return kLiteError; 707+ } 708+ if (model_type != kMindIR && model_type != kMindIR_Lite) { 709+ MS_LOG(ERROR) << "Unsupported Export Format " << model_type; 710+ return kLiteParamInvalid; 711+ } 712+ if (model.impl_->session_ == nullptr) { 713+ MS_LOG(ERROR) << "Model session is nullptr."; 714+ return kLiteError; 715+ } 716+ if (!is_inference) { 717+ MS_LOG(ERROR) << "Currently, can only export inference-model's weights."; 718+ return kLiteNotSupport; 719+ } 720+ auto ret = model.impl_->session_->ExportWeightsCollaborateWithMicro(CharToString(weight_file), lite::MT_INFERENCE, 721+ lite::FT_FLATBUFFERS, enable_fp16, 722+ VectorCharToString(changeable_weights_name)); 723+ 724+ return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; 725+} 726 } // namespace mindspore 727diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc 728index b5370ddd..0352ad19 100644 729--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc 730+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc 731@@ -35,6 +35,8 @@ int GroupConvolutionBaseCPUKernel::Prepare() { 732 return ret; 733 } 734 } 735+ conv_param_->input_channel_ *= group_num_; 736+ conv_param_->output_channel_ *= group_num_; 737 // if infer shape is done, resize func will be invoked in sub kernels 738 return RET_OK; 739 } 740@@ -99,11 +101,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { 741 auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front(); 742 CHECK_NULL_RETURN(sub_kernel_in_tensor); 743 sub_kernel_in_tensor->set_shape(in_shape); 744- ret = sub_kernel_in_tensor->MallocData(); 745- if (ret != RET_OK) { 746- MS_LOG(ERROR) << "sub kernel in tensor malloc data failed."; 747- return ret; 748- } 749 // out 750 auto out_tensor = out_tensors_.front(); 751 CHECK_NULL_RETURN(out_tensor); 752@@ -113,11 +110,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { 753 for (auto tensor : sub_kernel_out_tensors) { 754 CHECK_NULL_RETURN(tensor); 755 tensor->set_shape(out_shape); 756- ret = tensor->MallocData(); 757- if (ret != RET_OK) { 758- MS_LOG(ERROR) << "sub kernel out tensor malloc data failed."; 759- return ret; 760- } 761 } 762 } 763 ret = ReSize(); 764@@ -177,7 +169,22 @@ int GroupConvolutionBaseCPUKernel::Run() { 765 ori_out_data_ = out_tensors_[0]->data(); 766 CHECK_NULL_RETURN(ori_out_data_); 767 for (int i = 0; i < group_num_; ++i) { 768- // first, separate group conv input into several parts. This step must be in runtime stage. 769+ // first, malloc data for sub_kernel's tensors 770+ auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front(); 771+ ret = sub_kernel_in_tensor->MallocData(); 772+ if (ret != RET_OK) { 773+ MS_LOG(ERROR) << "sub kernel in tensor malloc data failed."; 774+ return ret; 775+ } 776+ auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors(); 777+ for (auto tensor : sub_kernel_out_tensors) { 778+ ret = tensor->MallocData(); 779+ if (ret != RET_OK) { 780+ MS_LOG(ERROR) << "sub kernel out tensor malloc data failed."; 781+ return ret; 782+ } 783+ } 784+ // second, separate group conv input into several parts. This step must be in runtime stage. 785 ret = SeparateInput(i); 786 if (ret != RET_OK) { 787 MS_LOG(ERROR) << "Separate input failed."; 788@@ -195,6 +202,11 @@ int GroupConvolutionBaseCPUKernel::Run() { 789 MS_LOG(ERROR) << "Concat output failed."; 790 return ret; 791 } 792+ // free data 793+ sub_kernel_in_tensor->FreeData(); 794+ for (auto tensor : sub_kernel_out_tensors) { 795+ tensor->FreeData(); 796+ } 797 } 798 return RET_OK; 799 } 800diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc 801index fc78a887..81b0aac2 100644 802--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc 803+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc 804@@ -53,15 +53,6 @@ void FreeCurrentConv(ConvParameter *conv_param, std::vector<lite::Tensor *> *new 805 } 806 } 807 808-static inline lite::Tensor *TensorMalloc(lite::Tensor *tensor) { 809- if (tensor->MallocData() != lite::RET_OK) { 810- delete tensor; 811- MS_LOG(ERROR) << "malloc tensor data failed."; 812- return nullptr; 813- } 814- return tensor; 815-} 816- 817 lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector<int> &shape, const int index) { 818 auto new_tensor = 819 new (std::nothrow) lite::Tensor(tensor->data_type(), shape, mindspore::NHWC, lite::Category::CONST_TENSOR); 820@@ -87,7 +78,7 @@ lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector<in 821 return new_tensor; 822 } 823 824-lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) { 825+lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info) { 826 auto tensor = new (std::nothrow) lite::Tensor(); 827 if (tensor == nullptr) { 828 MS_LOG(ERROR) << "new tensor failed."; 829@@ -97,11 +88,7 @@ lite::Tensor *CreateVarTensor(const TensorInfo &tensor_info, bool inferred) { 830 tensor->set_format(tensor_info.format_); 831 tensor->set_category(tensor_info.tensor_type_); 832 tensor->set_shape(tensor_info.shape_); 833- 834- if (inferred) { 835- // set shape of out tensor 836- return TensorMalloc(tensor); 837- } 838+ tensor->set_allocator(tensor_info.allocator_); 839 return tensor; 840 } 841 842@@ -129,7 +116,8 @@ void GroupConvCreator::FreeGroupConvs() { 843 } 844 845 int GroupConvCreator::NewInputTensor(std::vector<lite::Tensor *> *tensors) { 846- auto in_tensor = CreateVarTensor({input_shape_, mindspore::NHWC, data_type_, lite::Category::VAR, true}, infered_); 847+ auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr; 848+ auto in_tensor = CreateVarTensor({input_shape_, allocator, mindspore::NHWC, data_type_, lite::Category::VAR, true}); 849 if (in_tensor == nullptr) { 850 return lite::RET_ERROR; 851 } 852@@ -138,7 +126,9 @@ int GroupConvCreator::NewInputTensor(std::vector<lite::Tensor *> *tensors) { 853 } 854 855 int GroupConvCreator::NewOutputTensor(std::vector<lite::Tensor *> *tensors, const lite::Tensor *output) const { 856- auto out_tensor = CreateVarTensor({output_shape_, output->format(), data_type_, output->category(), false}, infered_); 857+ auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr; 858+ auto out_tensor = 859+ CreateVarTensor({output_shape_, allocator, output->format(), data_type_, output->category(), false}); 860 if (out_tensor == nullptr) { 861 return lite::RET_ERROR; 862 } 863diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h 864index b4a4f768..27aa0cc8 100644 865--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h 866+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h 867@@ -22,10 +22,12 @@ 868 #include "src/runtime/lite_kernel.h" 869 #include "nnacl/conv_parameter.h" 870 #include "src/runtime/tensor_category.h" 871+#include "include/api/allocator.h" 872 873 namespace mindspore::kernel { 874 struct TensorInfo { 875 std::vector<int> shape_; 876+ AllocatorPtr allocator_; 877 mindspore::Format format_; 878 TypeId data_type_; 879 lite::Category tensor_type_; 880@@ -35,8 +37,9 @@ struct TensorInfo { 881 class GroupConvCreator { 882 public: 883 GroupConvCreator(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs, OpParameter *op_parameter, 884- bool is_quant, TypeId data_type) 885- : origin_inputs_(std::move(inputs)), 886+ bool is_quant, TypeId data_type, const lite::InnerContext *ctx) 887+ : ms_context_(ctx), 888+ origin_inputs_(std::move(inputs)), 889 origin_outputs_(std::move(outputs)), 890 is_quant_(is_quant), 891 data_type_(data_type) { 892@@ -64,6 +67,7 @@ class GroupConvCreator { 893 int NewOutputTensor(std::vector<lite::Tensor *> *tensors, const lite::Tensor *output) const; 894 895 private: 896+ const lite::InnerContext *ms_context_ = nullptr; 897 std::vector<lite::Tensor *> origin_inputs_; 898 std::vector<lite::Tensor *> origin_outputs_; 899 std::vector<kernel::LiteKernel *> group_convs_; 900diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc 901index 7aa823b0..17ba38ff 100644 902--- a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc 903+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc 904@@ -202,7 +202,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector<lite::Tensor 905 const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, 906 const InnerContext *ctx) { 907 auto *group_conv_creator = 908- new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16); 909+ new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16, ctx); 910 if (group_conv_creator == nullptr) { 911 MS_LOG(ERROR) << "new GroupConvCreator fail"; 912 free(op_parameter); 913diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc 914new file mode 100644 915index 00000000..7851eecb 916--- /dev/null 917+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc 918@@ -0,0 +1,132 @@ 919+#ifdef ENABLE_ARM64 920+/** 921+ * Copyright 2023 Huawei Technologies Co., Ltd 922+ * 923+ * Licensed under the Apache License, Version 2.0 (the "License"); 924+ * you may not use this file except in compliance with the License. 925+ * You may obtain a copy of the License at 926+ * 927+ * http://www.apache.org/licenses/LICENSE-2.0 928+ * 929+ * Unless required by applicable law or agreed to in writing, software 930+ * distributed under the License is distributed on an "AS IS" BASIS, 931+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 932+ * See the License for the specific language governing permissions and 933+ * limitations under the License. 934+ */ 935+#include "src/runtime/kernel/cpu/fp16/custom_gru_fp16.h" 936+#include <algorithm> 937+#include "src/runtime/kernel_registry.h" 938+#include "include/errorcode.h" 939+#include "src/common/log_adapter.h" 940+#include "src/runtime/pack_weight_manager.h" 941+#include "nnacl/custom_gru_parameter.h" 942+#include "nnacl/fp16/custom_gru_fp16.h" 943+#include "nnacl/fp16/matmul_fp16.h" 944+ 945+using mindspore::lite::KernelRegistrar; 946+using mindspore::lite::RET_ERROR; 947+using mindspore::lite::RET_NOT_SUPPORT; 948+using mindspore::lite::RET_OK; 949+ 950+namespace mindspore::kernel { 951+int CustomGruFp16CPUKernel::InitWeightAndBias() { 952+ auto weight_shape = in_tensors_[1]->shape(); 953+ auto hidden_size = weight_shape[0] / C3NUM; 954+ auto col_align = UP_ROUND(hidden_size, col_tile_); 955+ auto weight_in_pack_size = static_cast<size_t>(col_align * weight_shape[1]) * sizeof(float16_t); 956+ bool is_packed = false; 957+ weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData( 958+ in_tensors_[SECOND_INPUT]->data(), static_cast<size_t>(weight_in_pack_size * C3NUM), &is_packed); 959+ MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed."); 960+ if (!is_packed) { 961+ auto weight_in_src = static_cast<const float16_t *>(in_tensors_[SECOND_INPUT]->data()); 962+ for (int i = 0; i < C3NUM; ++i) { 963+ RowMajor2Col8MajorFp16(weight_in_src + i * hidden_size * weight_shape[1], 964+ static_cast<float16_t *>(weight_in_) + i * col_align * weight_shape[1], hidden_size, 965+ weight_shape[1], false); 966+ } 967+ } 968+ auto weight_hidden_pack_size = static_cast<size_t>(col_align * hidden_size) * sizeof(float16_t); 969+ is_packed = false; 970+ weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData( 971+ in_tensors_[THIRD_INPUT]->data(), static_cast<size_t>(weight_hidden_pack_size * C3NUM), &is_packed); 972+ MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed."); 973+ if (!is_packed) { 974+ auto weight_hidden_src = static_cast<const float16_t *>(in_tensors_[THIRD_INPUT]->data()); 975+ for (int i = 0; i < C3NUM; ++i) { 976+ RowMajor2Col8MajorFp16(weight_hidden_src + i * hidden_size * weight_shape[1], 977+ static_cast<float16_t *>(weight_hidden_) + i * col_align * weight_shape[1], hidden_size, 978+ hidden_size, false); 979+ } 980+ } 981+ auto bias_pack_size = static_cast<size_t>(col_align) * sizeof(float16_t); 982+ auto bias = reinterpret_cast<float16_t *>(malloc(bias_pack_size * C6NUM)); 983+ if (bias == nullptr) { 984+ MS_LOG(ERROR) << "malloc for packing bias failed."; 985+ return lite::RET_NULL_PTR; 986+ } 987+ (void)memset(bias, 0, bias_pack_size * C6NUM); 988+ bias_in_ = bias; 989+ bias_hidden_ = bias + col_align * C3NUM; 990+ auto bias_in_src = static_cast<const float16_t *>(in_tensors_[FOURTH_INPUT]->data()); 991+ for (int i = 0; i < C3NUM; ++i) { 992+ (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float16_t)); 993+ } 994+ auto bias_hidden_src = static_cast<const float16_t *>(in_tensors_[FIFTH_INPUT]->data()); 995+ for (int i = 0; i < C3NUM; ++i) { 996+ (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float16_t)); 997+ } 998+ if (in_tensors_[SIXTH_INPUT]->IsConst()) { 999+ init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size()); 1000+ MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed."); 1001+ (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size()); 1002+ } 1003+ return RET_OK; 1004+} 1005+ 1006+int CustomGruFp16CPUKernel::Run() { 1007+ auto input = reinterpret_cast<float16_t *>(in_tensors_[FIRST_INPUT]->data()); 1008+ if (input == nullptr) { 1009+ MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_; 1010+ return lite::RET_NULL_PTR; 1011+ } 1012+ if (!in_tensors_[SIXTH_INPUT]->IsConst()) { 1013+ init_h_ = in_tensors_[SIXTH_INPUT]->data(); 1014+ } 1015+ if (init_h_ == nullptr) { 1016+ MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_; 1017+ return lite::RET_NULL_PTR; 1018+ } 1019+ auto output = reinterpret_cast<float16_t *>(out_tensors_.front()->data()); 1020+ if (output == nullptr) { 1021+ MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_; 1022+ return lite::RET_NULL_PTR; 1023+ } 1024+ MallocRunBuffer(sizeof(float16_t)); 1025+ if (run_buffer_ == nullptr) { 1026+ MS_LOG(ERROR) << "malloc running buffer failed." << name_; 1027+ return lite::RET_NULL_PTR; 1028+ } 1029+ auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_); 1030+ auto row_align = UP_ROUND(param->batch_size, row_tile_); 1031+ auto run_buffer = reinterpret_cast<float16_t *>(run_buffer_); 1032+ float16_t *buffer[C4NUM] = { 1033+ run_buffer, run_buffer + row_align * param->input_size, 1034+ run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM, 1035+ run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM}; 1036+ CustomGruFp16(output, input, static_cast<float16_t *>(weight_in_), static_cast<float16_t *>(weight_hidden_), 1037+ static_cast<float16_t *>(bias_in_), static_cast<float16_t *>(bias_hidden_), 1038+ static_cast<float16_t *>(init_h_), buffer, param); 1039+ if (ms_context_->allocator != nullptr) { 1040+ ms_context_->allocator->Free(run_buffer_); 1041+ } else { 1042+ free(run_buffer_); 1043+ } 1044+ run_buffer_ = nullptr; 1045+ return RET_OK; 1046+} 1047+ 1048+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimType_Inner_CustomGru, LiteKernelCreator<CustomGruFp16CPUKernel>) 1049+} // namespace mindspore::kernel 1050+#endif 1051diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h 1052new file mode 100644 1053index 00000000..d7ed313a 1054--- /dev/null 1055+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h 1056@@ -0,0 +1,40 @@ 1057+/** 1058+ * Copyright 2023 Huawei Technologies Co., Ltd 1059+ * 1060+ * Licensed under the Apache License, Version 2.0 (the "License"); 1061+ * you may not use this file except in compliance with the License. 1062+ * You may obtain a copy of the License at 1063+ * 1064+ * http://www.apache.org/licenses/LICENSE-2.0 1065+ * 1066+ * Unless required by applicable law or agreed to in writing, software 1067+ * distributed under the License is distributed on an "AS IS" BASIS, 1068+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1069+ * See the License for the specific language governing permissions and 1070+ * limitations under the License. 1071+ */ 1072+ 1073+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_ 1074+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_ 1075+#ifdef ENABLE_ARM64 1076+#include <vector> 1077+#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h" 1078+ 1079+namespace mindspore::kernel { 1080+class CustomGruFp16CPUKernel : public CustomGruCPUKernel { 1081+ public: 1082+ CustomGruFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 1083+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 1084+ : CustomGruCPUKernel(parameter, inputs, outputs, ctx) { 1085+ row_tile_ = C4NUM; 1086+ col_tile_ = C8NUM; 1087+ } 1088+ ~CustomGruFp16CPUKernel() override = default; 1089+ int Run() override; 1090+ 1091+ protected: 1092+ int InitWeightAndBias() override; 1093+}; 1094+} // namespace mindspore::kernel 1095+#endif 1096+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_ 1097diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc 1098index bbbf8488..3514e5b4 100644 1099--- a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc 1100+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc 1101@@ -359,7 +359,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *> 1102 kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, 1103 const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter, 1104 const lite::InnerContext *ctx) { 1105- auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32); 1106+ auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32, ctx); 1107 auto group_kernel = new (std::nothrow) GroupConvolutionFp32CPUKernel( 1108 op_parameter, inputs, outputs, ctx, group_conv_creator, reinterpret_cast<ConvParameter *>(op_parameter)->group_); 1109 if (group_kernel == nullptr) { 1110diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc 1111new file mode 100644 1112index 00000000..c85a1283 1113--- /dev/null 1114+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc 1115@@ -0,0 +1,251 @@ 1116+#ifdef ENABLE_ARM64 1117+/** 1118+ * Copyright 2023 Huawei Technologies Co., Ltd 1119+ * 1120+ * Licensed under the Apache License, Version 2.0 (the "License"); 1121+ * you may not use this file except in compliance with the License. 1122+ * You may obtain a copy of the License at 1123+ * 1124+ * http://www.apache.org/licenses/LICENSE-2.0 1125+ * 1126+ * Unless required by applicable law or agreed to in writing, software 1127+ * distributed under the License is distributed on an "AS IS" BASIS, 1128+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1129+ * See the License for the specific language governing permissions and 1130+ * limitations under the License. 1131+ */ 1132+#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h" 1133+#include <algorithm> 1134+#include "src/runtime//kernel_registry.h" 1135+#include "include/errorcode.h" 1136+#include "src/common/log_adapter.h" 1137+#include "src/runtime//pack_weight_manager.h" 1138+#include "nnacl/fp32/pack_fp32.h" 1139+#include "nnacl/custom_gru_parameter.h" 1140+#include "nnacl/fp32/custom_gru_fp32.h" 1141+ 1142+using mindspore::lite::KernelRegistrar; 1143+using mindspore::lite::RET_ERROR; 1144+using mindspore::lite::RET_NOT_SUPPORT; 1145+using mindspore::lite::RET_OK; 1146+ 1147+namespace mindspore::kernel { 1148+CustomGruCPUKernel::~CustomGruCPUKernel() { 1149+ if (weight_in_) { 1150+ lite::PackWeightManager::GetInstance()->Free(weight_in_); 1151+ weight_in_ = nullptr; 1152+ } 1153+ if (weight_hidden_) { 1154+ lite::PackWeightManager::GetInstance()->Free(weight_hidden_); 1155+ weight_hidden_ = nullptr; 1156+ } 1157+ if (bias_in_) { 1158+ free(bias_in_); 1159+ bias_in_ = nullptr; 1160+ bias_hidden_ = nullptr; 1161+ } 1162+ if (in_tensors_[SIXTH_INPUT]->IsConst() && init_h_) { 1163+ free(init_h_); 1164+ init_h_ = nullptr; 1165+ } 1166+} 1167+ 1168+int CustomGruCPUKernel::Prepare() { 1169+ CHECK_LESS_RETURN(in_tensors_.size(), C6NUM); 1170+ CHECK_LESS_RETURN(out_tensors_.size(), 1); 1171+ if (in_tensors_[FIRST_INPUT]->IsConst()) { 1172+ MS_LOG(ERROR) << "Built-in CustomGru first-input must be a variable." << name_; 1173+ return RET_NOT_SUPPORT; 1174+ } 1175+ for (size_t i = 1; i < C5NUM; ++i) { 1176+ if (!in_tensors_[i]->IsConst()) { 1177+ MS_LOG(ERROR) << "Built-in CustomGru only support first-input and last-input is variable." << name_; 1178+ return RET_NOT_SUPPORT; 1179+ } 1180+ } 1181+ if (InitParamter() != RET_OK) { 1182+ MS_LOG(ERROR) << "Init Built-in CustomGru Parameter failed." << name_; 1183+ return RET_ERROR; 1184+ } 1185+ if (InitWeightAndBias() != RET_OK) { 1186+ MS_LOG(ERROR) << "Init Built-in CustomGru Weight and bias failed." << name_; 1187+ return RET_ERROR; 1188+ } 1189+ if (!InferShapeDone()) { 1190+ return RET_OK; 1191+ } 1192+ return ReSize(); 1193+} 1194+ 1195+int CustomGruCPUKernel::InitParamter() { 1196+ auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_); 1197+ thread_num_ = 1; 1198+ auto weight_in_shape = in_tensors_[1]->shape(); 1199+ auto weight_hidden_shape = in_tensors_[C2NUM]->shape(); 1200+ if (weight_in_shape.size() != C2NUM || weight_hidden_shape.size() != C2NUM) { 1201+ MS_LOG(ERROR) << "Built-in CustomGru's weight must be 2D." << name_; 1202+ return RET_ERROR; 1203+ } 1204+ if (weight_in_shape[0] != weight_hidden_shape[0]) { 1205+ MS_LOG(ERROR) << "Built-in CustomGru's weight-in and weight-hidden first-dim must be same." << name_; 1206+ return RET_ERROR; 1207+ } 1208+ if (weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) { 1209+ MS_LOG(ERROR) << "Built-in CustomGru's weight-hidden first-dim must be 3 * second-dim." << name_; 1210+ return RET_ERROR; 1211+ } 1212+ auto bias_in_shape = in_tensors_[C3NUM]->shape(); 1213+ auto bias_hidden_shape = in_tensors_[C4NUM]->shape(); 1214+ if (bias_in_shape.size() != 1) { 1215+ MS_LOG(ERROR) << "Built-in CustomGru's bias must be 1D." << name_; 1216+ return RET_ERROR; 1217+ } 1218+ if (bias_in_shape != bias_hidden_shape) { 1219+ MS_LOG(ERROR) << "Built-in CustomGru's bias-in and bias-hidden must have same shape." << name_; 1220+ return RET_ERROR; 1221+ } 1222+ if (bias_in_shape.back() != weight_in_shape.front()) { 1223+ MS_LOG(ERROR) << "Built-in CustomGru's bias-in shape don't match with the first-dim of weight." << name_; 1224+ return RET_ERROR; 1225+ } 1226+ if (bias_in_shape.front() % C3NUM != 0) { 1227+ MS_LOG(ERROR) << "The first-dim of CustomGru's weight must be 3 * hidden."; 1228+ return RET_ERROR; 1229+ } 1230+ param->input_size = weight_in_shape.back(); 1231+ param->hidden_size = bias_in_shape.front() / C3NUM; 1232+ return RET_OK; 1233+} 1234+ 1235+int CustomGruCPUKernel::InitWeightAndBias() { 1236+ auto weight_shape = in_tensors_[1]->shape(); 1237+ auto hidden_size = weight_shape[0] / C3NUM; 1238+ auto col_align = UP_ROUND(hidden_size, col_tile_); 1239+ auto weight_in_pack_size = static_cast<size_t>(col_align * weight_shape[1]) * sizeof(float); 1240+ bool is_packed = false; 1241+ weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData( 1242+ in_tensors_[SECOND_INPUT]->data(), static_cast<size_t>(weight_in_pack_size * C3NUM), &is_packed); 1243+ MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed."); 1244+ if (!is_packed) { 1245+ auto weight_in_src = static_cast<const float *>(in_tensors_[SECOND_INPUT]->data()); 1246+ for (int i = 0; i < C3NUM; ++i) { 1247+ RowMajor2Col8Major(weight_in_src + i * hidden_size * weight_shape[1], 1248+ static_cast<float *>(weight_in_) + i * col_align * weight_shape[1], hidden_size, 1249+ weight_shape[1]); 1250+ } 1251+ } 1252+ auto weight_hidden_pack_size = static_cast<size_t>(col_align * hidden_size) * sizeof(float); 1253+ is_packed = false; 1254+ weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData( 1255+ in_tensors_[THIRD_INPUT]->data(), static_cast<size_t>(weight_hidden_pack_size * C3NUM), &is_packed); 1256+ MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed."); 1257+ if (!is_packed) { 1258+ auto weight_hidden_src = static_cast<const float *>(in_tensors_[THIRD_INPUT]->data()); 1259+ for (int i = 0; i < C3NUM; ++i) { 1260+ RowMajor2Col8Major(weight_hidden_src + i * hidden_size * weight_shape[1], 1261+ static_cast<float *>(weight_hidden_) + i * col_align * weight_shape[1], hidden_size, 1262+ hidden_size); 1263+ } 1264+ } 1265+ auto bias_pack_size = static_cast<size_t>(col_align) * sizeof(float); 1266+ auto bias = reinterpret_cast<float *>(malloc(bias_pack_size * C6NUM)); 1267+ if (bias == nullptr) { 1268+ MS_LOG(ERROR) << "malloc for packing bias failed."; 1269+ return lite::RET_NULL_PTR; 1270+ } 1271+ (void)memset(bias, 0, bias_pack_size * C6NUM); 1272+ bias_in_ = bias; 1273+ bias_hidden_ = bias + col_align * C3NUM; 1274+ auto bias_in_src = static_cast<const float *>(in_tensors_[FOURTH_INPUT]->data()); 1275+ for (int i = 0; i < C3NUM; ++i) { 1276+ (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float)); 1277+ } 1278+ auto bias_hidden_src = static_cast<const float *>(in_tensors_[FIFTH_INPUT]->data()); 1279+ for (int i = 0; i < C3NUM; ++i) { 1280+ (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float)); 1281+ } 1282+ if (in_tensors_[SIXTH_INPUT]->IsConst()) { 1283+ init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size()); 1284+ MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed."); 1285+ (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size()); 1286+ } 1287+ return RET_OK; 1288+} 1289+ 1290+int CustomGruCPUKernel::ReSize() { 1291+ auto in_shape = in_tensors_.front()->shape(); 1292+ if (in_shape.size() != C3NUM) { 1293+ MS_LOG(ERROR) << "Built-in CustomGru's first-input must be 3D." << name_; 1294+ return RET_ERROR; 1295+ } 1296+ auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_); 1297+ param->num_step = in_shape[0]; 1298+ param->batch_size = in_shape[1]; 1299+ if (in_shape.back() != param->input_size) { 1300+ MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input don't match its weight." << name_; 1301+ return RET_ERROR; 1302+ } 1303+ return RET_OK; 1304+} 1305+ 1306+void CustomGruCPUKernel::MallocRunBuffer(size_t data_type_size) { 1307+ if (run_buffer_ != nullptr) { 1308+ return; 1309+ } 1310+ auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_); 1311+ auto row_align = UP_ROUND(param->batch_size, row_tile_); 1312+ auto run_buffer_size = 1313+ (row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C6NUM) * 1314+ data_type_size; 1315+ if (ms_context_->allocator != nullptr) { 1316+ run_buffer_ = ms_context_->allocator->Malloc(run_buffer_size); 1317+ } else { 1318+ run_buffer_ = malloc(run_buffer_size); 1319+ } 1320+} 1321+ 1322+int CustomGruCPUKernel::Run() { 1323+ auto input = reinterpret_cast<float *>(in_tensors_[FIRST_INPUT]->data()); 1324+ if (input == nullptr) { 1325+ MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_; 1326+ return lite::RET_NULL_PTR; 1327+ } 1328+ if (!in_tensors_[SIXTH_INPUT]->IsConst()) { 1329+ init_h_ = in_tensors_[SIXTH_INPUT]->data(); 1330+ } 1331+ if (init_h_ == nullptr) { 1332+ MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_; 1333+ return lite::RET_NULL_PTR; 1334+ } 1335+ auto output = reinterpret_cast<float *>(out_tensors_.front()->data()); 1336+ if (output == nullptr) { 1337+ MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_; 1338+ return lite::RET_NULL_PTR; 1339+ } 1340+ MallocRunBuffer(sizeof(float)); 1341+ if (run_buffer_ == nullptr) { 1342+ MS_LOG(ERROR) << "malloc running buffer failed." << name_; 1343+ return lite::RET_NULL_PTR; 1344+ } 1345+ auto param = reinterpret_cast<CustomGruParameter *>(op_parameter_); 1346+ auto row_align = UP_ROUND(param->batch_size, row_tile_); 1347+ auto run_buffer = reinterpret_cast<float *>(run_buffer_); 1348+ float *buffer[C4NUM] = { 1349+ run_buffer, run_buffer + row_align * param->input_size, 1350+ run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM, 1351+ run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM}; 1352+ CustomGru(output, input, static_cast<float *>(weight_in_), static_cast<float *>(weight_hidden_), 1353+ static_cast<float *>(bias_in_), static_cast<float *>(bias_hidden_), static_cast<float *>(init_h_), buffer, 1354+ param); 1355+ if (ms_context_->allocator != nullptr) { 1356+ ms_context_->allocator->Free(run_buffer_); 1357+ } else { 1358+ free(run_buffer_); 1359+ } 1360+ run_buffer_ = nullptr; 1361+ return RET_OK; 1362+} 1363+ 1364+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomGru, LiteKernelCreator<CustomGruCPUKernel>) 1365+} // namespace mindspore::kernel 1366+#endif 1367diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h 1368new file mode 100644 1369index 00000000..e70e408f 1370--- /dev/null 1371+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h 1372@@ -0,0 +1,51 @@ 1373+/** 1374+ * Copyright 2023 Huawei Technologies Co., Ltd 1375+ * 1376+ * Licensed under the Apache License, Version 2.0 (the "License"); 1377+ * you may not use this file except in compliance with the License. 1378+ * You may obtain a copy of the License at 1379+ * 1380+ * http://www.apache.org/licenses/LICENSE-2.0 1381+ * 1382+ * Unless required by applicable law or agreed to in writing, software 1383+ * distributed under the License is distributed on an "AS IS" BASIS, 1384+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1385+ * See the License for the specific language governing permissions and 1386+ * limitations under the License. 1387+ */ 1388+ 1389+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_ 1390+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_ 1391+#ifdef ENABLE_ARM64 1392+#include <vector> 1393+#include "src/runtime/lite_kernel.h" 1394+ 1395+namespace mindspore::kernel { 1396+class CustomGruCPUKernel : public LiteKernel { 1397+ public: 1398+ CustomGruCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 1399+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 1400+ : LiteKernel(parameter, inputs, outputs, ctx) {} 1401+ ~CustomGruCPUKernel() override; 1402+ int Prepare() override; 1403+ int ReSize() override; 1404+ int Run() override; 1405+ 1406+ private: 1407+ int InitParamter(); 1408+ 1409+ protected: 1410+ void MallocRunBuffer(size_t data_type_size); 1411+ virtual int InitWeightAndBias(); 1412+ int row_tile_{C12NUM}; 1413+ int col_tile_{C8NUM}; 1414+ void *weight_in_{nullptr}; 1415+ void *weight_hidden_{nullptr}; 1416+ void *bias_in_{nullptr}; 1417+ void *bias_hidden_{nullptr}; 1418+ void *init_h_{nullptr}; 1419+ void *run_buffer_{nullptr}; 1420+}; 1421+} // namespace mindspore::kernel 1422+#endif 1423+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_ 1424diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc 1425index ef03171a..0ff780c7 100644 1426--- a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc 1427+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc 1428@@ -107,7 +107,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vector<lite::Tensor 1429 << conv_param->input_channel_; 1430 return nullptr; 1431 } 1432- auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8); 1433+ auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8, ctx); 1434 if (group_conv_creator == nullptr) { 1435 MS_LOG(ERROR) << "group_conv_creator is nullptr."; 1436 return nullptr; 1437diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h 1438index 2fdb1eb7..d5a672bb 100644 1439--- a/mindspore/lite/src/runtime/lite_session.h 1440+++ b/mindspore/lite/src/runtime/lite_session.h 1441@@ -106,6 +106,12 @@ class MS_API LiteSession { 1442 std::vector<std::string> out_put_tensor_name = {}) { 1443 return mindspore::lite::RET_ERROR; 1444 } 1445+ virtual int ExportWeightsCollaborateWithMicro(const std::string &file_name, 1446+ lite::ModelType model_type = lite::MT_TRAIN, 1447+ lite::FormatType = lite::FT_FLATBUFFERS, bool enable_fp16 = false, 1448+ const std::vector<std::string> &changeable_weights_name = {}) { 1449+ return mindspore::lite::RET_ERROR; 1450+ } 1451 virtual int UpdateWeights(std::vector<lite::Tensor *> new_weights) { return mindspore::lite::RET_ERROR; } 1452 virtual std::vector<lite::Tensor *> GetFeatureMaps() const { 1453 std::vector<lite::Tensor *> features; 1454diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc 1455index 1af44e45..03f5675e 100644 1456--- a/mindspore/lite/src/train/graph_fusion.cc 1457+++ b/mindspore/lite/src/train/graph_fusion.cc 1458@@ -22,6 +22,7 @@ 1459 #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" 1460 #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" 1461 #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 1462+#include "src/train/optimizer/fusion/gru_fusion_pass.h" 1463 #include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h" 1464 #include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h" 1465 1466@@ -41,6 +42,12 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) { 1467 MS_LOG(ERROR) << "graph is nullptr."; 1468 return RET_ERROR; 1469 } 1470+ auto gru_fusion = std::make_shared<GruFusionPass>(); 1471+ MS_CHECK_TRUE_MSG(gru_fusion != nullptr, RET_NULL_PTR, "Create GruFusion object failed."); 1472+ if (gru_fusion->Run(graph) != RET_OK) { 1473+ MS_LOG(ERROR) << "Do GruFusion failed."; 1474+ return RET_ERROR; 1475+ } 1476 auto old_nodes = GetGraphNodes(*graph); 1477 Optimizer fusion_optimizer; 1478 fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass()); 1479diff --git a/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc 1480new file mode 100644 1481index 00000000..435686e5 1482--- /dev/null 1483+++ b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc 1484@@ -0,0 +1,809 @@ 1485+/** 1486+ * Copyright 2023 Huawei Technologies Co., Ltd 1487+ * 1488+ * Licensed under the Apache License, Version 2.0 (the "License"); 1489+ * you may not use this file except in compliance with the License. 1490+ * You may obtain a copy of the License at 1491+ * 1492+ * http://www.apache.org/licenses/LICENSE-2.0 1493+ * 1494+ * Unless required by applicable law or agreed to in writing, software 1495+ * distributed under the License is distributed on an "AS IS" BASIS, 1496+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1497+ * See the License for the specific language governing permissions and 1498+ * limitations under the License. 1499+ */ 1500+ 1501+#include "src/train/optimizer/fusion/gru_fusion_pass.h" 1502+#include <algorithm> 1503+#include <map> 1504+#include <memory> 1505+#include <set> 1506+#include <string> 1507+#include <utility> 1508+#include <vector> 1509+#include "src/common/log_adapter.h" 1510+#include "include/errorcode.h" 1511+#include "nnacl/op_base.h" 1512+ 1513+namespace mindspore { 1514+namespace lite { 1515+namespace { 1516+constexpr size_t kSplitOutSize = 3; 1517+constexpr uint32_t kAdd0 = 0; 1518+constexpr uint32_t kAdd1 = 1; 1519+constexpr uint32_t kAdd2 = 2; 1520+constexpr uint32_t kAdd3 = 3; 1521+constexpr uint32_t kAdd4 = 4; 1522+constexpr uint32_t kAdd5 = 5; 1523+constexpr uint32_t kSub = 6; 1524+constexpr uint32_t kMul0 = 7; 1525+constexpr uint32_t kMul1 = 8; 1526+constexpr uint32_t kTanh = 9; 1527+constexpr uint32_t kSigmoid0 = 10; 1528+constexpr uint32_t kSigmoid1 = 11; 1529+constexpr uint32_t kSplit0 = 12; 1530+constexpr uint32_t kSplit1 = 13; 1531+constexpr uint32_t kMatmul0 = 14; 1532+constexpr uint32_t kMatmul1 = 15; 1533+constexpr uint32_t kInputH = 16; 1534+constexpr uint32_t kInputI = 17; 1535+constexpr auto kCustomGRU = "CustomGRU"; 1536+ 1537+bool CheckCommon(schema::MetaGraphT *graph, uint32_t node_index, schema::PrimitiveType type, size_t in_nums, 1538+ size_t out_nums) { 1539+ if (graph->nodes.size() <= node_index) { 1540+ return false; 1541+ } 1542+ const auto &node = graph->nodes[node_index]; 1543+ if (node == nullptr || node->primitive == nullptr) { 1544+ return false; 1545+ } 1546+ const auto &value = node->primitive->value; 1547+ if (value.type != type) { 1548+ return false; 1549+ } 1550+ if (value.value == nullptr) { 1551+ return false; 1552+ } 1553+ if ((in_nums > 0 && node->inputIndex.size() != in_nums) || node->outputIndex.size() != out_nums) { 1554+ return false; 1555+ } 1556+ return std::all_of(node->inputIndex.begin(), node->inputIndex.end(), 1557+ [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; }) && 1558+ std::all_of(node->outputIndex.begin(), node->outputIndex.end(), 1559+ [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; }); 1560+} 1561+ 1562+template <schema::PrimitiveType T, typename P> 1563+bool CheckArithmetic(schema::MetaGraphT *graph, uint32_t node_index) { 1564+ if (!CheckCommon(graph, node_index, T, kInputSize1, 1)) { 1565+ return false; 1566+ } 1567+ const auto &node = graph->nodes[node_index]; 1568+ const auto &value = node->primitive->value; 1569+ const auto add_attr = static_cast<const P *>(value.value); 1570+ if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) { 1571+ return false; 1572+ } 1573+ auto tensor_indexes = node->inputIndex; 1574+ (void)tensor_indexes.insert(tensor_indexes.end(), node->outputIndex.begin(), node->outputIndex.end()); 1575+ std::vector<int> shape; 1576+ for (size_t i = 0; i < tensor_indexes.size(); ++i) { 1577+ if (i == 0) { 1578+ shape = graph->allTensors[tensor_indexes[i]]->dims; 1579+ continue; 1580+ } 1581+ if (graph->allTensors[tensor_indexes[i]]->dims != shape) { 1582+ return false; 1583+ } 1584+ } 1585+ return true; 1586+} 1587+ 1588+template <schema::ActivationType T> 1589+bool CheckActivation(schema::MetaGraphT *graph, uint32_t node_index) { 1590+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_Activation, 1, 1)) { 1591+ return false; 1592+ } 1593+ const auto &value = graph->nodes[node_index]->primitive->value; 1594+ const auto add_attr = static_cast<const schema::ActivationT *>(value.value); 1595+ if (add_attr->activation_type != T) { 1596+ return false; 1597+ } 1598+ return true; 1599+} 1600+ 1601+bool CheckBiasAdd(schema::MetaGraphT *graph, uint32_t node_index) { 1602+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_AddFusion, kInputSize1, 1) && 1603+ !CheckCommon(graph, node_index, schema::PrimitiveType_BiasAdd, kInputSize1, 1)) { 1604+ return false; 1605+ } 1606+ const auto &node = graph->nodes[node_index]; 1607+ const auto &value = node->primitive->value; 1608+ if (value.type == schema::PrimitiveType_AddFusion) { 1609+ const auto add_attr = static_cast<const schema::AddFusionT *>(value.value); 1610+ if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) { 1611+ return false; 1612+ } 1613+ } 1614+ auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims; 1615+ auto in_shape1 = graph->allTensors[node->inputIndex[1]]->dims; 1616+ if (in_shape1.size() != 1 || in_shape0.empty() || in_shape0.back() != in_shape1.back()) { 1617+ return false; 1618+ } 1619+ return true; 1620+} 1621+ 1622+bool CheckMatmul(schema::MetaGraphT *graph, uint32_t node_index) { 1623+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_MatMulFusion, kInputSize1, 1)) { 1624+ return false; 1625+ } 1626+ const auto &node = graph->nodes[node_index]; 1627+ const auto &value = node->primitive->value; 1628+ const auto matmul_attr = static_cast<const schema::MatMulFusionT *>(value.value); 1629+ if (matmul_attr->activation_type != schema::ActivationType_NO_ACTIVATION) { 1630+ return false; 1631+ } 1632+ auto out_shape = graph->allTensors[node->outputIndex.front()]->dims; 1633+ return out_shape.size() == kInputSize1; 1634+} 1635+ 1636+bool CheckSplit(schema::MetaGraphT *graph, uint32_t node_index) { 1637+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_Split, 1, kSplitOutSize)) { 1638+ return false; 1639+ } 1640+ const auto &node = graph->nodes[node_index]; 1641+ if (node->inputIndex.size() != 1 || node->outputIndex.size() != kSplitOutSize) { 1642+ return false; 1643+ } 1644+ auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims; 1645+ auto out_shape0 = graph->allTensors[node->outputIndex[0]]->dims; 1646+ auto out_shape1 = graph->allTensors[node->outputIndex[1]]->dims; 1647+ auto out_shape2 = graph->allTensors[node->outputIndex[kInputSize1]]->dims; 1648+ if (out_shape0 != out_shape1 || out_shape0 != out_shape2) { 1649+ return false; 1650+ } 1651+ if (in_shape0.empty() || out_shape0.empty()) { 1652+ return false; 1653+ } 1654+ if (in_shape0.back() != (out_shape0.back() + out_shape1.back() + out_shape2.back())) { 1655+ return false; 1656+ } 1657+ return true; 1658+} 1659+ 1660+bool CheckStack(schema::MetaGraphT *graph, uint32_t node_index) { 1661+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_Stack, 0, 1)) { 1662+ return false; 1663+ } 1664+ const auto &node = graph->nodes[node_index]; 1665+ const auto &value = node->primitive->value; 1666+ const auto stack_attr = static_cast<const schema::StackT *>(value.value); 1667+ auto out_shape = graph->allTensors[node->outputIndex.front()]->dims; 1668+ if (out_shape.empty()) { 1669+ return false; 1670+ } 1671+ auto axis = stack_attr->axis; 1672+ if (axis < 0) { 1673+ axis += static_cast<int64_t>(out_shape.size()); 1674+ } 1675+ return axis == 0; 1676+} 1677+ 1678+bool CheckSqueeze(schema::MetaGraphT *graph, uint32_t node_index) { 1679+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_Squeeze, 0, 1)) { 1680+ return false; 1681+ } 1682+ const auto &node = graph->nodes[node_index]; 1683+ if (node->inputIndex.size() != 1 && node->inputIndex.size() != kInputSize1) { 1684+ return false; 1685+ } 1686+ int axis = 0; 1687+ if (node->inputIndex.size() == kInputSize1) { 1688+ const auto &data = graph->allTensors[node->inputIndex[1]]->data; 1689+ if (data.size() != sizeof(int)) { 1690+ return false; 1691+ } 1692+ axis = *(reinterpret_cast<const int *>(data.data())); 1693+ } else { 1694+ const auto &value = node->primitive->value; 1695+ const auto squeeze_attr = static_cast<const schema::SqueezeT *>(value.value); 1696+ if (squeeze_attr->axis.size() != 1) { 1697+ return false; 1698+ } 1699+ axis = squeeze_attr->axis.front(); 1700+ } 1701+ auto in_shape = graph->allTensors[node->inputIndex[0]]->dims; 1702+ if (in_shape.empty()) { 1703+ return false; 1704+ } 1705+ if (axis < 0) { 1706+ axis += static_cast<int>(in_shape.size()); 1707+ } 1708+ return axis == 0; 1709+} 1710+ 1711+std::vector<int> GetStridedSlicePoints(const schema::TensorT *tensor, int64_t mask) { 1712+ if (tensor->data.empty()) { 1713+ return {}; 1714+ } 1715+ auto origin_data = reinterpret_cast<const int *>(tensor->data.data()); 1716+ size_t num = tensor->data.size() / sizeof(int); 1717+ std::vector<int> data; 1718+ for (size_t i = 0; i < num; ++i) { 1719+ bool ineffective = (mask & (1 << i)); 1720+ int cur_point = ineffective ? 0 : origin_data[i]; 1721+ data.push_back(cur_point); 1722+ } 1723+ return data; 1724+} 1725+ 1726+bool CheckStridedSlice(schema::MetaGraphT *graph, uint32_t node_index, int batch_position) { 1727+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_StridedSlice, C4NUM, 1)) { 1728+ return false; 1729+ } 1730+ const auto &node = graph->nodes[node_index]; 1731+ const auto &step_tensor = graph->allTensors[node->inputIndex.back()]; 1732+ if (!step_tensor->data.empty()) { 1733+ const auto data = reinterpret_cast<int *>(step_tensor->data.data()); 1734+ auto size = step_tensor->data.size() / sizeof(int); 1735+ if (std::any_of(data, data + size, [](int val) { return val != 1; })) { 1736+ return false; 1737+ } 1738+ } 1739+ auto in_shape = graph->allTensors[node->inputIndex.front()]->dims; 1740+ auto out_shape = graph->allTensors[node->outputIndex.back()]->dims; 1741+ if (in_shape.size() != out_shape.size() || in_shape.empty()) { 1742+ return false; 1743+ } 1744+ for (size_t i = 1; i < in_shape.size(); ++i) { 1745+ if (in_shape[i] != out_shape[i]) { 1746+ return false; 1747+ } 1748+ } 1749+ const auto &value = node->primitive->value; 1750+ const auto strided_slice_attr = static_cast<const schema::StridedSliceT *>(value.value); 1751+ if (strided_slice_attr->ellipsis_mask != 0 || strided_slice_attr->new_axis_mask != 0 || 1752+ strided_slice_attr->shrink_axis_mask != 0) { 1753+ return false; 1754+ } 1755+ auto begin = GetStridedSlicePoints(graph->allTensors[node->inputIndex[1]].get(), strided_slice_attr->begin_mask); 1756+ if (begin.empty()) { 1757+ return false; 1758+ } 1759+ return begin.front() == batch_position; 1760+} 1761+ 1762+bool CheckGruCell(schema::MetaGraphT *graph, uint32_t node_index) { 1763+ if (!CheckCommon(graph, node_index, schema::PrimitiveType_Custom, C6NUM, 1)) { 1764+ return false; 1765+ } 1766+ const auto &node = graph->nodes[node_index]; 1767+ const auto &value = node->primitive->value; 1768+ const auto gru_attr = static_cast<const schema::CustomT *>(value.value); 1769+ return gru_attr->type == kCustomGRU; 1770+} 1771+ 1772+std::unique_ptr<schema::CustomT> CreateCustom() { 1773+ auto ConvertToAttr = [](const std::string &key, const std::vector<uint8_t> &value) { 1774+ auto attr = std::make_unique<schema::AttributeT>(); 1775+ attr->name = key; 1776+ attr->data = value; 1777+ return attr; 1778+ }; 1779+ auto attrs = std::make_unique<schema::CustomT>(); 1780+ MS_CHECK_TRUE_MSG(attrs != nullptr, nullptr, "Create CustomT failed."); 1781+ attrs->type = kCustomGRU; 1782+ std::vector<uint8_t> transpose_a{false}; 1783+ std::vector<uint8_t> transpose_b{true}; 1784+ std::vector<uint8_t> built_in{true}; 1785+ 1786+ attrs->attr.push_back(ConvertToAttr("transpose_a", transpose_a)); 1787+ attrs->attr.push_back(ConvertToAttr("transpose_b", transpose_b)); 1788+ attrs->attr.push_back(ConvertToAttr("builtin", built_in)); 1789+ return attrs; 1790+} 1791+ 1792+struct InNodeInfo { 1793+ int node_index; 1794+ std::vector<uint32_t> in_indexes; 1795+}; 1796+ 1797+struct OutNodeInfo { 1798+ int node_index; 1799+ uint32_t out_index; 1800+}; 1801+ 1802+struct camp { 1803+ bool operator()(uint32_t left, uint32_t right) const { return left > right; } 1804+}; 1805+} // namespace 1806+ 1807+class LinkInfoManager { 1808+ public: 1809+ explicit LinkInfoManager(schema::MetaGraphT *graph) : graph_{graph} { 1810+ auto &all_nodes = graph->nodes; 1811+ for (int node_index = 0; node_index < static_cast<int>(all_nodes.size()); ++node_index) { 1812+ auto in_indexes = all_nodes[node_index]->inputIndex; 1813+ for (uint32_t index = 0; index < static_cast<uint32_t>(in_indexes.size()); ++index) { 1814+ if (link_info_manager_.find(in_indexes[index]) == link_info_manager_.end()) { 1815+ link_info_manager_[in_indexes[index]] = std::make_pair(std::vector<InNodeInfo>{}, OutNodeInfo{-1, 0}); 1816+ } 1817+ auto &in_infos = link_info_manager_[in_indexes[index]].first; 1818+ auto iter = in_infos.begin(); 1819+ for (; iter != in_infos.end(); ++iter) { 1820+ if (iter->node_index == node_index) { 1821+ break; 1822+ } 1823+ } 1824+ if (iter != in_infos.end()) { 1825+ iter->in_indexes.push_back(index); 1826+ } else { 1827+ in_infos.push_back({node_index, {index}}); 1828+ } 1829+ } 1830+ 1831+ auto out_indexes = all_nodes[node_index]->outputIndex; 1832+ for (uint32_t index = 0; index < static_cast<uint32_t>(out_indexes.size()); ++index) { 1833+ link_info_manager_[out_indexes[index]].second = OutNodeInfo{node_index, index}; 1834+ } 1835+ } 1836+ } 1837+ 1838+ const auto &GetLinkInfos() const { return link_info_manager_; } 1839+ 1840+ void Replace(uint32_t node_index, std::unique_ptr<CNodeT> node) { graph_->nodes[node_index].swap(node); } 1841+ 1842+ void AddDeleteNodes(const std::set<uint32_t> &node_indexes) { 1843+ delete_nodes_.insert(node_indexes.begin(), node_indexes.end()); 1844+ } 1845+ 1846+ void UpdateMetaGraph() { 1847+ auto &main_graph = graph_->subGraph.front(); 1848+ for (auto node_index : delete_nodes_) { 1849+ graph_->nodes.erase(graph_->nodes.begin() + node_index); 1850+ } 1851+ main_graph->nodeIndices.clear(); 1852+ for (uint32_t index = 0; index < static_cast<uint32_t>(graph_->nodes.size()); ++index) { 1853+ main_graph->nodeIndices.push_back(index); 1854+ } 1855+ std::map<uint32_t, uint32_t> tensor_maps; 1856+ BuildTensorMap(&tensor_maps); 1857+ auto UpdateTensorIndex = [&tensor_maps](std::vector<uint32_t> *origin) { 1858+ auto origin_indexes = *origin; 1859+ origin->clear(); 1860+ (void)std::transform(origin_indexes.begin(), origin_indexes.end(), std::back_inserter(*origin), 1861+ [&tensor_maps](uint32_t origin_index) { return tensor_maps[origin_index]; }); 1862+ }; 1863+ UpdateTensorIndex(&graph_->inputIndex); 1864+ for (auto &node : graph_->nodes) { 1865+ UpdateTensorIndex(&node->inputIndex); 1866+ UpdateTensorIndex(&node->outputIndex); 1867+ } 1868+ UpdateTensorIndex(&graph_->outputIndex); 1869+ main_graph->inputIndices = graph_->inputIndex; 1870+ main_graph->outputIndices = graph_->outputIndex; 1871+ main_graph->tensorIndices.clear(); 1872+ for (uint32_t index = 0; index < static_cast<uint32_t>(tensor_maps.size()); ++index) { 1873+ main_graph->tensorIndices.push_back(index); 1874+ } 1875+ std::vector<std::unique_ptr<TensorT>> tensors; 1876+ graph_->allTensors.swap(tensors); 1877+ graph_->allTensors.resize(tensor_maps.size()); 1878+ for (auto &tensor_map : tensor_maps) { 1879+ graph_->allTensors[tensor_map.second].swap(tensors[tensor_map.first]); 1880+ } 1881+ } 1882+ 1883+ private: 1884+ void BuildTensorMap(std::map<uint32_t, uint32_t> *tensor_maps) { 1885+ uint32_t new_index = 0; 1886+ auto InsertElements = [tensor_maps, &new_index](const std::vector<uint32_t> &indexes) mutable { 1887+ for (auto index : indexes) { 1888+ if (tensor_maps->find(index) != tensor_maps->end()) { 1889+ continue; 1890+ } 1891+ (*tensor_maps)[index] = new_index++; 1892+ } 1893+ }; 1894+ InsertElements(graph_->inputIndex); 1895+ for (auto &node : graph_->nodes) { 1896+ InsertElements(node->inputIndex); 1897+ InsertElements(node->outputIndex); 1898+ } 1899+ InsertElements(graph_->outputIndex); 1900+ } 1901+ 1902+ schema::MetaGraphT *graph_{nullptr}; 1903+ std::set<uint32_t, camp> delete_nodes_; 1904+ // tensor_index, <in_node_infos, out_node_info> 1905+ std::map<uint32_t, std::pair<std::vector<InNodeInfo>, OutNodeInfo>> link_info_manager_; 1906+}; 1907+ 1908+class GruCellFusion { 1909+ public: 1910+ GruCellFusion() = default; 1911+ ~GruCellFusion() = default; 1912+ STATUS Run(schema::MetaGraphT *graph) { 1913+ MS_ASSERT(graph != nullptr); 1914+ MS_ASSERT(graph->subGraph.size() == 1); 1915+ link_info_manager_ = std::make_shared<LinkInfoManager>(graph); 1916+ graph_ = graph; 1917+ DefinePattern(); 1918+ for (uint32_t node_index = 0; node_index < static_cast<uint32_t>(graph->nodes.size()); ++node_index) { 1919+ if (!MatchPattern(node_index)) { 1920+ continue; 1921+ } 1922+ if (CreateCustomGruCell() != RET_OK) { 1923+ MS_LOG(ERROR) << "Create Custom-Gru failed."; 1924+ return RET_ERROR; 1925+ } 1926+ } 1927+ link_info_manager_->UpdateMetaGraph(); 1928+ return RET_OK; 1929+ } 1930+ 1931+ private: 1932+ struct NodeInfo { 1933+ struct InTensorInfo { 1934+ bool is_const{false}; 1935+ uint32_t node_index_{0}; 1936+ uint32_t tensor_index_{0}; 1937+ }; 1938+ struct OutTensorInfo { 1939+ uint32_t node_index_{0}; 1940+ uint32_t tensor_index_{0}; 1941+ }; 1942+ bool (*checker)(schema::MetaGraphT *graph, uint32_t node_index); 1943+ std::vector<InTensorInfo> in_infos; 1944+ std::vector<OutTensorInfo> out_infos; 1945+ }; 1946+ 1947+ void DefinePattern() { 1948+ int match_order = 0; 1949+ pattern_[{match_order++, kAdd0}] = { 1950+ CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>, {{false, kTanh, 0}, {false, kMul0, 0}}, {}}; 1951+ pattern_[{match_order++, kTanh}] = { 1952+ CheckActivation<schema::ActivationType_TANH>, {{false, kAdd1, 0}}, {{kSub, 1}, {kAdd0, 0}}}; 1953+ pattern_[{match_order++, kMul0}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>, 1954+ {{false, kSigmoid0, 0}, {false, kSub, 0}}, 1955+ {{kAdd0, 1}}}; 1956+ pattern_[{match_order++, kAdd1}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>, 1957+ {{false, kSplit0, 2}, {false, kMul1, 0}}, 1958+ {{kTanh, 0}}}; 1959+ pattern_[{match_order++, kSub}] = {CheckArithmetic<schema::PrimitiveType_SubFusion, schema::SubFusionT>, 1960+ {{false, kInputH, 0}, {false, kTanh, 0}}, 1961+ {{kMul0, 1}}}; 1962+ pattern_[{match_order++, kSigmoid0}] = { 1963+ CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd2, 0}}, {{kMul0, 0}}}; 1964+ pattern_[{match_order++, kSplit0}] = {CheckSplit, {{false, kAdd3, 0}}, {{kAdd4, 0}, {kAdd2, 0}, {kAdd1, 0}}}; 1965+ pattern_[{match_order++, kMul1}] = {CheckArithmetic<schema::PrimitiveType_MulFusion, schema::MulFusionT>, 1966+ {{false, kSigmoid1, 0}, {false, kSplit1, 2}}, 1967+ {{kAdd1, 1}}}; 1968+ pattern_[{match_order++, kAdd2}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>, 1969+ {{false, kSplit0, 1}, {false, kSplit1, 1}}, 1970+ {{kSigmoid0, 0}}}; 1971+ pattern_[{match_order++, kSigmoid1}] = { 1972+ CheckActivation<schema::ActivationType_SIGMOID>, {{false, kAdd4, 0}}, {{kMul1, 0}}}; 1973+ pattern_[{match_order++, kAdd3}] = {CheckBiasAdd, {{false, kMatmul0, 0}, {true}}, {{kSplit0, 0}}}; 1974+ pattern_[{match_order++, kSplit1}] = {CheckSplit, {{false, kAdd5, 0}}, {{kAdd4, 1}, {kAdd2, 1}, {kMul1, 1}}}; 1975+ pattern_[{match_order++, kAdd4}] = {CheckArithmetic<schema::PrimitiveType_AddFusion, schema::AddFusionT>, 1976+ {{false, kSplit0, 0}, {false, kSplit1, 0}}, 1977+ {{kSigmoid1, 0}}}; 1978+ pattern_[{match_order++, kAdd5}] = {CheckBiasAdd, {{false, kMatmul1, 0}, {true}}, {{kSplit1, 0}}}; 1979+ pattern_[{match_order++, kMatmul0}] = {CheckMatmul, {{false, kInputI, 0}, {true}}, {{kAdd3, 0}}}; 1980+ pattern_[{match_order++, kMatmul1}] = {CheckMatmul, {{false, kInputH, 0}, {true}}, {{kAdd5, 0}}}; 1981+ } 1982+ 1983+ bool FillRealPattern(uint32_t node_index, std::map<uint32_t, NodeInfo> *real_pattern) { 1984+ const auto &link_infos = link_info_manager_->GetLinkInfos(); 1985+ if (real_pattern->find(node_index) != real_pattern->end()) { 1986+ return false; 1987+ } 1988+ real_pattern->insert({node_index, {nullptr}}); 1989+ auto in_tensor_indexes = graph_->nodes[node_index]->inputIndex; 1990+ for (auto tensor_index : in_tensor_indexes) { 1991+ if (link_infos.find(tensor_index) == link_infos.end()) { 1992+ return false; 1993+ } 1994+ const auto &tensor_out_info = link_infos.at(tensor_index).second; 1995+ if (tensor_out_info.node_index < 0) { 1996+ real_pattern->at(node_index).in_infos.push_back({true}); 1997+ } else { 1998+ real_pattern->at(node_index) 1999+ .in_infos.push_back({false, static_cast<uint32_t>(tensor_out_info.node_index), tensor_out_info.out_index}); 2000+ } 2001+ } 2002+ auto out_tensor_indexes = graph_->nodes[node_index]->outputIndex; 2003+ for (auto tensor_index : out_tensor_indexes) { 2004+ if (link_infos.find(tensor_index) == link_infos.end()) { 2005+ return false; 2006+ } 2007+ const auto &in_tensor_out_info = link_infos.at(tensor_index).first; 2008+ for (const auto &in_node_info : in_tensor_out_info) { 2009+ for (auto index : in_node_info.in_indexes) { 2010+ real_pattern->at(node_index).out_infos.push_back({static_cast<uint32_t>(in_node_info.node_index), index}); 2011+ } 2012+ } 2013+ } 2014+ return true; 2015+ } 2016+ 2017+ bool CheckPattern(const std::map<uint32_t, NodeInfo> &real_pattern, 2018+ const std::pair<int, uint32_t> &pattern_node_index) { 2019+ const auto &real_in_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).in_infos; 2020+ const auto &virtual_in_infos = pattern_.at(pattern_node_index).in_infos; 2021+ if (real_in_infos.size() != virtual_in_infos.size()) { 2022+ return false; 2023+ } 2024+ for (size_t i = 0; i < virtual_in_infos.size(); ++i) { 2025+ if (virtual_in_infos[i].is_const) { 2026+ if (!real_in_infos[i].is_const) { 2027+ return false; 2028+ } 2029+ continue; 2030+ } 2031+ if (virtual_in_infos[i].tensor_index_ != real_in_infos[i].tensor_index_) { 2032+ return false; 2033+ } 2034+ if (real_node_map_.find(virtual_in_infos[i].node_index_) == real_node_map_.end()) { 2035+ real_node_map_.insert({virtual_in_infos[i].node_index_, real_in_infos[i].node_index_}); 2036+ } else if (real_node_map_.at(virtual_in_infos[i].node_index_) != real_in_infos[i].node_index_) { 2037+ return false; 2038+ } 2039+ } 2040+ const auto &real_out_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).out_infos; 2041+ const auto &virtual_out_infos = pattern_.at(pattern_node_index).out_infos; 2042+ if (virtual_out_infos.empty()) { 2043+ return true; 2044+ } 2045+ if (real_out_infos.size() != virtual_out_infos.size()) { 2046+ return false; 2047+ } 2048+ for (size_t i = 0; i < virtual_out_infos.size(); ++i) { 2049+ if (virtual_out_infos[i].tensor_index_ != real_out_infos[i].tensor_index_) { 2050+ return false; 2051+ } 2052+ if (real_node_map_.find(virtual_out_infos[i].node_index_) == real_node_map_.end()) { 2053+ real_node_map_.insert({virtual_out_infos[i].node_index_, real_out_infos[i].node_index_}); 2054+ } else if (real_node_map_.at(virtual_out_infos[i].node_index_) != real_out_infos[i].node_index_) { 2055+ return false; 2056+ } 2057+ } 2058+ return true; 2059+ } 2060+ 2061+ bool CheckClosure(const std::map<uint32_t, uint32_t> &node_map) { 2062+ std::set<uint32_t> real_nodes; 2063+ (void)std::for_each(node_map.begin(), node_map.end(), 2064+ [&real_nodes](std::pair<uint32_t, uint32_t> pair) { real_nodes.insert(pair.second); }); 2065+ if (real_nodes.size() != node_map.size()) { 2066+ return false; 2067+ } 2068+ const auto &link_infos = link_info_manager_->GetLinkInfos(); 2069+ for (uint32_t start = kAdd1; start <= kMatmul1; ++start) { 2070+ if (node_map.find(start) == node_map.end()) { 2071+ return false; 2072+ } 2073+ const auto &node = graph_->nodes[node_map.at(start)]; 2074+ auto out_tensor_indexes = node->outputIndex; 2075+ for (auto out_index : out_tensor_indexes) { 2076+ if (link_infos.find(out_index) == link_infos.end()) { 2077+ return false; 2078+ } 2079+ for (const auto &in_node_info : link_infos.at(out_index).first) { 2080+ if (real_nodes.find(in_node_info.node_index) == real_nodes.end()) { 2081+ return false; 2082+ } 2083+ } 2084+ } 2085+ } 2086+ return true; 2087+ } 2088+ 2089+ bool MatchPattern(uint32_t add_index) { 2090+ real_node_map_.clear(); 2091+ real_node_map_[kAdd0] = add_index; 2092+ std::map<uint32_t, NodeInfo> real_pattern; 2093+ for (const auto &pair : pattern_) { 2094+ if (real_node_map_.find(pair.first.second) == real_node_map_.end()) { 2095+ return false; 2096+ } 2097+ auto node_index = real_node_map_[pair.first.second]; 2098+ if (!pair.second.checker(graph_, node_index)) { 2099+ return false; 2100+ } 2101+ if (!FillRealPattern(node_index, &real_pattern)) { 2102+ return false; 2103+ } 2104+ if (!CheckPattern(real_pattern, pair.first)) { 2105+ return false; 2106+ } 2107+ } 2108+ auto weight_hidden_index = graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1]; 2109+ auto weight_hidden_shape = graph_->allTensors[weight_hidden_index]->dims; 2110+ if (weight_hidden_shape.size() != C2NUM || weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) { 2111+ return false; 2112+ } 2113+ return CheckClosure(real_node_map_); 2114+ } 2115+ 2116+ STATUS CreateCustomGruCell() { 2117+ std::vector<uint32_t> inputs; 2118+ inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[0]); // x 2119+ inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[1]); // weight_input 2120+ inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1]); // weight_hidden 2121+ inputs.push_back(graph_->nodes[real_node_map_[kAdd3]]->inputIndex[1]); // bias_input 2122+ inputs.push_back(graph_->nodes[real_node_map_[kAdd5]]->inputIndex[1]); // bias_hidden 2123+ inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[0]); // init_h 2124+ auto outputs = graph_->nodes[real_node_map_[kAdd0]]->outputIndex; 2125+ auto attrs = CreateCustom(); 2126+ MS_CHECK_TRUE_RET(attrs != nullptr, RET_NULL_PTR); 2127+ auto prim_t = std::make_unique<schema::PrimitiveT>(); 2128+ MS_CHECK_TRUE_MSG(prim_t != nullptr, RET_ERROR, "Create PrimitiveT failed."); 2129+ prim_t->value.type = schema::PrimitiveType_Custom; 2130+ prim_t->value.value = attrs.release(); 2131+ auto custom_gru = std::make_unique<schema::CNodeT>(); 2132+ MS_CHECK_TRUE_MSG(custom_gru != nullptr, RET_ERROR, "Create Custom-Gru failed."); 2133+ custom_gru->name = graph_->nodes[real_node_map_[kAdd0]]->name; 2134+ custom_gru->inputIndex = inputs; 2135+ custom_gru->outputIndex = outputs; 2136+ custom_gru->primitive = std::move(prim_t); 2137+ link_info_manager_->Replace(real_node_map_[kAdd0], std::move(custom_gru)); 2138+ std::set<uint32_t> delete_nodes; 2139+ for (uint32_t i = kAdd1; i <= kMatmul1; ++i) { 2140+ delete_nodes.insert(real_node_map_[i]); 2141+ } 2142+ link_info_manager_->AddDeleteNodes(delete_nodes); 2143+ return RET_OK; 2144+ } 2145+ 2146+ std::map<std::pair<int, uint32_t>, NodeInfo> pattern_; 2147+ std::map<uint32_t, uint32_t> real_node_map_; 2148+ schema::MetaGraphT *graph_{nullptr}; 2149+ std::shared_ptr<LinkInfoManager> link_info_manager_{nullptr}; 2150+}; 2151+ 2152+STATUS GruFusionPass::Run(schema::MetaGraphT *graph) { 2153+#ifndef ENABLE_ARM64 2154+ return RET_OK; 2155+#endif 2156+ if (graph == nullptr) { 2157+ MS_LOG(ERROR) << "graph is a nullptr."; 2158+ return RET_NULL_PTR; 2159+ } 2160+ if (graph->subGraph.size() != 1) { 2161+ return RET_OK; 2162+ } 2163+ if (FuseToGruCell(graph) != RET_OK) { 2164+ return RET_ERROR; 2165+ } 2166+ return FuseGruCell(graph); 2167+} 2168+ 2169+STATUS GruFusionPass::FuseToGruCell(schema::MetaGraphT *graph) { 2170+ GruCellFusion gru_cell_fusion{}; 2171+ if (gru_cell_fusion.Run(graph) != RET_OK) { 2172+ MS_LOG(ERROR) << "Fuse GruCell failed."; 2173+ return RET_ERROR; 2174+ } 2175+ return RET_OK; 2176+} 2177+ 2178+STATUS GruFusionPass::FuseGruCell(schema::MetaGraphT *graph) { 2179+ link_info_manager_ = std::make_shared<LinkInfoManager>(graph); 2180+ for (uint32_t i = 0; i < static_cast<uint32_t>(graph->nodes.size()); ++i) { 2181+ if (!CheckStack(graph, i)) { 2182+ continue; 2183+ } 2184+ std::vector<uint32_t> strided_slices; 2185+ std::vector<uint32_t> squeezes; 2186+ std::vector<uint32_t> gru_cells; 2187+ if (!MatchPatten(graph, i, &strided_slices, &squeezes, &gru_cells)) { 2188+ continue; 2189+ } 2190+ if (CreateGru(graph, i, strided_slices, squeezes, gru_cells) != RET_OK) { 2191+ MS_LOG(ERROR) << "Fuse GruCell failed."; 2192+ return RET_ERROR; 2193+ } 2194+ } 2195+ link_info_manager_->UpdateMetaGraph(); 2196+ link_info_manager_ = nullptr; 2197+ return RET_OK; 2198+} 2199+ 2200+bool GruFusionPass::MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector<uint32_t> *strided_slices, 2201+ std::vector<uint32_t> *squeezes, std::vector<uint32_t> *gru_cells) { 2202+ auto &link_infos = link_info_manager_->GetLinkInfos(); 2203+ auto &stack_node = graph->nodes[stack_index]; 2204+ int batch_point = 0; 2205+ auto CommonCheck = [&link_infos](uint32_t tensor_index) { 2206+ if (link_infos.find(tensor_index) == link_infos.end()) { 2207+ return std::make_pair(false, 0); 2208+ } 2209+ const auto &in_node_info = link_infos.at(tensor_index).first; 2210+ if (in_node_info.size() != 1 && in_node_info.front().in_indexes.size() != 1) { 2211+ return std::make_pair(false, 0); 2212+ } 2213+ auto node_index = link_infos.at(tensor_index).second.node_index; 2214+ if (node_index < 0) { 2215+ return std::make_pair(false, 0); 2216+ } 2217+ return std::make_pair(true, node_index); 2218+ }; 2219+ for (auto tensor_index : stack_node->inputIndex) { 2220+ auto check_info = CommonCheck(tensor_index); 2221+ if (!check_info.first || !CheckGruCell(graph, check_info.second)) { 2222+ return false; 2223+ } 2224+ gru_cells->push_back(check_info.second); 2225+ auto &gru_cell_node = graph->nodes[check_info.second]; 2226+ check_info = CommonCheck(gru_cell_node->inputIndex.front()); 2227+ if (!check_info.first || !CheckSqueeze(graph, check_info.second)) { 2228+ return false; 2229+ } 2230+ squeezes->push_back(check_info.second); 2231+ auto &squeeze_node = graph->nodes[check_info.second]; 2232+ check_info = CommonCheck(squeeze_node->inputIndex.front()); 2233+ if (!check_info.first || !CheckStridedSlice(graph, check_info.second, batch_point)) { 2234+ return false; 2235+ } 2236+ strided_slices->push_back(check_info.second); 2237+ ++batch_point; 2238+ } 2239+ if (strided_slices->empty()) { 2240+ return false; 2241+ } 2242+ uint32_t input_index = graph->nodes[strided_slices->front()]->inputIndex.front(); 2243+ if (std::any_of(strided_slices->begin(), strided_slices->end(), [input_index, graph](uint32_t strided_slice) { 2244+ return graph->nodes[strided_slice]->inputIndex.front() != input_index; 2245+ })) { 2246+ return false; 2247+ } 2248+ auto in_shape = graph->allTensors[input_index]->dims; 2249+ if (in_shape.empty() || in_shape.front() != batch_point) { 2250+ return false; 2251+ } 2252+ return CheckGruCellConnection(graph, *gru_cells); 2253+} 2254+ 2255+bool GruFusionPass::CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector<uint32_t> &gru_cells) { 2256+ auto &first_node = graph->nodes[gru_cells.front()]; 2257+ if (first_node->inputIndex.size() != C6NUM) { 2258+ return false; 2259+ } 2260+ auto init_h = first_node->outputIndex.front(); 2261+ for (size_t i = 1; i < gru_cells.size(); ++i) { 2262+ auto &node = graph->nodes[gru_cells[i]]; 2263+ if (node->inputIndex.size() != first_node->inputIndex.size()) { 2264+ return false; 2265+ } 2266+ for (size_t j = 1; j < C5NUM; ++j) { 2267+ if (node->inputIndex[j] != first_node->inputIndex[j]) { 2268+ return false; 2269+ } 2270+ } 2271+ if (node->inputIndex[C5NUM] != init_h) { 2272+ return false; 2273+ } 2274+ init_h = node->outputIndex.front(); 2275+ } 2276+ return true; 2277+} 2278+ 2279+STATUS GruFusionPass::CreateGru(schema::MetaGraphT *graph, uint32_t stack_index, 2280+ const std::vector<uint32_t> &strided_slices, const std::vector<uint32_t> &squeezes, 2281+ const std::vector<uint32_t> &gru_cells) { 2282+ auto &gru_cell_node = graph->nodes[gru_cells.front()]; 2283+ gru_cell_node->inputIndex[0] = graph->nodes[strided_slices.front()]->inputIndex[0]; 2284+ gru_cell_node->outputIndex[0] = graph->nodes[stack_index]->outputIndex[0]; 2285+ std::set<uint32_t> delete_node{stack_index}; 2286+ (void)delete_node.insert(strided_slices.begin(), strided_slices.end()); 2287+ (void)delete_node.insert(squeezes.begin(), squeezes.end()); 2288+ (void)delete_node.insert(gru_cells.begin() + 1, gru_cells.end()); 2289+ link_info_manager_->AddDeleteNodes(delete_node); 2290+ return RET_OK; 2291+} 2292+} // namespace lite 2293+} // namespace mindspore 2294diff --git a/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h 2295new file mode 100644 2296index 00000000..5e2b705d 2297--- /dev/null 2298+++ b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h 2299@@ -0,0 +1,45 @@ 2300+/** 2301+ * Copyright 2023 Huawei Technologies Co., Ltd 2302+ * 2303+ * Licensed under the Apache License, Version 2.0 (the "License"); 2304+ * you may not use this file except in compliance with the License. 2305+ * You may obtain a copy of the License at 2306+ * 2307+ * http://www.apache.org/licenses/LICENSE-2.0 2308+ * 2309+ * Unless required by applicable law or agreed to in writing, software 2310+ * distributed under the License is distributed on an "AS IS" BASIS, 2311+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2312+ * See the License for the specific language governing permissions and 2313+ * limitations under the License. 2314+ */ 2315+ 2316+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_ 2317+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_ 2318+ 2319+#include <memory> 2320+#include <vector> 2321+#include "tools/converter/optimizer.h" 2322+ 2323+namespace mindspore { 2324+namespace lite { 2325+class LinkInfoManager; 2326+class GruFusionPass : public GraphPass { 2327+ public: 2328+ GruFusionPass() = default; 2329+ ~GruFusionPass() override = default; 2330+ STATUS Run(schema::MetaGraphT *graph) override; 2331+ 2332+ private: 2333+ STATUS FuseToGruCell(schema::MetaGraphT *graph); 2334+ STATUS FuseGruCell(schema::MetaGraphT *graph); 2335+ bool MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector<uint32_t> *strided_slices, 2336+ std::vector<uint32_t> *squeezes, std::vector<uint32_t> *gru_cells); 2337+ bool CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector<uint32_t> &gru_cells); 2338+ STATUS CreateGru(schema::MetaGraphT *graph, uint32_t stack_index, const std::vector<uint32_t> &strided_slices, 2339+ const std::vector<uint32_t> &squeezes, const std::vector<uint32_t> &gru_cells); 2340+ std::shared_ptr<LinkInfoManager> link_info_manager_{nullptr}; 2341+}; 2342+} // namespace lite 2343+} // namespace mindspore 2344+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_ 2345diff --git a/mindspore/lite/src/train/static_allocator.h b/mindspore/lite/src/train/static_allocator.h 2346index d78e13ba..bd80651d 100644 2347--- a/mindspore/lite/src/train/static_allocator.h 2348+++ b/mindspore/lite/src/train/static_allocator.h 2349@@ -40,12 +40,12 @@ class StaticAllocator : public Allocator { 2350 if (ptr == nullptr) return STATIC_ALLOCATION; 2351 char *ptrc = reinterpret_cast<char *>(ptr); 2352 char *bufc = reinterpret_cast<char *>(start_buf_); 2353- return ((ptrc < bufc) || (ptrc - bufc >= static_cast<ptrdiff_t>(size_)) ? 1 : 0); 2354+ return ((ptrc < bufc) || (ptrc >= bufc + size_)) ? 1 : 0; 2355 } 2356 2357 private: 2358- void *start_buf_; 2359- size_t size_; 2360+ void *start_buf_{nullptr}; 2361+ size_t size_{0}; 2362 size_t total_size_ = 0; 2363 }; 2364 }; // namespace mindspore 2365diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc 2366index 7e504c4e..008de7c5 100644 2367--- a/mindspore/lite/src/train/train_export.cc 2368+++ b/mindspore/lite/src/train/train_export.cc 2369@@ -30,6 +30,10 @@ 2370 #include "src/train/graph_fusion.h" 2371 #include "src/train/graph_dropout.h" 2372 #include "src/runtime/weight_decoder.h" 2373+#include "src/runtime/kernel/cpu/fp16/fp16_op_handler.h" 2374+#ifndef ENABLE_ARM 2375+#include "base/float16.h" 2376+#endif 2377 2378 namespace mindspore { 2379 namespace lite { 2380@@ -645,6 +649,40 @@ int TrainExport::SaveToBuffer() { 2381 return RET_OK; 2382 } 2383 2384+int TrainExport::SaveWeightsToFile(bool enable_fp16, const std::vector<std::string> &changeable_weights_name) { 2385+ const auto &all_tensors = meta_graph_->allTensors; 2386+ std::ofstream weights(file_name_, std::ios::out | std::ios::trunc | std::ios::binary); 2387+ for (auto &tensor : all_tensors) { 2388+ MS_CHECK_TRUE_MSG(tensor != nullptr, RET_NULL_PTR, "Exist tensor is a nullptr."); 2389+ if (tensor->data.empty()) { 2390+ continue; 2391+ } 2392+ if (std::find(changeable_weights_name.begin(), changeable_weights_name.end(), tensor->name) != 2393+ changeable_weights_name.end()) { 2394+ auto shape = tensor->dims; 2395+ weights.write(reinterpret_cast<const char *>(shape.data()), shape.size() * sizeof(uint32_t)); 2396+ } 2397+ if (!enable_fp16 || tensor->dataType != kNumberTypeFloat32) { 2398+ weights.write(reinterpret_cast<const char *>(tensor->data.data()), tensor->data.size()); 2399+ } else { 2400+ std::vector<uint16_t> data_fp16(tensor->data.size() / sizeof(float)); 2401+#ifndef ENABLE_ARM 2402+ auto fp32_data = reinterpret_cast<const float *>(tensor->data.data()); 2403+ auto fp16_data = reinterpret_cast<float16 *>(data_fp16.data()); 2404+ CHECK_NULL_RETURN(fp32_data); 2405+ CHECK_NULL_RETURN(fp16_data); 2406+ for (size_t j = 0; j < data_fp16.size(); ++j) { 2407+ fp16_data[j] = float16(fp32_data[j]); 2408+ } 2409+#else 2410+ Float32ToFloat16_fp16_handler(tensor->data.data(), data_fp16.data(), data_fp16.size(), true); 2411+#endif 2412+ weights.write(reinterpret_cast<const char *>(data_fp16.data()), data_fp16.size() * sizeof(uint16_t)); 2413+ } 2414+ } 2415+ weights.close(); 2416+ return RET_OK; 2417+} 2418 2419 bool TrainExport::IsInputTensor(const schema::TensorT &t) { 2420 int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies<int>()); 2421diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h 2422index 8e802021..d6f81187 100644 2423--- a/mindspore/lite/src/train/train_export.h 2424+++ b/mindspore/lite/src/train/train_export.h 2425@@ -52,6 +52,7 @@ class TrainExport { 2426 int ExportInit(const std::string model_name, std::string version); 2427 int SaveToFile(); 2428 int SaveToBuffer(); 2429+ int SaveWeightsToFile(bool enable_fp16 = false, const std::vector<std::string> &changeable_weights_name = {}); 2430 void set_connect(const std::unordered_map<size_t, size_t> &map) { connect_ = map; } 2431 int LoadModel(void *buf, size_t buf_size); 2432 int AddTransformNode(); 2433diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc 2434index b40ff8c2..2f9aa99b 100644 2435--- a/mindspore/lite/src/train/train_session.cc 2436+++ b/mindspore/lite/src/train/train_session.cc 2437@@ -1233,10 +1233,45 @@ int TrainSession::Export(Buffer *model_buffer, ModelType model_type, Quantizatio 2438 return ExportInner<Buffer *>(model_buffer, model_type, quant_type, format, out_put_tensor_name); 2439 } 2440 2441+int TrainSession::ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type, 2442+ FormatType format, bool enable_fp16, 2443+ const std::vector<std::string> &changeable_weights_name) { 2444+ MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty"); 2445+ struct stat path_type; 2446+ if (stat(file_name.c_str(), &path_type) == RET_OK) { 2447+ if (path_type.st_mode & S_IFDIR) { 2448+ MS_LOG(ERROR) << "Destination must be path, now is a directory"; 2449+ return RET_ERROR; 2450+ } 2451+ } 2452+ MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "Format must be `FT_FLATBUFFERS`"); 2453+ MS_CHECK_FALSE_MSG(model_type != mindspore::lite::MT_INFERENCE, RET_ERROR, 2454+ "Currently, can only export inference-model's weights."); 2455+ int status = Eval(); 2456+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed"); 2457+ 2458+ TrainExport texport(file_name); 2459+ status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); 2460+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); 2461+ 2462+ status = texport.ExportNet(inference_kernels_, tensors_, eval_output_tensor_names_, model_.get(), QT_DEFAULT); 2463+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); 2464+ status = texport.TrainModelDrop(); 2465+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); 2466+ status = texport.TrainModelFusion(); 2467+ TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed."); 2468+ status = texport.SaveWeightsToFile(enable_fp16, changeable_weights_name); 2469+ if (status != RET_OK) { 2470+ MS_LOG(ERROR) << "Failed to save to " << file_name; 2471+ return status; 2472+ } 2473+ return RET_OK; 2474+} 2475+ 2476 std::vector<lite::Tensor *> TrainSession::GetFeatureMaps() const { 2477 std::vector<lite::Tensor *> features; 2478 for (auto cur_tensor : this->tensors_) { 2479- if (cur_tensor->category() ==lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) { 2480+ if (cur_tensor->category() == lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) { 2481 features.push_back(cur_tensor); 2482 } 2483 } 2484diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h 2485index 5acff82a..edcab32d 100644 2486--- a/mindspore/lite/src/train/train_session.h 2487+++ b/mindspore/lite/src/train/train_session.h 2488@@ -106,6 +106,9 @@ class TrainSession : virtual public lite::LiteSession { 2489 std::vector<std::string> out_put_tensor_name = {}) override; 2490 int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType, 2491 std::vector<std::string> out_put_tensor_name = {}) override; 2492+ int ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type, FormatType, 2493+ bool enable_fp16, 2494+ const std::vector<std::string> &changeable_weights_name) override; 2495 2496 std::vector<lite::Tensor *> GetFeatureMaps() const override; 2497 2498diff --git a/mindspore/lite/test/config_level0/micro/micro_arm64.cfg b/mindspore/lite/test/config_level0/micro/micro_arm64.cfg 2499index 0375ebf7..765549ab 100644 2500--- a/mindspore/lite/test/config_level0/micro/micro_arm64.cfg 2501+++ b/mindspore/lite/test/config_level0/micro/micro_arm64.cfg 2502@@ -25,3 +25,10 @@ support_parallel=false 2503 2504 # enable debug 2505 debug_mode=false 2506+ 2507+# false indicates that only the required weights are saved. If collaborate with lite-train, the parameter must be true. 2508+keep_original_weight=false 2509+ 2510+# the names of those weight-tensors whose shape is changeable, only embedding-table supports change. 2511+# the parameter is used to collaborate with lite-train. If set, `keep_original_weight` must be true. 2512+#changeable_weights_name=name0,name1 2513diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 2514index ea263e64..4c6bd237 100644 2515--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 2516+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 2517@@ -221,11 +221,14 @@ int ConfigFileParser::ParseAclOptionCfgString(const std::map<std::string, std::m 2518 int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps) { 2519 if (maps.find(kMicroParam) != maps.end()) { 2520 const auto &map = maps.at(kMicroParam); 2521- std::map<std::string, std::string &> parse_map{{"target", micro_param_string_.target}, 2522- {"codegen_mode", micro_param_string_.codegen_mode}, 2523- {"debug_mode", micro_param_string_.debug_mode}, 2524- {"support_parallel", micro_param_string_.support_parallel}, 2525- {"enable_micro", micro_param_string_.enable_micro}}; 2526+ std::map<std::string, std::string &> parse_map{ 2527+ {"target", micro_param_string_.target}, 2528+ {"codegen_mode", micro_param_string_.codegen_mode}, 2529+ {"debug_mode", micro_param_string_.debug_mode}, 2530+ {"support_parallel", micro_param_string_.support_parallel}, 2531+ {"enable_micro", micro_param_string_.enable_micro}, 2532+ {"keep_original_weight", micro_param_string_.keep_original_weight}, 2533+ {"changeable_weights_name", micro_param_string_.changeable_weights_name}}; 2534 return SetMapData(map, parse_map, kMicroParam); 2535 } 2536 return RET_OK; 2537diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 2538index 0ada406e..8854e5f7 100644 2539--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h 2540+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 2541@@ -86,6 +86,8 @@ struct MicroParamString { 2542 std::string support_parallel; 2543 std::string debug_mode; 2544 std::string enable_micro; 2545+ std::string keep_original_weight; 2546+ std::string changeable_weights_name; 2547 }; 2548 2549 struct ThirdPartyModelString { 2550diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 2551index 310b2398..559bee8b 100644 2552--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 2553+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 2554@@ -61,6 +61,23 @@ STATUS MicroParamParser::ParseEnableMicro(const std::string &enable_micro, micro 2555 return RET_OK; 2556 } 2557 2558+STATUS MicroParamParser::ParseKeepOriginalWeight(const std::string &save_all_weights, micro::MicroParam *micro_param) { 2559+ MS_LOG(DEBUG) << "Micro enables : " << save_all_weights; 2560+ micro_param->keep_original_weight = false; // default 2561+ bool is_keep_original_weight; 2562+ if (ConvertBool(save_all_weights, &is_keep_original_weight)) { 2563+ micro_param->keep_original_weight = is_keep_original_weight; 2564+ } 2565+ return RET_OK; 2566+} 2567+ 2568+STATUS MicroParamParser::ParseChangeableWeightsName(const std::string &changeable_weights_name, 2569+ micro::MicroParam *micro_param) { 2570+ MS_LOG(DEBUG) << "Micro record changeable weights name: " << changeable_weights_name; 2571+ micro_param->changeable_weights_name = changeable_weights_name; 2572+ return RET_OK; 2573+} 2574+ 2575 STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_string, micro::MicroParam *micro_param) { 2576 CHECK_NULL_RETURN(micro_param); 2577 if (!micro_param_string.target.empty()) { 2578@@ -93,6 +110,22 @@ STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_str 2579 return RET_INPUT_PARAM_INVALID; 2580 } 2581 } 2582+ if (!micro_param_string.keep_original_weight.empty()) { 2583+ if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) { 2584+ MS_LOG(ERROR) << "Parse keep_original_weight val; " << micro_param_string.keep_original_weight; 2585+ return RET_INPUT_PARAM_INVALID; 2586+ } 2587+ } 2588+ if (!micro_param_string.changeable_weights_name.empty()) { 2589+ if (!micro_param->keep_original_weight) { 2590+ MS_LOG(ERROR) << "When changeable_weights_name is set, the keep_original_weight must be true."; 2591+ return RET_INPUT_PARAM_INVALID; 2592+ } 2593+ if (ParseChangeableWeightsName(micro_param_string.changeable_weights_name, micro_param) != RET_OK) { 2594+ MS_LOG(ERROR) << "Parse changeable_weights_name val: " << micro_param_string.changeable_weights_name; 2595+ return RET_INPUT_PARAM_INVALID; 2596+ } 2597+ } 2598 return RET_OK; 2599 } 2600 } // namespace lite 2601diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 2602index 860182af..93a30b39 100644 2603--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 2604+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 2605@@ -33,6 +33,8 @@ class MicroParamParser { 2606 STATUS ParseCodeGenMode(const std::string &codegen_mode, micro::MicroParam *micro_param); 2607 STATUS ParseSupportParallel(const std::string &support_parallel, micro::MicroParam *micro_param); 2608 STATUS ParseDebugMode(const std::string &debug_mode, micro::MicroParam *micro_param); 2609+ STATUS ParseKeepOriginalWeight(const std::string &save_all_weights, micro::MicroParam *micro_param); 2610+ STATUS ParseChangeableWeightsName(const std::string &changeable_weights_name, micro::MicroParam *micro_param); 2611 }; 2612 } // namespace lite 2613 } // namespace mindspore 2614diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc 2615index 944ed29c..6177d379 100644 2616--- a/mindspore/lite/tools/converter/converter.cc 2617+++ b/mindspore/lite/tools/converter/converter.cc 2618@@ -186,6 +186,9 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> 2619 } 2620 } 2621 2622+ if (param->fmk_type == FmkType::kFmkTypeMsLite) { 2623+ return nullptr; 2624+ } 2625 auto graph = BuildFuncGraph(param); 2626 if (graph == nullptr) { 2627 MS_LOG(ERROR) << "Parser/Import model return nullptr"; 2628@@ -557,7 +560,7 @@ int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) { 2629 } 2630 const std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 2631 FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 2632- FmkType::kFmkTypeThirdParty}; 2633+ FmkType::kFmkTypeThirdParty, FmkType::kFmkTypeMsLite}; 2634 if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { 2635 MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|THIRDPARTY" 2636 << ", but got " << param->fmk_type; 2637@@ -780,6 +783,14 @@ int RunConverter(const std::shared_ptr<ConverterPara> ¶m, void **model_data, 2638 NotSupportOp::GetInstance()->PrintOps(); 2639 status = ReturnCode::GetSingleReturnCode()->status_code(); 2640 if (meta_graph == nullptr) { 2641+ if (param->fmk_type == FmkType::kFmkTypeMsLite && param->microParam.enable_micro) { 2642+ status = micro::Coder::MicroSourceCodeGeneration(param->model_file, param->output_file, param->microParam, 2643+ param->weight_fp16); 2644+ if (status != RET_OK) { 2645+ CONVERTER_LOG_ERROR("MICRO CODEGEN FAILED:" << status << " " << GetErrorInfo(status)); 2646+ } 2647+ return status; 2648+ } 2649 CONVERTER_LOG_ERROR("CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status)); 2650 status = RET_ERROR; 2651 return status; 2652@@ -797,9 +808,8 @@ int RunConverter(const std::shared_ptr<ConverterPara> ¶m, void **model_data, 2653 } 2654 2655 if (param->microParam.enable_micro) { 2656- status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam.codegen_mode, 2657- param->microParam.target, param->microParam.support_parallel, 2658- param->microParam.debug_mode, param->weight_fp16); 2659+ status = 2660+ micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam, param->weight_fp16); 2661 if (status != RET_OK) { 2662 delete meta_graph; 2663 CONVERTER_LOG_ERROR("MICRO CODEGEN FAILED:" << status << " " << GetErrorInfo(status)); 2664diff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 2665index 595b59ed..e30994cc 100644 2666--- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 2667+++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 2668@@ -29,7 +29,7 @@ using mindspore::lite::RET_INPUT_PARAM_INVALID; 2669 using mindspore::lite::RET_OK; 2670 2671 Flags::Flags() { 2672- AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", ""); 2673+ AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX | MSLITE", ""); 2674 AddFlag(&Flags::modelFile, "modelFile", 2675 "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); 2676 AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); 2677@@ -120,7 +120,7 @@ int Flags::InitFmk() { 2678 // value check not here, it is in converter c++ API's CheckValueParam method. 2679 std::map<std::string, FmkType> StrToEnumFmkTypeMap = { 2680 {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, 2681- {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}}; 2682+ {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}, {"MSLITE", kFmkTypeMsLite}}; 2683 if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) { 2684 this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn); 2685 } else { 2686diff --git a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 2687index 9ae54538..589ee81a 100644 2688--- a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 2689+++ b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 2690@@ -53,6 +53,7 @@ set(CODER_OPCODERS_SRC 2691 ${MICRO_DIR}/coder/opcoders/base/reduce_base_coder.cc 2692 ${MICRO_DIR}/coder/opcoders/base/resize_base_coder.cc 2693 ${MICRO_DIR}/coder/opcoders/base/reshape_base_coder.cc 2694+ ${MICRO_DIR}/coder/opcoders/base/stack_base_coder.cc 2695 ${MICRO_DIR}/coder/opcoders/base/softmax_base_coder.cc 2696 ${MICRO_DIR}/coder/opcoders/base/detection_post_process_base_coder.cc 2697 ${MICRO_DIR}/coder/opcoders/base/strided_slice_base_coder.cc 2698@@ -71,6 +72,7 @@ set(CODER_OPCODERS_SRC 2699 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc 2700 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/avg_pooling_fp16_coder.cc 2701 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc 2702+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc 2703 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc 2704 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_coder.cc 2705 ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 2706@@ -90,6 +92,7 @@ set(CODER_OPCODERS_SRC 2707 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc 2708 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 2709 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc 2710+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc 2711 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/full_connection_fp32_coder.cc 2712 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc 2713 ${MICRO_DIR}/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc 2714diff --git a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc 2715index 9c5839b4..be314ed6 100644 2716--- a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc 2717+++ b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc 2718@@ -79,18 +79,26 @@ void MemoryAllocator::Free() { 2719 iter++; 2720 } 2721 } 2722+ for (auto &item : auxiliary_weights_) { 2723+ delete item.second.first; 2724+ } 2725 malloc_weights_addr_.clear(); 2726 for (auto &item : allocated_) { 2727 free(item); 2728 item = nullptr; 2729 } 2730 allocated_.clear(); 2731+ origin_weights_.clear(); 2732+ auxiliary_weights_.clear(); 2733 } 2734 2735 std::map<Tensor *, std::string> MemoryAllocator::tensors_map() const { 2736 std::map<Tensor *, std::string> res; 2737 res.insert(tensors_addr_.begin(), tensors_addr_.end()); 2738 res.insert(malloc_weights_addr_.begin(), malloc_weights_addr_.end()); 2739+ (void)std::for_each( 2740+ auxiliary_weights_.begin(), auxiliary_weights_.end(), 2741+ [&res](const std::pair<Tensor *, std::pair<Tensor *, std::string>> &item) { res.insert(item.second); }); 2742 return res; 2743 } 2744 2745@@ -121,17 +129,25 @@ void MemoryAllocator::AssignGraphInputs(const std::vector<Tensor *> &inputs) { 2746 } 2747 } 2748 2749-void MemoryAllocator::RecordOriginWeightsAddr(const std::vector<std::unique_ptr<OperatorCoder>> &nodes) { 2750- for (const auto &node : nodes) { 2751- std::vector<Tensor *> inputs = node->input_tensors(); 2752- for (const auto &tensor : inputs) { 2753- if (tensor->category() == lite::Category::CONST_TENSOR || tensor->category() == lite::Category::CONST_SCALAR) { 2754- std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_); 2755- origin_weights_addr_.insert(std::make_pair(tensor, runtime_addr)); 2756- weight_index_++; 2757+int MemoryAllocator::RecordOriginWeightsAddr(const std::vector<Tensor *> &all_tensors, 2758+ const std::string &changeable_weights_name) { 2759+ std::vector<std::string> weights_name; 2760+ if (!changeable_weights_name.empty()) { 2761+ weights_name = StrSplit(changeable_weights_name, ","); 2762+ } 2763+ for (const auto &tensor : all_tensors) { 2764+ if (tensor->category() == lite::Category::CONST_TENSOR || tensor->category() == lite::Category::CONST_SCALAR) { 2765+ if (std::find(weights_name.begin(), weights_name.end(), tensor->tensor_name()) != weights_name.end()) { 2766+ if (RecordChangeableWeights(tensor) != RET_OK) { 2767+ return RET_ERROR; 2768+ } 2769 } 2770+ std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++); 2771+ origin_weights_addr_.insert(std::make_pair(tensor, runtime_addr)); 2772+ origin_weights_.push_back(tensor); 2773 } 2774 } 2775+ return RET_OK; 2776 } 2777 2778 int MemoryAllocator::AssignTensors(const std::vector<std::unique_ptr<OperatorCoder>> &nodes) { 2779@@ -150,9 +166,13 @@ int MemoryAllocator::AssignTensors(const std::vector<std::unique_ptr<OperatorCod 2780 } 2781 2782 int MemoryAllocator::Assign(const std::vector<Tensor *> &inputs, 2783- const std::vector<std::unique_ptr<OperatorCoder>> &nodes) { 2784+ const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 2785+ const std::vector<Tensor *> &all_tensors, const std::string &changeable_weights_name) { 2786 AssignGraphInputs(inputs); 2787- RecordOriginWeightsAddr(nodes); 2788+ if (RecordOriginWeightsAddr(all_tensors, changeable_weights_name) != RET_OK) { 2789+ MS_LOG(ERROR) << "RecordOriginWeightsAddr failed."; 2790+ return RET_ERROR; 2791+ } 2792 return AssignTensors(nodes); 2793 } 2794 2795@@ -163,4 +183,46 @@ void MemoryAllocator::MarkSharedWeight(const Tensor *src, void *pack_weight) { 2796 void *MemoryAllocator::GetSharedWeightAddr(const Tensor *src) { 2797 return shared_pack_weights_.find(src) == shared_pack_weights_.end() ? nullptr : shared_pack_weights_[src]; 2798 } 2799+ 2800+int MemoryAllocator::RecordChangeableWeights(Tensor *src) { 2801+ MS_ASSERT(src != nullptr); 2802+ auto variable_str = GetAuxiliaryWeight(src); 2803+ if (!variable_str.empty()) { 2804+ return RET_OK; 2805+ } 2806+ if (!src->IsConst()) { 2807+ MS_LOG(ERROR) << "Currently, the tensor must be a constant."; 2808+ return RET_NOT_SUPPORT; 2809+ } 2810+ auto shape = src->shape(); 2811+ auto shape_tensor = new (std::nothrow) 2812+ Tensor(kNumberTypeInt32, {static_cast<int>(shape.size())}, src->format(), Category::CONST_TENSOR); 2813+ if (shape_tensor == nullptr) { 2814+ MS_LOG(ERROR) << "Create an assistant tensor failed."; 2815+ return RET_NULL_PTR; 2816+ } 2817+ auto data = shape_tensor->MutableData(); 2818+ if (data == nullptr) { 2819+ MS_LOG(ERROR) << "Create an assistant tensor failed."; 2820+ delete shape_tensor; 2821+ return RET_NULL_PTR; 2822+ } 2823+ if (memcpy_s(data, shape_tensor->Size(), shape.data(), shape.size() * sizeof(int)) != EOK) { 2824+ MS_LOG(ERROR) << "Create an assistant tensor failed."; 2825+ delete shape_tensor; 2826+ return RET_ERROR; 2827+ } 2828+ shape_tensor->set_tensor_name(src->tensor_name() + "_shape"); 2829+ std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++); 2830+ auxiliary_weights_[src] = std::make_pair(shape_tensor, runtime_addr); 2831+ return RET_OK; 2832+} 2833+ 2834+std::string MemoryAllocator::GetAuxiliaryWeight(Tensor *src) { 2835+ auto iter = auxiliary_weights_.find(src); 2836+ if (iter != auxiliary_weights_.end()) { 2837+ return iter->second.second; 2838+ } 2839+ return {}; 2840+} 2841 } // namespace mindspore::lite::micro 2842diff --git a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h 2843index 8a1331fb..f5bacf6f 100644 2844--- a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h 2845+++ b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h 2846@@ -56,7 +56,8 @@ class MemoryAllocator { 2847 /* 2848 * assign model's input, original weights and all tensors memory addr 2849 */ 2850- int Assign(const std::vector<Tensor *> &inputs, const std::vector<std::unique_ptr<OperatorCoder>> &nodes); 2851+ int Assign(const std::vector<Tensor *> &inputs, const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 2852+ const std::vector<Tensor *> &all_tensors, const std::string &changeable_weights_name = {}); 2853 2854 // allocator holds the space malloced by opcoders, will free before session coder destroy 2855 void Free(); 2856@@ -141,14 +142,18 @@ class MemoryAllocator { 2857 void *MallocWeightTensor(TypeId type_id, size_t size, MallocType type, const std::string &tensor_name = ""); 2858 void MarkSharedWeight(const Tensor *src, void *pack_weight); 2859 void *GetSharedWeightAddr(const Tensor *src); 2860+ std::string GetAuxiliaryWeight(Tensor *src); 2861+ std::vector<Tensor *> origin_weights() const { return origin_weights_; } 2862+ std::map<Tensor *, std::pair<Tensor *, std::string>> auxiliary_weights() const { return auxiliary_weights_; } 2863 2864 private: 2865 int AssignTensors(const std::vector<std::unique_ptr<OperatorCoder>> &nodes); 2866 void AssignGraphInputs(const std::vector<Tensor *> &inputs); 2867 void AssignWorkspaces(void *addr, size_t size); 2868- void RecordOriginWeightsAddr(const std::vector<std::unique_ptr<OperatorCoder>> &nodes); 2869+ int RecordOriginWeightsAddr(const std::vector<Tensor *> &all_tensors, 2870+ const std::string &changeable_weights_name = {}); 2871 void RecordTensorsAddr(const std::map<Tensor *, size_t> &offsets); 2872- 2873+ int RecordChangeableWeights(Tensor *src); 2874 MemoryAllocator() = default; 2875 ~MemoryAllocator() = default; 2876 2877@@ -160,11 +165,13 @@ class MemoryAllocator { 2878 bool is_next_{false}; 2879 size_t offset_{0}; 2880 std::vector<void *> allocated_; 2881+ std::vector<Tensor *> origin_weights_; 2882 std::map<std::string, Tensor *> saved_weights_addr_; 2883 std::map<Tensor *, std::string> origin_weights_addr_; 2884 std::map<Tensor *, std::string> malloc_weights_addr_; 2885 std::map<Tensor *, std::string> tensors_addr_; 2886 std::map<const Tensor *, void *> shared_pack_weights_; 2887+ std::map<Tensor *, std::pair<Tensor *, std::string>> auxiliary_weights_; 2888 }; 2889 } // namespace mindspore::lite::micro 2890 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_ALLOCATOR_ALLOCATOR_H_ 2891diff --git a/mindspore/lite/tools/converter/micro/coder/coder.cc b/mindspore/lite/tools/converter/micro/coder/coder.cc 2892index cca4687e..a94ac91b 100644 2893--- a/mindspore/lite/tools/converter/micro/coder/coder.cc 2894+++ b/mindspore/lite/tools/converter/micro/coder/coder.cc 2895@@ -93,25 +93,48 @@ bool Coder::InitPath(const std::string &output_path) { 2896 } 2897 2898 int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, 2899- const std::string &codegen_mode, const std::string &device, bool support_parallel, 2900- bool debug_mode, bool enableFp16) { 2901+ const MicroParam ¶m, bool enable_fp16) { 2902 flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); 2903 auto offset = schema::MetaGraph::Pack(builder, &graph); 2904 builder.Finish(offset); 2905 schema::FinishMetaGraphBuffer(builder, offset); 2906 size_t size = builder.GetSize(); 2907+ if (ExecuteMicroGeneration(builder.GetBufferPointer(), size, output_path, param, enable_fp16) != RET_OK) { 2908+ MS_LOG(ERROR) << "Execute Micro failed."; 2909+ return RET_ERROR; 2910+ } 2911+ return RET_OK; 2912+} 2913+ 2914+int Coder::MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, 2915+ const MicroParam ¶m, bool enable_fp16) { 2916+ size_t buffer_size; 2917+ auto model_buf = lite::ReadFile(model_file.c_str(), &buffer_size); 2918+ if (model_buf == nullptr) { 2919+ MS_LOG(ERROR) << "Read model-file failed."; 2920+ return RET_NULL_PTR; 2921+ } 2922+ auto ret = ExecuteMicroGeneration(model_buf, buffer_size, output_path, param, enable_fp16); 2923+ if (ret != RET_OK) { 2924+ MS_LOG(ERROR) << "Execute Micro failed."; 2925+ } 2926+ delete[] model_buf; 2927+ return ret; 2928+} 2929+int Coder::ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path, 2930+ const MicroParam ¶m, bool enable_fp16) { 2931 micro::Coder code_gen; 2932 if (!code_gen.InitPath(output_path)) { 2933 MS_LOG(ERROR) << "Init path failed"; 2934 return RET_ERROR; 2935 } 2936 // codegeneration for micro 2937- STATUS status = code_gen.Init(codegen_mode, device, support_parallel, debug_mode); 2938+ STATUS status = code_gen.Init(param); 2939 if (status != RET_OK) { 2940 MS_LOG(ERROR) << "Codegen init Error"; 2941 return RET_ERROR; 2942 } 2943- status = code_gen.Run(builder.GetBufferPointer(), size, enableFp16); 2944+ status = code_gen.Run(model_buf, size, enable_fp16); 2945 if (status != RET_OK) { 2946 MS_LOG(ERROR) << "Codegen Run Error"; 2947 return RET_ERROR; 2948@@ -120,29 +143,30 @@ int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std: 2949 return RET_OK; 2950 } 2951 2952-int Coder::Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode) const { 2953+int Coder::Init(const MicroParam ¶m) const { 2954 static const std::map<std::string, Target> kTargetMap = { 2955 {"x86", kX86}, {"Cortex-M", kCortex_M}, {"ARM32", kARM32}, {"ARM64", kARM64}, {"All", kAllTargets}}; 2956 static const std::map<std::string, CodeMode> kCodeModeMap = {{"Inference", Inference}, {"Train", Train}}; 2957 Configurator *config = Configurator::GetInstance(); 2958 2959- auto target_item = kTargetMap.find(target); 2960- MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + target); 2961+ auto target_item = kTargetMap.find(param.target); 2962+ MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + param.target); 2963 config->set_target(target_item->second); 2964 2965- auto code_item = kCodeModeMap.find(code_mode); 2966- MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + code_mode); 2967+ auto code_item = kCodeModeMap.find(param.codegen_mode); 2968+ MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + param.codegen_mode); 2969 config->set_code_mode(code_item->second); 2970 2971- if (support_parallel && config->target() == kCortex_M) { 2972+ if (param.support_parallel && config->target() == kCortex_M) { 2973 MS_LOG(ERROR) << "Cortex-M cannot support parallel."; 2974 return RET_ERROR; 2975 } 2976- config->set_support_parallel(support_parallel); 2977- config->set_debug_mode(debug_mode); 2978+ config->set_support_parallel(param.support_parallel); 2979+ config->set_debug_mode(param.debug_mode); 2980 2981 config->set_proj_dir(model_name_); 2982- 2983+ config->set_keep_original_weight(param.keep_original_weight); 2984+ config->set_changeable_weights_name(param.changeable_weights_name); 2985 const std::string slash = std::string(kSlash); 2986 if (!save_path_.empty() && !DirExists(save_path_)) { 2987 MS_LOG(ERROR) << "code_gen code path " << save_path_ << " is not valid"; 2988@@ -170,6 +194,7 @@ int Coder::Init(const std::string code_mode, const std::string target, bool supp 2989 print_parameter("codePath", config->code_path()); 2990 print_parameter("codeMode", config->code_mode()); 2991 print_parameter("debugMode", config->debug_mode()); 2992+ print_parameter("keepOriginalWeight", config->keep_original_weight()); 2993 return RET_OK; 2994 } 2995 } // namespace mindspore::lite::micro 2996diff --git a/mindspore/lite/tools/converter/micro/coder/coder.h b/mindspore/lite/tools/converter/micro/coder/coder.h 2997index 96531e6f..0753156a 100644 2998--- a/mindspore/lite/tools/converter/micro/coder/coder.h 2999+++ b/mindspore/lite/tools/converter/micro/coder/coder.h 3000@@ -31,11 +31,14 @@ class Coder final { 3001 3002 ~Coder() = default; 3003 static int MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, 3004- const std::string &codegen_mode, const std::string &device, 3005- bool support_parallel, bool debug_mode, bool enableFp16); 3006+ const MicroParam ¶m, bool enable_fp16); 3007+ static int MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, 3008+ const MicroParam ¶m, bool enable_fp16); 3009 3010 private: 3011- int Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode_) const; 3012+ static int ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path, 3013+ const MicroParam ¶m, bool enable_fp16); 3014+ int Init(const MicroParam ¶m) const; 3015 int Run(const void *model_buff, size_t size, bool enableFp16); 3016 bool InitPath(const std::string &output_path); 3017 std::shared_ptr<CoderSession> session_{nullptr}; 3018diff --git a/mindspore/lite/tools/converter/micro/coder/config.h b/mindspore/lite/tools/converter/micro/coder/config.h 3019index 84285932..42e0f50e 100644 3020--- a/mindspore/lite/tools/converter/micro/coder/config.h 3021+++ b/mindspore/lite/tools/converter/micro/coder/config.h 3022@@ -26,9 +26,11 @@ enum CodeMode { Inference = 0, Train = 1, Code_Unknown = 99 }; 3023 struct MicroParam { 3024 std::string codegen_mode = "Inference"; 3025 std::string target; 3026+ std::string changeable_weights_name; 3027 bool enable_micro{false}; 3028 bool support_parallel{false}; 3029 bool debug_mode{false}; 3030+ bool keep_original_weight{false}; 3031 }; 3032 3033 class Configurator { 3034@@ -56,6 +58,12 @@ class Configurator { 3035 void set_proj_dir(std::string dir) { proj_dir_ = dir; } 3036 std::string proj_dir() const { return proj_dir_; } 3037 3038+ void set_keep_original_weight(bool keep_weight) { keep_original_weight_ = keep_weight; } 3039+ bool keep_original_weight() const { return keep_original_weight_; } 3040+ 3041+ void set_changeable_weights_name(const std::string &weights_name) { changeable_weights_name_ = weights_name; } 3042+ const std::string &changeable_weights_name() const { return changeable_weights_name_; } 3043+ 3044 private: 3045 Configurator() = default; 3046 ~Configurator() = default; 3047@@ -64,7 +72,9 @@ class Configurator { 3048 CodeMode code_mode_{Code_Unknown}; 3049 bool support_parallel_{false}; 3050 bool debug_mode_{false}; 3051+ bool keep_original_weight_{false}; 3052 std::string proj_dir_; 3053+ std::string changeable_weights_name_; 3054 }; 3055 } // namespace mindspore::lite::micro 3056 #endif // MICRO_CODER_CONFIG_H 3057diff --git a/mindspore/lite/tools/converter/micro/coder/context.h b/mindspore/lite/tools/converter/micro/coder/context.h 3058index 724475fe..cec385bb 100644 3059--- a/mindspore/lite/tools/converter/micro/coder/context.h 3060+++ b/mindspore/lite/tools/converter/micro/coder/context.h 3061@@ -69,6 +69,25 @@ class CoderContext { 3062 void set_saved_weights(const std::map<std::string, Tensor *> &saved_weights) { saved_weights_ = saved_weights; } 3063 std::map<std::string, Tensor *> saved_weights() const { return saved_weights_; } 3064 3065+ void set_origin_weights(const std::vector<Tensor *> &origin_weights) { origin_weights_ = origin_weights; } 3066+ const std::vector<Tensor *> &origin_weights() const { return origin_weights_; } 3067+ 3068+ void set_auxiliary_weights(const std::map<Tensor *, std::pair<Tensor *, std::string>> &auxiliary_weights) { 3069+ auxiliary_weights_ = auxiliary_weights; 3070+ } 3071+ const std::map<Tensor *, std::pair<Tensor *, std::string>> &auxiliary_weights() const { return auxiliary_weights_; } 3072+ 3073+ bool JudgeIsValid(bool keep_origin_weight) { 3074+ if (!keep_origin_weight) { 3075+ return true; 3076+ } 3077+ return std::all_of(saved_weights_.begin(), saved_weights_.end(), 3078+ [this](const std::pair<std::string, Tensor *> &item) { 3079+ return std::find(this->origin_weights_.begin(), this->origin_weights_.end(), item.second) != 3080+ this->origin_weights_.end(); 3081+ }); 3082+ } 3083+ 3084 void set_total_buffer_size(size_t size) { total_buffer_size_ = size; } 3085 size_t total_buffer_size() const { return total_buffer_size_; } 3086 3087@@ -107,7 +126,11 @@ class CoderContext { 3088 private: 3089 std::vector<Tensor *> graph_inputs_; 3090 std::vector<Tensor *> graph_outputs_; 3091- // primitive const tensors, parsed from model, without packed. 3092+ // primitive const tensors, parsed from model, without packed. Maybe exist tensor is not used. 3093+ std::vector<Tensor *> origin_weights_; 3094+ // assistant content for origin-weights if needed. 3095+ std::map<Tensor *, std::pair<Tensor *, std::string>> auxiliary_weights_; 3096+ // primitive const tensors, parsed from model, with packed. Tensors are all real used. 3097 std::map<std::string, Tensor *> saved_weights_; 3098 // all tensors, include parsed from model and packed tensors. 3099 std::map<Tensor *, std::string> tensors_map_; 3100diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 3101index 058f0ba0..d30e0133 100644 3102--- a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 3103+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 3104@@ -141,6 +141,7 @@ void CodeMSModelDestory(std::ofstream &ofs, const Configurator *config) { 3105 } 3106 ofs << " MSTensorHandleArrayDestroy(micro_model->inputs);\n" 3107 " MSTensorHandleArrayDestroy(micro_model->outputs);\n" 3108+ " FreeResource();\n" 3109 " free(*model);\n" 3110 " *model = NULL;\n" 3111 " }\n"; 3112@@ -331,10 +332,12 @@ void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderCo 3113 } 3114 ofs << " void **allocated[] = {\n"; 3115 size_t num = 0; 3116+ auto &w_auxiliary = ctx->auxiliary_weights(); 3117 for (const auto &item : ctx->tensors_map()) { 3118 Tensor *tensor = item.first; 3119 std::string name = item.second; 3120- if (tensor->data() != nullptr && !(CheckConstantTensor(tensor))) { 3121+ if (tensor->data() != nullptr && 3122+ (!(CheckConstantTensor(tensor)) || w_auxiliary.find(tensor) != w_auxiliary.end())) { 3123 ofs << " (void**)&" << name << ",\n"; 3124 num++; 3125 } 3126diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 3127index 4d102391..d0824ecb 100644 3128--- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 3129+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 3130@@ -22,6 +22,32 @@ 3131 #include "coder/opcoders/parallel.h" 3132 3133 namespace mindspore::lite::micro { 3134+namespace { 3135+struct camp { 3136+ bool operator()(const std::string &a, const std::string &b) const { return a.size() < b.size() || a < b; } 3137+}; 3138+ 3139+std::string GenerateArrayContent(const std::vector<size_t> &contents, const std::string &prefix) { 3140+ std::string lines; 3141+ std::string line = prefix; 3142+ for (auto content : contents) { 3143+ std::string append = std::to_string(content) + ", "; 3144+ if (line == prefix) { 3145+ line += append; 3146+ continue; 3147+ } 3148+ if (line.size() + append.size() > 120) { 3149+ lines += line + "\n"; 3150+ line = prefix + append; 3151+ } else { 3152+ line += append; 3153+ } 3154+ } 3155+ lines += line + "\n"; 3156+ return lines; 3157+} 3158+} // namespace 3159+ 3160 void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 3161 ofs << g_hwLicense; 3162 // include all operator header 3163@@ -71,10 +97,11 @@ void CodeModelParamsData(std::ofstream &ofs, const std::map<std::string, Tensor 3164 void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::unique_ptr<CoderContext> &ctx, 3165 const Configurator &config) { 3166 // reverse key and value of tensors_map 3167- std::map<std::string, Tensor *> address_map; 3168+ std::map<std::string, Tensor *, camp> address_map; 3169 for (const auto &item : ctx->tensors_map()) { 3170 address_map.insert(std::make_pair(item.second, item.first)); 3171 } 3172+ auto &w_auxiliary = ctx->auxiliary_weights(); 3173 for (auto &item : address_map) { 3174 std::string name = item.first; 3175 Tensor *tensor = item.second; 3176@@ -83,13 +110,22 @@ void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std:: 3177 } 3178 if (CheckConstantTensor(tensor)) { 3179 if (config.target() != kCortex_M) { 3180- hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[];\n"; 3181- cofs << GetTensorDataType(tensor->data_type()) << name << "[" << tensor->ElementsNum() << "];\n"; 3182+ if (w_auxiliary.find(tensor) == w_auxiliary.end()) { 3183+ hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[]; // " << tensor->tensor_name() 3184+ << std::endl; 3185+ cofs << GetTensorDataType(tensor->data_type()) << name << "[" << tensor->ElementsNum() << "];\n"; 3186+ } else { 3187+ hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << "; // " 3188+ << tensor->tensor_name() << std::endl; 3189+ cofs << GetTensorDataType(tensor->data_type()) << "*" << name << " = NULL;\n"; 3190+ } 3191 } else { 3192- hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[];\n"; 3193+ hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[]; // " << tensor->tensor_name() 3194+ << std::endl; 3195 } 3196 } else if (tensor->category() == lite::Category::VAR) { 3197- hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << ";\n"; 3198+ hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << "; // " << tensor->tensor_name() 3199+ << std::endl; 3200 cofs << GetTensorDataType(tensor->data_type()) << "*" << name << " = NULL;\n"; 3201 } 3202 } 3203@@ -104,6 +140,186 @@ void CodeInitWeightState(std::ofstream &ofs) { 3204 << "int Init(void *weight_buffer, int weight_size);\n\n"; 3205 } 3206 3207+void CodeWeightContentInit(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, 3208+ const std::map<Tensor *, int> &tensors_index) { 3209+ auto &w_auxiliary = ctx->auxiliary_weights(); 3210+ std::map<std::string, Tensor *, camp> real_need_tensors; 3211+ auto record_saved_tensors = ctx->saved_weights(); 3212+ for (auto &item : record_saved_tensors) { 3213+ real_need_tensors.insert(std::make_pair(item.first, item.second)); 3214+ } 3215+ std::string non_copy; 3216+ std::string copy_static; 3217+ std::string copy_dynamic; 3218+ int copy_static_num = 0; 3219+ int copy_dynamic_num = 0; 3220+ auto tensors_map = ctx->tensors_map(); 3221+ for (const auto &item : real_need_tensors) { 3222+ if (!CheckConstantTensor(item.second) || item.second->data() == nullptr) { 3223+ continue; 3224+ } 3225+ auto iter = tensors_map.find(item.second); 3226+ if (iter == tensors_map.end()) { 3227+ TypeId data_type = item.second->data_type(); 3228+ non_copy += " " + GetTensorDataType(data_type) + "*" + item.first + " = (weight_buffer + offsets[" + 3229+ std::to_string(tensors_index.at(item.second)) + "]);\n"; 3230+ continue; 3231+ } 3232+ if (w_auxiliary.find(item.second) == w_auxiliary.end()) { 3233+ copy_static += " {" + item.first + ", " + std::to_string(tensors_index.at(item.second)) + "},\n"; 3234+ ++copy_static_num; 3235+ } else { 3236+ copy_dynamic += " {&" + item.first + ", " + std::to_string(tensors_index.at(item.second)) + "},\n"; 3237+ ++copy_dynamic_num; 3238+ } 3239+ } 3240+ for (const auto &item : w_auxiliary) { 3241+ copy_static += " {" + item.second.second + ", " + std::to_string(tensors_index.at(item.second.first)) + "},\n"; 3242+ ++copy_static_num; 3243+ } 3244+ ofs << non_copy << "\n"; 3245+ if (copy_static_num > 0) { 3246+ ofs << " {\n struct ModelParameter static_copy[] = {\n" << copy_static << " };\n"; 3247+ ofs << " for(int i = 0; i < " << copy_static_num << "; ++i) {\n" 3248+ << " int index = static_copy[i].index;\n" 3249+ << " if (offsets[index] + tensors_size[index] > weight_size) {\n" 3250+ " return RET_ERROR;\n" 3251+ " }\n" 3252+ << " memcpy(static_copy[i].addr, (weight_buffer + offsets[index]), tensors_size[index]);\n" 3253+ << " }\n }\n\n"; 3254+ } 3255+ ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize() + dynamic_memory;\n"; 3256+ ofs << " if (" << ctx->weight_size_name() << " > 0) {\n"; 3257+ ofs << " " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n"; 3258+ ofs << " if (" << ctx->weight_name() << " == NULL) {\n return RET_ERROR;\n }\n"; 3259+ ofs << " memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n"; 3260+ ofs << " }\n"; 3261+ ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; 3262+ if (copy_dynamic_num > 0) { 3263+ ofs << " {\n struct ModelParameter dynamic_copy[] = {\n" << copy_dynamic << " };\n"; 3264+ ofs << " for(int i = 0; i < " << copy_dynamic_num << "; ++i) {\n" 3265+ << " int index = dynamic_copy[i].index;\n" 3266+ << " memcpy(" << ctx->weight_name() << " + " << ctx->weight_offset_name() 3267+ << ", (weight_buffer + offsets[index]), tensors_size[index]);\n" 3268+ << " *((void **)dynamic_copy[i].addr) = " << ctx->weight_name() << " + " << ctx->weight_offset_name() 3269+ << ";\n" 3270+ << " " << ctx->weight_offset_name() << " += tensors_size[index];\n" 3271+ << " }\n }\n\n"; 3272+ } 3273+} 3274+ 3275+void CodeWeightInitIfKeepWeight(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 3276+ auto &w_origin = ctx->origin_weights(); 3277+ auto &w_auxiliary = ctx->auxiliary_weights(); 3278+ std::vector<size_t> tensors_size; 3279+ std::vector<size_t> online_compute_index; 3280+ std::map<Tensor *, int> tensors_index; 3281+ for (auto tensor : w_origin) { 3282+ if (!(CheckConstantTensor(tensor)) || tensor->data() == nullptr) { 3283+ continue; 3284+ } 3285+ auto iter = w_auxiliary.find(tensor); 3286+ if (iter == w_auxiliary.end()) { 3287+ tensors_index[tensor] = tensors_size.size(); 3288+ tensors_size.push_back(tensor->Size()); 3289+ } else { 3290+ tensors_index[iter->second.first] = tensors_size.size(); 3291+ tensors_size.push_back(iter->second.first->Size()); 3292+ tensors_index[tensor] = tensors_size.size(); 3293+ online_compute_index.push_back(tensors_size.size()); 3294+ tensors_size.push_back(DataTypeSize(tensor->data_type())); 3295+ } 3296+ } 3297+ std::vector<size_t> offsets{0}; 3298+ int last = online_compute_index.empty() ? tensors_size.size() - 1 : online_compute_index.front(); 3299+ for (int i = 1; i <= last; ++i) { 3300+ offsets.push_back(offsets[i - 1] + tensors_size[i - 1]); 3301+ } 3302+ ofs << "int Init(void *weight_buffer, int weight_size) {\n" 3303+ << " if (weight_buffer == NULL) {\n" 3304+ << " return RET_ERROR;\n" 3305+ << " }\n"; 3306+ ofs << " struct ModelParameter {\n" 3307+ << " void *addr;\n" 3308+ << " int index;\n" 3309+ << " };\n"; 3310+ ofs << " int offsets[" << std::to_string(tensors_size.size()) << "] = {\n" 3311+ << GenerateArrayContent(offsets, " ") << " };\n"; 3312+ ofs << " size_t tensors_size[" << std::to_string(tensors_size.size()) << "] = {\n" 3313+ << GenerateArrayContent(tensors_size, " ") << " };\n"; 3314+ ofs << " size_t dynamic_memory = 0;\n"; 3315+ offsets.insert(offsets.end(), tensors_size.size() - offsets.size(), 0); 3316+ if (!online_compute_index.empty()) { 3317+ ofs << " int online_compute_index[] = {\n" << GenerateArrayContent(online_compute_index, " ") << " };\n"; 3318+ ofs << " for (size_t i = 0; i < " << std::to_string(online_compute_index.size()) + "; ++i) {\n"; 3319+ ofs << " int *shape = (int *)(weight_buffer + offsets[online_compute_index[i] - 1]);\n"; 3320+ ofs << " int dim_num = tensors_size[online_compute_index[i] - 1] / 4;\n"; 3321+ ofs << " size_t tensor_size = tensors_size[online_compute_index[i]];\n"; 3322+ ofs << " for (int j = 0; j < dim_num; ++j) {\n"; 3323+ ofs << " tensor_size *= shape[j];\n"; 3324+ ofs << " }\n"; 3325+ ofs << " tensors_size[online_compute_index[i]] = tensor_size;\n"; 3326+ ofs << " dynamic_memory += tensor_size;\n"; 3327+ ofs << " int next_index = (i + 1) < " << std::to_string(online_compute_index.size()) 3328+ << " ? online_compute_index[i + 1] : " << std::to_string(tensors_size.size()) << " - 1;\n"; 3329+ ofs << " for (int j = online_compute_index[i] + 1; j <= next_index; ++j) {\n"; 3330+ ofs << " offsets[j] = offsets[j - 1] + tensors_size[j - 1];\n"; 3331+ ofs << " }\n }\n"; 3332+ } 3333+ CodeWeightContentInit(ofs, ctx, tensors_index); 3334+} 3335+ 3336+void CodeWeightInitIfNonKeepWeight(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 3337+ ofs << "int Init(void *weight_buffer, int weight_size) {\n" 3338+ << " if (weight_buffer == NULL) {\n" 3339+ << " return RET_ERROR;\n" 3340+ << " }\n"; 3341+ ofs << " struct ModelParameter {\n" 3342+ << " void *addr;\n" 3343+ << " size_t size;\n" 3344+ << " size_t offset;\n" 3345+ << " };\n"; 3346+ 3347+ ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize();\n"; 3348+ size_t params_num = 0; 3349+ size_t offset = 0; 3350+ std::string params; 3351+ std::string origins; 3352+ for (const auto &item : ctx->saved_weights()) { 3353+ std::string name = item.first; 3354+ Tensor *tensor = item.second; 3355+ if (!CheckConstantTensor(tensor)) { 3356+ continue; 3357+ } 3358+ std::map<Tensor *, std::string> ctx_tensor_map = ctx->tensors_map(); 3359+ auto iter = ctx_tensor_map.find(tensor); 3360+ if (iter != ctx_tensor_map.end()) { 3361+ origins += " {" + name + ", " + std::to_string(tensor->Size()) + ", " + std::to_string(offset) + "},\n"; 3362+ params_num++; 3363+ } else { 3364+ TypeId data_type = tensor->data_type(); 3365+ params += 3366+ " " + GetTensorDataType(data_type) + "*" + name + " = (weight_buffer + " + std::to_string(offset) + ");\n"; 3367+ } 3368+ offset += tensor->Size(); 3369+ } 3370+ ofs << params << "\n"; 3371+ ofs << " struct ModelParameter model_params[] = {\n" << origins << " };\n"; 3372+ ofs << "\n"; 3373+ ofs << " for(int i = 0; i < " << params_num << "; ++i) {\n" 3374+ << " if (model_params[i].offset + model_params[i].size > weight_size) {\n" 3375+ " return RET_ERROR;\n" 3376+ " }\n" 3377+ << " memcpy(model_params[i].addr, (weight_buffer + model_params[i].offset), model_params[i].size);\n" 3378+ << " }\n"; 3379+ ofs << " if (" << ctx->weight_size_name() << " > 0) {\n"; 3380+ ofs << " " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n"; 3381+ ofs << " if (" << ctx->weight_name() << " == NULL) {\n return RET_ERROR;\n }\n"; 3382+ ofs << " memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n"; 3383+ ofs << " }\n"; 3384+ ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; 3385+} 3386+ 3387 void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 3388 if (config.target() != kCortex_M) { 3389 ofs << "static size_t PackWeightSize() {\n"; 3390@@ -114,58 +330,16 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> 3391 ofs << " return w_size;\n"; 3392 ofs << "}\n\n"; 3393 3394- ofs << "int Init(void *weight_buffer, int weight_size) {\n" 3395- << " if (weight_buffer == NULL) {\n" 3396- << " return RET_ERROR;\n" 3397- << " }\n"; 3398- ofs << " struct ModelParameter {\n" 3399- << " void *addr;\n" 3400- << " size_t size;\n" 3401- << " size_t offset;\n" 3402- << " };\n"; 3403- 3404- ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize();\n"; 3405- size_t params_num = 0; 3406- size_t offset = 0; 3407- std::string params; 3408- std::string origins; 3409- for (const auto &item : ctx->saved_weights()) { 3410- std::string name = item.first; 3411- Tensor *tensor = item.second; 3412- if (!CheckConstantTensor(tensor)) { 3413- continue; 3414- } 3415- std::map<Tensor *, std::string> ctx_tensor_map = ctx->tensors_map(); 3416- auto iter = ctx_tensor_map.find(tensor); 3417- if (iter != ctx_tensor_map.end()) { 3418- origins += " {" + name + ", " + std::to_string(tensor->Size()) + ", " + std::to_string(offset) + "},\n"; 3419- params_num++; 3420- } else { 3421- TypeId data_type = tensor->data_type(); 3422- params += 3423- " " + GetTensorDataType(data_type) + "*" + name + " = (weight_buffer + " + std::to_string(offset) + ");\n"; 3424- } 3425- offset += tensor->Size(); 3426- } 3427- ofs << params << "\n"; 3428- ofs << " struct ModelParameter model_params[] = {\n" << origins << " };\n"; 3429- ofs << "\n"; 3430- ofs << " for(int i = 0; i < " << params_num << "; ++i) {\n" 3431- << " if (model_params[i].offset + model_params[i].size > weight_size) {\n" 3432- " return RET_ERROR;\n" 3433- " }\n" 3434- << " memcpy(model_params[i].addr, (weight_buffer + model_params[i].offset), model_params[i].size);\n" 3435- << " }\n"; 3436- ofs << " if (" << ctx->weight_size_name() << " > 0) {\n"; 3437- ofs << " " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n"; 3438- ofs << " if (" << ctx->weight_name() << " == NULL) {\n return RET_ERROR;\n }\n"; 3439- ofs << " memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n"; 3440- ofs << " }\n"; 3441+ if (config.keep_original_weight()) { 3442+ CodeWeightInitIfKeepWeight(ofs, ctx); 3443+ } else { 3444+ CodeWeightInitIfNonKeepWeight(ofs, ctx); 3445+ } 3446 } else { 3447 ofs << "int Init(void *weight_buffer, int weight_size) {\n"; 3448 ofs << " const size_t w_size = " << ctx->weight_buffer_size() << ";\n"; 3449+ ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; 3450 } 3451- ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; 3452 for (const auto &block : ctx->init_contents()) { 3453 ofs << "{\n" << block << "}\n"; 3454 } 3455@@ -175,11 +349,26 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr<CoderContext> 3456 ofs << "}\n\n"; 3457 } 3458 3459-void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const std::string &net_file) { 3460+void SaveDataToNet(const std::unique_ptr<CoderContext> &ctx, const std::string &net_file, bool keep_weight) { 3461 std::ofstream net(net_file, std::ios::out | std::ios::trunc | std::ios::binary); 3462 MS_CHECK_TRUE_WITHOUT_RET(net.is_open(), "net file open failed!"); 3463- for (auto &item : saved_weights) { 3464- Tensor *tensor = item.second; 3465+ std::vector<Tensor *> save_tensors; 3466+ if (keep_weight) { 3467+ auto &w_origin = ctx->origin_weights(); 3468+ auto &w_auxiliary = ctx->auxiliary_weights(); 3469+ (void)std::for_each(w_origin.begin(), w_origin.end(), [&save_tensors, &w_auxiliary](Tensor *tensor) { 3470+ auto iter = w_auxiliary.find(tensor); 3471+ if (iter != w_auxiliary.end()) { 3472+ save_tensors.push_back(iter->second.first); 3473+ } 3474+ save_tensors.push_back(tensor); 3475+ }); 3476+ } else { 3477+ auto recorded_saved_tensors = ctx->saved_weights(); 3478+ (void)std::transform(recorded_saved_tensors.begin(), recorded_saved_tensors.end(), std::back_inserter(save_tensors), 3479+ [](const std::pair<std::string, Tensor *> &item) { return item.second; }); 3480+ } 3481+ for (auto tensor : save_tensors) { 3482 if ((CheckConstantTensor(tensor)) && tensor->data() != nullptr) { 3483 net.write(reinterpret_cast<const char *>(tensor->data()), tensor->Size()); 3484 } 3485diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h 3486index 3a68a540..98c56afd 100644 3487--- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h 3488+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h 3489@@ -31,7 +31,7 @@ void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr<CoderContext 3490 void CodeModelParamsState(std::ofstream &ofs, const std::map<std::string, Tensor *> &weights); 3491 void CodeModelParamsData(std::ofstream &ofs, const std::map<std::string, Tensor *> &weights); 3492 3493-void SaveDataToNet(const std::map<std::string, Tensor *> &saved_weights, const std::string &net_file); 3494+void SaveDataToNet(const std::unique_ptr<CoderContext> &ctx, const std::string &net_file, bool keep_weight); 3495 void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::unique_ptr<CoderContext> &ctx, 3496 const Configurator &config); 3497 3498diff --git a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 3499index 5b29978f..8add577f 100644 3500--- a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 3501+++ b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 3502@@ -259,7 +259,7 @@ int Generator::CodeWeightFile() { 3503 cofs << "unsigned char * " << ctx_->buffer_name() << " = 0; \n"; 3504 cofs << "unsigned char * " << ctx_->weight_name() << " = 0; \n"; 3505 std::string net_file = net_src_file_path_ + "net.bin"; 3506- SaveDataToNet(ctx_->saved_weights(), net_file); 3507+ SaveDataToNet(ctx_, net_file, config_->keep_original_weight()); 3508 } else { 3509 if (!ctx_->weight_buffer_size_code_blocks().empty()) { 3510 MS_LOG(ERROR) << "Weight init code generation error "; 3511diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc 3512index d25b3e6b..56b22333 100644 3513--- a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc 3514+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc 3515@@ -41,13 +41,18 @@ int ReshapeBaseCoder::DoCode(CoderContext *const context) { 3516 } 3517 3518 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Reshape, CPUOpCoderCreator<ReshapeBaseCoder>) 3519+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Reshape, CPUOpCoderCreator<ReshapeBaseCoder>) 3520 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Reshape, CPUOpCoderCreator<ReshapeBaseCoder>) 3521 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Flatten, CPUOpCoderCreator<ReshapeBaseCoder>) 3522+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Flatten, CPUOpCoderCreator<ReshapeBaseCoder>) 3523 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Flatten, CPUOpCoderCreator<ReshapeBaseCoder>) 3524 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_ExpandDims, CPUOpCoderCreator<ReshapeBaseCoder>) 3525+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_ExpandDims, CPUOpCoderCreator<ReshapeBaseCoder>) 3526 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_ExpandDims, CPUOpCoderCreator<ReshapeBaseCoder>) 3527 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Squeeze, CPUOpCoderCreator<ReshapeBaseCoder>) 3528+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Squeeze, CPUOpCoderCreator<ReshapeBaseCoder>) 3529 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Squeeze, CPUOpCoderCreator<ReshapeBaseCoder>) 3530 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CPUOpCoderCreator<ReshapeBaseCoder>) 3531+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Unsqueeze, CPUOpCoderCreator<ReshapeBaseCoder>) 3532 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Unsqueeze, CPUOpCoderCreator<ReshapeBaseCoder>) 3533 } // namespace mindspore::lite::micro 3534diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc 3535new file mode 100644 3536index 00000000..ee887342 3537--- /dev/null 3538+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc 3539@@ -0,0 +1,85 @@ 3540+/** 3541+ * Copyright 2023 Huawei Technologies Co., Ltd 3542+ * 3543+ * Licensed under the Apache License, Version 2.0 (the "License"); 3544+ * you may not use this file except in compliance with the License. 3545+ * You may obtain a copy of the License at 3546+ * 3547+ * http://www.apache.org/licenses/LICENSE-2.0 3548+ * 3549+ * Unless required by applicable law or agreed to in writing, software 3550+ * distributed under the License is distributed on an "AS IS" BASIS, 3551+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3552+ * See the License for the specific language governing permissions and 3553+ * limitations under the License. 3554+ */ 3555+#include "coder/opcoders/base/stack_base_coder.h" 3556+#include <string> 3557+#include <vector> 3558+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 3559+#include "coder/opcoders/file_collector.h" 3560+#include "coder/opcoders/parallel.h" 3561+ 3562+using mindspore::schema::PrimitiveType_Stack; 3563+ 3564+namespace mindspore::lite::micro::nnacl { 3565+int StackFP32Coder::Prepare(CoderContext *const context) { 3566+ stack_param_ = reinterpret_cast<StackParameter *>(parameter_); 3567+ return ReSize(); 3568+} 3569+ 3570+int StackFP32Coder::ReSize() { 3571+ axis_ = stack_param_->axis_ >= 0 ? stack_param_->axis_ 3572+ : static_cast<int>(input_tensor_->shape().size()) + stack_param_->axis_ + 1; 3573+ if (axis_ < 0 || axis_ > static_cast<int>(input_tensor_->shape().size())) { 3574+ return RET_ERROR; 3575+ } 3576+ return RET_OK; 3577+} 3578+ 3579+int StackFP32Coder::DoCode(CoderContext *const context) { 3580+ Collect(context, 3581+ { 3582+ "nnacl/base/stack_base.h", 3583+ }, 3584+ { 3585+ "stack_base.c", 3586+ }); 3587+ 3588+ size_t input_num = input_tensors_.size(); 3589+ 3590+ NNaclFp32Serializer code; 3591+ code << "\t\tvoid *inputs_addr[] = {"; 3592+ for (size_t i = 0; i < input_num; ++i) { 3593+ code << allocator_->GetRuntimeAddr(input_tensors_.at(i)) << ", "; 3594+ } 3595+ code << "};\n"; 3596+ 3597+ size_t copy_size = 0; 3598+ int outer_size = 1; 3599+ auto shape = input_tensor_->shape(); 3600+ if (input_tensors_.empty()) { 3601+ copy_size = 0; 3602+ outer_size = 0; 3603+ } else if (input_tensors_.size() == 1) { 3604+ copy_size = input_tensor_->ElementsNum(); 3605+ outer_size = 1; 3606+ } else { 3607+ copy_size = 1; 3608+ for (int i = axis_; i < static_cast<int>(shape.size()); ++i) { 3609+ copy_size *= shape[i]; 3610+ } 3611+ for (int i = 0; i < axis_; ++i) { 3612+ outer_size *= shape[i]; 3613+ } 3614+ } 3615+ copy_size *= DataTypeSize(input_tensor_->data_type()); 3616+ code.CodeFunction("Stack", "inputs_addr", output_tensor_, input_num, copy_size, 0, outer_size); 3617+ context->AppendCode(code.str()); 3618+ return RET_OK; 3619+} 3620+ 3621+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Stack, CPUOpCoderCreator<StackFP32Coder>) 3622+REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Stack, CPUOpCoderCreator<StackFP32Coder>) 3623+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Stack, CPUOpCoderCreator<StackFP32Coder>) 3624+} // namespace mindspore::lite::micro::nnacl 3625diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h 3626new file mode 100644 3627index 00000000..08074332 3628--- /dev/null 3629+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h 3630@@ -0,0 +1,42 @@ 3631+/** 3632+ * Copyright 2023 Huawei Technologies Co., Ltd 3633+ * 3634+ * Licensed under the Apache License, Version 2.0 (the "License"); 3635+ * you may not use this file except in compliance with the License. 3636+ * You may obtain a copy of the License at 3637+ * 3638+ * http://www.apache.org/licenses/LICENSE-2.0 3639+ * 3640+ * Unless required by applicable law or agreed to in writing, software 3641+ * distributed under the License is distributed on an "AS IS" BASIS, 3642+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3643+ * See the License for the specific language governing permissions and 3644+ * limitations under the License. 3645+ */ 3646+ 3647+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_ 3648+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_ 3649+ 3650+#include <vector> 3651+#include "coder/opcoders/op_coder.h" 3652+#include "nnacl/stack_parameter.h" 3653+ 3654+namespace mindspore::lite::micro::nnacl { 3655+class StackFP32Coder final : public OperatorCoder { 3656+ public: 3657+ StackFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 3658+ const LiteGraph::Node *node, size_t node_index, Target target) 3659+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 3660+ ~StackFP32Coder() override = default; 3661+ 3662+ int Prepare(CoderContext *const context) override; 3663+ int DoCode(CoderContext *const context) override; 3664+ 3665+ private: 3666+ int ReSize(); 3667+ 3668+ int axis_{0}; 3669+ StackParameter *stack_param_{nullptr}; 3670+}; 3671+} // namespace mindspore::lite::micro::nnacl 3672+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_ 3673diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc 3674index ba9fbaa1..ffc70e1c 100644 3675--- a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc 3676+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc 3677@@ -33,6 +33,8 @@ size_t GetInnerSize(TypeId type_id, int inner_elements) { 3678 return inner_elements * sizeof(float); 3679 case kNumberTypeInt32: 3680 return inner_elements * sizeof(int32_t); 3681+ case kNumberTypeFloat16: 3682+ return inner_elements * sizeof(uint16_t); 3683 default: 3684 MS_LOG(ERROR) << "Not supported data type: " << type_id; 3685 return 0; 3686@@ -142,6 +144,23 @@ int StridedSliceBaseCoder::DoFastCode(CoderContext *ctx) { 3687 } 3688 3689 int StridedSliceBaseCoder::DoNormalCode(CoderContext *ctx) { 3690+ switch (input_tensor_->data_type()) { 3691+ case kNumberTypeInt8: 3692+ strided_slice_parameter_->data_type = ::kNumberTypeInt8; 3693+ break; 3694+ case kNumberTypeFloat32: 3695+ strided_slice_parameter_->data_type = ::kNumberTypeFloat32; 3696+ break; 3697+ case kNumberTypeInt32: 3698+ strided_slice_parameter_->data_type = ::kNumberTypeInt32; 3699+ break; 3700+ case kNumberTypeFloat16: 3701+ strided_slice_parameter_->data_type = ::kNumberTypeFloat16; 3702+ break; 3703+ default: 3704+ MS_LOG(ERROR) << "Not supported data type: " << input_tensor_->data_type(); 3705+ return RET_ERROR; 3706+ } 3707 nnacl::NNaclFp32Serializer code; 3708 code.CodeStruct("strided_slice_parameter", *strided_slice_parameter_); 3709 code.CodeFunction("DoStridedSlice", input_tensor_, output_tensor_, 3710@@ -166,6 +185,8 @@ int StridedSliceBaseCoder::DoCode(CoderContext *ctx) { 3711 } 3712 REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_StridedSlice, 3713 CPUOpCoderCreator<StridedSliceBaseCoder>) 3714+REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_StridedSlice, 3715+ CPUOpCoderCreator<StridedSliceBaseCoder>) 3716 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_StridedSlice, CPUOpCoderCreator<StridedSliceBaseCoder>) 3717 REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_StridedSlice, CPUOpCoderCreator<StridedSliceBaseCoder>) 3718 } // namespace mindspore::lite::micro 3719diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc 3720new file mode 100644 3721index 00000000..5470b56a 3722--- /dev/null 3723+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc 3724@@ -0,0 +1,34 @@ 3725+/** 3726+ * Copyright 2023 Huawei Technologies Co., Ltd 3727+ * 3728+ * Licensed under the Apache License, Version 2.0 (the "License"); 3729+ * you may not use this file except in compliance with the License. 3730+ * You may obtain a copy of the License at 3731+ * 3732+ * http://www.apache.org/licenses/LICENSE-2.0 3733+ * 3734+ * Unless required by applicable law or agreed to in writing, software 3735+ * distributed under the License is distributed on an "AS IS" BASIS, 3736+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3737+ * See the License for the specific language governing permissions and 3738+ * limitations under the License. 3739+ */ 3740+#include "coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h" 3741+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 3742+#include "coder/opcoders/file_collector.h" 3743+ 3744+using mindspore::schema::PrimitiveType_Custom; 3745+ 3746+namespace mindspore::lite::micro::nnacl { 3747+void CustomGruFP16Coder::InitNnaclFile(CoderContext *const context) { 3748+ Collect(context, {"nnacl/fp16/custom_gru_fp16.h"}, 3749+ {"custom_gru_fp16.c", "pack_fp16.c", "matmul_fp16.c", "arithmetic_fp16.c", "activation_fp16.c"}); 3750+} 3751+ 3752+void CustomGruFP16Coder::InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, 3753+ int row, int col) { 3754+ init_code->CodeFunction("RowMajor2Col8MajorFp16", src, dst, row, col, false); 3755+} 3756+ 3757+REG_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Custom, CPUOpCoderCreator<CustomGruFP16Coder>) 3758+} // namespace mindspore::lite::micro::nnacl 3759diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h 3760new file mode 100644 3761index 00000000..eb76faf6 3762--- /dev/null 3763+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h 3764@@ -0,0 +1,44 @@ 3765+/** 3766+ * Copyright 2023 Huawei Technologies Co., Ltd 3767+ * 3768+ * Licensed under the Apache License, Version 2.0 (the "License"); 3769+ * you may not use this file except in compliance with the License. 3770+ * You may obtain a copy of the License at 3771+ * 3772+ * http://www.apache.org/licenses/LICENSE-2.0 3773+ * 3774+ * Unless required by applicable law or agreed to in writing, software 3775+ * distributed under the License is distributed on an "AS IS" BASIS, 3776+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3777+ * See the License for the specific language governing permissions and 3778+ * limitations under the License. 3779+ */ 3780+ 3781+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_ 3782+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_ 3783+ 3784+#include <string> 3785+#include <vector> 3786+#include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" 3787+#include "nnacl/custom_gru_parameter.h" 3788+ 3789+namespace mindspore::lite::micro::nnacl { 3790+class CustomGruFP16Coder : public CustomGruFP32Coder { 3791+ public: 3792+ CustomGruFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 3793+ const LiteGraph::Node *node, size_t node_index, Target target) 3794+ : CustomGruFP32Coder(in_tensors, out_tensors, node, node_index, target) { 3795+ data_type_ = kNumberTypeFloat16; 3796+ row_tile_ = C4NUM; 3797+ col_tile_ = C8NUM; 3798+ op_func_ = "CustomGruFp16"; 3799+ } 3800+ ~CustomGruFP16Coder() override = default; 3801+ 3802+ protected: 3803+ void InitNnaclFile(CoderContext *const context) override; 3804+ void InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, int row, 3805+ int col) override; 3806+}; 3807+} // namespace mindspore::lite::micro::nnacl 3808+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_ 3809diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 3810index f2aec9d2..37b90b65 100644 3811--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 3812+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 3813@@ -30,13 +30,12 @@ int MatMulFP16BaseCoder::InitBiasData() { 3814 if (bias_ptr_) { 3815 return RET_OK; 3816 } 3817- bias_pack_ptr_size_ = static_cast<size_t>(params_->col_align_ * data_type_size_); 3818+ bias_pack_ptr_size_ = static_cast<size_t>(params_->col_align_ * DataTypeSize(data_type_)); 3819 if (input_tensors_.size() == C3NUM) { 3820- bias_ptr_ = allocator_->Malloc(kNumberTypeUInt8, kOnlineSize, kOnlinePackWeight, 3821- bias_tensor_->tensor_name() + "_online_pack"); 3822- } else { 3823 bias_ptr_ = 3824- allocator_->Malloc(kNumberTypeUInt8, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 3825+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 3826+ } else { 3827+ bias_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 3828 } 3829 return RET_OK; 3830 } 3831@@ -45,18 +44,19 @@ int MatMulFP16BaseCoder::InitBufferA() { 3832 if (a_pack_ptr_ != nullptr || vec_matmul_) { 3833 return RET_OK; 3834 } 3835- a_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * sizeof(uint16_t)); 3836+ a_pack_ptr_size_ = 3837+ static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * DataTypeSize(data_type_)); 3838 if (params_->a_const_) { 3839 a_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(0)); 3840 if (a_pack_ptr_ == nullptr) { 3841- a_pack_ptr_ = allocator_->Malloc(kNumberTypeFloat16, kOnlineSize, kOnlinePackWeight, 3842+ a_pack_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, 3843 input_tensors_.at(0)->tensor_name() + "_online_pack"); 3844 allocator_->MarkSharedWeight(input_tensors_.at(0), a_pack_ptr_); 3845 } else { 3846 a_packed_ = true; 3847 } 3848 } else { 3849- a_pack_ptr_ = allocator_->Malloc(kNumberTypeFloat16, a_pack_ptr_size_, kWorkspace); 3850+ a_pack_ptr_ = allocator_->Malloc(data_type_, a_pack_ptr_size_, kWorkspace); 3851 } 3852 MS_CHECK_PTR(a_pack_ptr_); 3853 return RET_OK; 3854@@ -77,7 +77,7 @@ std::string MatMulFP16BaseCoder::InitMatrixA(NNaclFp32Serializer *const code, NN 3855 return allocator_->GetRuntimeAddr(input_tensor_, input_tensor_->IsConst()); 3856 } 3857 std::string input_a_str = allocator_->GetRuntimeAddr(input_tensor_); 3858- std::string input_a_pack_str = "(float16_t *)" + allocator_->GetRuntimeAddr(a_pack_ptr_); 3859+ std::string input_a_pack_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(a_pack_ptr_)); 3860 if (params_->a_const_) { 3861 init_code->CodeBufferOffsetExpression(a_pack_ptr_, context->weight_name(), context->weight_offset_name(), 3862 context->weight_size_name(), a_pack_ptr_size_); 3863@@ -132,7 +132,7 @@ std::string MatMulFP16BaseCoder::InitMatrixB(NNaclFp32Serializer *const code, NN 3864 return allocator_->GetRuntimeAddr(filter_tensor_, filter_tensor_->IsConst()); 3865 } 3866 std::string input_b_str = allocator_->GetRuntimeAddr(filter_tensor_); 3867- std::string input_b_pack_str = "(float16_t *)" + allocator_->GetRuntimeAddr(b_pack_ptr_); 3868+ std::string input_b_pack_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(b_pack_ptr_)); 3869 if (params_->b_const_) { 3870 init_code->CodeBufferOffsetExpression(b_pack_ptr_, context->weight_name(), context->weight_offset_name(), 3871 context->weight_size_name(), b_pack_ptr_size_); 3872@@ -248,7 +248,7 @@ int MatMulFP16BaseCoder::DoCode(CoderContext *const context) { 3873 init_code.CodeBufferOffsetExpression(bias_ptr_, context->weight_name(), context->weight_offset_name(), 3874 context->weight_size_name(), bias_pack_ptr_size_); 3875 w_buf_size += bias_pack_ptr_size_; 3876- std::string bias_str = "(float16_t *)" + allocator_->GetRuntimeAddr(bias_ptr_); 3877+ std::string bias_str = allocator_->GetRuntimeAddr(bias_ptr_); 3878 if (input_tensors_.size() == DIMENSION_3D) { 3879 auto origin_bias_str = allocator_->GetRuntimeAddr(bias_tensor_); 3880 init_code.CodeFunction("memcpy", bias_str, origin_bias_str, bias_tensor_->Size()); 3881diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h 3882index 864f54ae..38270456 100644 3883--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h 3884+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h 3885@@ -27,7 +27,9 @@ class MatMulFP16BaseCoder : public MatMulFP32BaseCoder { 3886 public: 3887 MatMulFP16BaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 3888 const LiteGraph::Node *node, size_t node_index, Target target) 3889- : MatMulFP32BaseCoder(in_tensors, out_tensors, node, node_index, target) {} 3890+ : MatMulFP32BaseCoder(in_tensors, out_tensors, node, node_index, target) { 3891+ data_type_ = kNumberTypeFloat16; 3892+ } 3893 3894 ~MatMulFP16BaseCoder() override = default; 3895 3896diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h 3897index 3a1cb66a..c5ea36cd 100644 3898--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h 3899+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h 3900@@ -26,9 +26,7 @@ class MatMulFP16Coder final : public MatMulFP16BaseCoder { 3901 public: 3902 MatMulFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 3903 const LiteGraph::Node *node, size_t node_index, Target target) 3904- : MatMulFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) { 3905- data_type_size_ = sizeof(uint16_t); 3906- } 3907+ : MatMulFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) {} 3908 3909 ~MatMulFP16Coder() override = default; 3910 3911diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc 3912index f46005c6..e53472ca 100644 3913--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc 3914+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc 3915@@ -20,42 +20,88 @@ 3916 #include "coder/opcoders/parallel.h" 3917 #include "coder/opcoders/file_collector.h" 3918 #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 3919+#include "src/common/utils.h" 3920 3921 namespace mindspore::lite::micro::nnacl { 3922 int ConvolutionDepthwiseFP32Coder::Prepare(CoderContext *const context) { 3923 MS_CHECK_RET_CODE(Conv2DBaseCoder::Init(), "Conv2DBaseCoder::Init() failed!"); 3924- MS_CHECK_RET_CODE(InitWeightBias(), "dwconvolution do init weightbais failed"); 3925+ MS_CHECK_RET_CODE(InitParameter(), "dwconvolution do InitParamter failed"); 3926+ if (Configurator::GetInstance()->keep_original_weight()) { 3927+ MS_CHECK_RET_CODE(InitWeightBiasOnline(), "dwconvolution do InitWeightBiasOnline failed"); 3928+ } else { 3929+ MS_CHECK_RET_CODE(InitWeightBiasOffline(), "dwconvolution do InitWeightBiasOffline failed"); 3930+ } 3931 conv_param_->thread_num_ = MSMIN(thread_num_, conv_param_->output_h_); 3932 return RET_OK; 3933 } 3934 3935-int ConvolutionDepthwiseFP32Coder::InitWeightBias() { 3936+int ConvolutionDepthwiseFP32Coder::InitParameter() { 3937+ auto shape = filter_tensor_->shape(); 3938+ MS_CHECK_TRUE_MSG(shape.size() == C4NUM, RET_ERROR, "Conv: filter-weight's shape must be 4D."); 3939+ packed_weight_size_ = 3940+ filter_tensor_->Batch() * filter_tensor_->Height() * filter_tensor_->Width() * DataTypeSize(data_type_); 3941+ packed_bias_size_ = filter_tensor_->Batch() * DataTypeSize(data_type_); 3942+ return RET_OK; 3943+} 3944+ 3945+int ConvolutionDepthwiseFP32Coder::InitWeightBiasOffline() { 3946 auto *origin_weight = reinterpret_cast<float *>(filter_tensor_->data()); 3947 MS_CHECK_PTR(origin_weight); 3948 int channel = filter_tensor_->Batch(); 3949- size_t pack_weight_size = filter_tensor_->Batch() * filter_tensor_->Height() * filter_tensor_->Width(); 3950- size_t packed_weight_data_size = pack_weight_size * sizeof(float); 3951- packed_weight_ = 3952- reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, packed_weight_data_size, kOfflinePackWeight)); 3953+ packed_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, packed_weight_size_, kOfflinePackWeight)); 3954 MS_CHECK_PTR(packed_weight_); 3955- MS_CHECK_RET_CODE(memset_s(packed_weight_, packed_weight_data_size, 0, packed_weight_data_size), 3956+ MS_CHECK_RET_CODE(memset_s(packed_weight_, packed_weight_size_, 0, packed_weight_size_), 3957 "memset packed weight failed!"); 3958 PackNCHWToNHWCFp32(origin_weight, packed_weight_, 1, filter_tensor_->Height() * filter_tensor_->Width(), channel, 3959 kDefaultTaskId, 0); 3960 3961- auto bias_size = static_cast<size_t>(channel * sizeof(float)); 3962- bias_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, bias_size, kOfflinePackWeight)); 3963+ bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, packed_bias_size_, kOfflinePackWeight)); 3964 MS_CHECK_PTR(bias_); 3965- MS_CHECK_RET_CODE(memset_s(bias_, bias_size, 0, bias_size), "memset bias failed!"); 3966+ MS_CHECK_RET_CODE(memset_s(bias_, packed_bias_size_, 0, packed_bias_size_), "memset bias failed!"); 3967 // init bias 3968 if (input_tensors_.size() == kInputSize2) { 3969 auto *ori_bias = reinterpret_cast<float *>(bias_tensor_->data()); 3970 MS_CHECK_TRUE(bias_tensor_->ElementsNum() > 0, "invalid bias length"); 3971- MS_CHECK_RET_CODE(memcpy_s(bias_, bias_size, ori_bias, bias_tensor_->Size()), "memcpy_s bias failed!"); 3972+ MS_CHECK_RET_CODE(memcpy_s(bias_, packed_bias_size_, ori_bias, bias_tensor_->Size()), "memcpy_s bias failed!"); 3973 } 3974 return RET_OK; 3975 } 3976 3977+int ConvolutionDepthwiseFP32Coder::InitWeightBiasOnline() { 3978+ packed_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); 3979+ MS_CHECK_PTR(packed_weight_); 3980+ bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); 3981+ MS_CHECK_PTR(bias_); 3982+ return RET_OK; 3983+} 3984+ 3985+void ConvolutionDepthwiseFP32Coder::InitCodeOnline(CoderContext *const context) { 3986+ if (!Configurator::GetInstance()->keep_original_weight()) { 3987+ return; 3988+ } 3989+ Collect(context, 3990+ { 3991+ "nnacl/fp32/pack_fp32.h", 3992+ }, 3993+ {"pack_fp32.c"}); 3994+ NNaclFp32Serializer init_code; 3995+ init_code.CodeBufferOffsetExpression(packed_weight_, context->weight_name(), context->weight_offset_name(), 3996+ context->weight_size_name(), packed_weight_size_); 3997+ auto filter_str = allocator_->GetRuntimeAddr(filter_tensor_); 3998+ init_code.CodeFunction("PackNCHWToNHWCFp32", filter_str, packed_weight_, 1, 3999+ filter_tensor_->Height() * filter_tensor_->Width(), filter_tensor_->Batch(), 0, 0); 4000+ init_code.CodeBufferOffsetExpression(bias_, context->weight_name(), context->weight_offset_name(), 4001+ context->weight_size_name(), packed_bias_size_); 4002+ if (input_tensors_.size() == kInputSize2) { 4003+ auto bias_str = allocator_->GetRuntimeAddr(bias_tensor_); 4004+ init_code.CodeFunction("memcpy", bias_, bias_str, bias_tensor_->Size()); 4005+ } else { 4006+ init_code.CodeFunction("memcpy", bias_, 0, packed_bias_size_); 4007+ } 4008+ context->AppendInitWeightSizeCode(packed_weight_size_ + packed_bias_size_); 4009+ context->AppendInitCode(init_code.str()); 4010+} 4011+ 4012 int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) { 4013 MS_CHECK_TRUE(conv_param_->input_channel_ == conv_param_->output_channel_, 4014 "Only support input channel equals output channel."); 4015@@ -78,6 +124,7 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) { 4016 "activation_fp32.c", 4017 }, 4018 {}); 4019+ InitCodeOnline(context); 4020 nnacl::NNaclFp32Serializer code; 4021 // call the op function 4022 std::string param_name = "conv_parameter"; 4023diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h 4024index a5827f4f..39757871 100644 4025--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h 4026+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h 4027@@ -34,9 +34,15 @@ class ConvolutionDepthwiseFP32Coder final : public Conv2DBaseCoder { 4028 int DoCode(CoderContext *const context) override; 4029 4030 private: 4031- int InitWeightBias(); 4032+ int InitParameter(); 4033+ int InitWeightBiasOffline(); 4034+ int InitWeightBiasOnline(); 4035+ void InitCodeOnline(CoderContext *const context); 4036+ size_t packed_weight_size_{0}; 4037 float *packed_weight_{nullptr}; 4038+ size_t packed_bias_size_{0}; 4039 float *bias_{nullptr}; 4040+ TypeId data_type_{kNumberTypeFloat32}; 4041 }; 4042 } // namespace mindspore::lite::micro::nnacl 4043 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CONVOLUTION_DEPTHWISE_FP32_CODER_H_ 4044diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 4045index 556f851a..466db21a 100644 4046--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 4047+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 4048@@ -77,13 +77,6 @@ const std::array<std::string, kEight> OutputTransFuncRelu6List8 = {"", 4049 "OutputTransform8x6Relu6Unit", 4050 "OutputTransform8x7Relu6Unit"}; 4051 4052-int ConvolutionWinogradFP32Coder::WinogradFilterTransform(const float *weight_data, float *matrix_g, 4053- const float *matrix_gt, int oc_block) { 4054- MS_CHECK_TRUE(oc_block, "Divide by zero!"); 4055- return WinogradWeightTransform(weight_data, trans_weight_, matrix_g, matrix_gt, oc_block, input_unit_, kernel_unit_, 4056- conv_param_->input_channel_, conv_param_->output_channel_, true); 4057-} 4058- 4059 int ConvolutionWinogradFP32Coder::InitTmpBuffer() { 4060 int channel_out = conv_param_->output_channel_; 4061 int oc8 = UP_DIV(channel_out, C8NUM); 4062@@ -115,12 +108,16 @@ int ConvolutionWinogradFP32Coder::Prepare(CoderContext *const context) { 4063 input_unit_ = output_unit_ + kernel_unit_ - 1; 4064 conv_param_->input_unit_ = input_unit_; 4065 conv_param_->output_unit_ = output_unit_; 4066- ret = InitWeightBias(); 4067- MS_CHECK_RET_CODE(ret, "Init weight bias failed."); 4068+ MS_CHECK_RET_CODE(InitParameter(), "Winograd convolution do InitParameter failed"); 4069+ if (Configurator::GetInstance()->keep_original_weight()) { 4070+ MS_CHECK_RET_CODE(InitWeightBiasOnline(), "Winograd convolution do InitWeightBiasOnline failed"); 4071+ } else { 4072+ MS_CHECK_RET_CODE(InitWeightBiasOffline(), "Winograd convolution do InitWeightBiasOffline failed"); 4073+ } 4074 return ReSize(); 4075 } // namespace micro 4076 4077-int ConvolutionWinogradFP32Coder::InitWeightBias() { 4078+int ConvolutionWinogradFP32Coder::InitParameter() { 4079 int in_channel = filter_tensor_->Channel(); 4080 int out_channel = filter_tensor_->Batch(); 4081 MS_CHECK_TRUE(in_channel > 0, "invalid in channel size"); 4082@@ -132,14 +129,10 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() { 4083 const int oc_block = C8NUM; 4084 int oc_block_num = UP_DIV(out_channel, C8NUM); 4085 // init weight 4086- int trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block; 4087- trans_weight_ = reinterpret_cast<float *>( 4088- allocator_->Malloc(kNumberTypeFloat32, trans_matrix_data_size * sizeof(float), kOfflinePackWeight)); 4089- MS_CHECK_PTR(trans_weight_); 4090- int ret = memset_s(trans_weight_, trans_matrix_data_size * sizeof(float), 0, trans_matrix_data_size * sizeof(float)); 4091- MS_CHECK_RET_CODE(ret, "memset_s failed!"); 4092- float matrix_g[k64]; 4093- float matrix_gt[k64]; 4094+ trans_weight_size_ = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * DataTypeSize(data_type_); 4095+ packed_bias_size_ = oc4 * C4NUM * DataTypeSize(data_type_); 4096+ matrix_g_.resize(k64); 4097+ matrix_gt_.resize(k64); 4098 float matrix_a[k64]; 4099 float matrix_at[k64]; 4100 float matrix_b[k64]; 4101@@ -148,31 +141,41 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() { 4102 if (input_unit_ == DIMENSION_8D) { 4103 coef = 0.5f; 4104 } 4105- ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_); 4106+ auto ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g_.data(), matrix_gt_.data(), coef, 4107+ output_unit_, kernel_unit_); 4108 MS_CHECK_RET_CODE(ret, "CookToomFilter failed!"); 4109- auto out_channel_size = static_cast<size_t>(out_channel); 4110+ return RET_OK; 4111+} 4112+ 4113+int ConvolutionWinogradFP32Coder::InitWeightBiasOffline() { 4114+ trans_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, trans_weight_size_, kOfflinePackWeight)); 4115+ MS_CHECK_PTR(trans_weight_); 4116+ int ret = memset_s(trans_weight_, trans_weight_size_, 0, trans_weight_size_); 4117 auto weight_data = reinterpret_cast<float *>(filter_tensor_->data()); 4118 MS_CHECK_PTR(weight_data); 4119- ret = WinogradFilterTransform(weight_data, matrix_g, matrix_gt, oc_block); 4120+ WinogradWeightTransform(weight_data, trans_weight_, matrix_g_.data(), matrix_gt_.data(), C8NUM, input_unit_, 4121+ kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_, true); 4122 MS_CHECK_RET_CODE(ret, "winograd filter transform failed!"); 4123- // init bias 4124- int new_bias_ele_num = oc4 * C4NUM; 4125- auto new_bias_ele_size = static_cast<size_t>(new_bias_ele_num * sizeof(float)); 4126- new_bias_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, new_bias_ele_size, kOfflinePackWeight)); 4127+ new_bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, packed_bias_size_, kOfflinePackWeight)); 4128 MS_CHECK_PTR(new_bias_); 4129- ret = memset_s(new_bias_, new_bias_ele_size, 0, new_bias_ele_size); 4130+ ret = memset_s(new_bias_, packed_bias_size_, 0, packed_bias_size_); 4131 MS_CHECK_RET_CODE(ret, "memset_s failed!"); 4132 if (input_tensors_.size() == kInputSize2) { 4133 auto ori_bias_addr = reinterpret_cast<float *>(bias_tensor_->data()); 4134 MS_CHECK_PTR(ori_bias_addr); 4135- MS_CHECK_RET_CODE(memcpy_s(new_bias_, new_bias_ele_size, ori_bias_addr, out_channel_size * sizeof(float)), 4136- "memcpy_s failed!"); 4137- } else { 4138- MS_CHECK_RET_CODE(memset_s(new_bias_, new_bias_ele_size, 0, new_bias_ele_size), "memset_s failed!"); 4139+ MS_CHECK_RET_CODE(memcpy_s(new_bias_, packed_bias_size_, ori_bias_addr, bias_tensor_->Size()), "memcpy_s failed!"); 4140 } 4141 return RET_OK; 4142 } 4143 4144+int ConvolutionWinogradFP32Coder::InitWeightBiasOnline() { 4145+ trans_weight_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); 4146+ MS_CHECK_PTR(trans_weight_); 4147+ new_bias_ = reinterpret_cast<float *>(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); 4148+ MS_CHECK_PTR(new_bias_); 4149+ return RET_OK; 4150+} 4151+ 4152 int ConvolutionWinogradFP32Coder::ConfigInputOutput() { 4153 trans_func_str_.in_func_ = GetInputTransFunc(input_unit_); 4154 MS_CHECK_TRUE(!trans_func_str_.in_func_.empty(), "Get input_trans_func failed."); 4155@@ -217,6 +220,36 @@ std::string ConvolutionWinogradFP32Coder::GetOutputTransFunc(int input_unit, int 4156 } 4157 } 4158 4159+void ConvolutionWinogradFP32Coder::InitCodeOnline(CoderContext *const context) { 4160+ if (!Configurator::GetInstance()->keep_original_weight()) { 4161+ return; 4162+ } 4163+ Collect(context, 4164+ { 4165+ "nnacl/base/minimal_filtering_generator.h", 4166+ "nnacl/fp32/pack_fp32.h", 4167+ }, 4168+ {"minimal_filtering_generator.c", "nnacl/fp32/pack_fp32.h"}); 4169+ NNaclFp32Serializer init_code; 4170+ init_code.CodeBufferOffsetExpression(trans_weight_, context->weight_name(), context->weight_offset_name(), 4171+ context->weight_size_name(), trans_weight_size_); 4172+ auto filter_str = allocator_->GetRuntimeAddr(filter_tensor_); 4173+ init_code.CodeArray("matrix_g", matrix_g_.data(), k64); 4174+ init_code.CodeArray("matrix_gt", matrix_gt_.data(), k64); 4175+ init_code.CodeFunction("WinogradWeightTransform", filter_str, trans_weight_, "matrix_g", "matrix_gt", C8NUM, 4176+ input_unit_, kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_, true); 4177+ init_code.CodeBufferOffsetExpression(new_bias_, context->weight_name(), context->weight_offset_name(), 4178+ context->weight_size_name(), packed_bias_size_); 4179+ if (input_tensors_.size() == kInputSize2) { 4180+ auto bias_str = allocator_->GetRuntimeAddr(bias_tensor_); 4181+ init_code.CodeFunction("memcpy", new_bias_, bias_str, bias_tensor_->Size()); 4182+ } else { 4183+ init_code.CodeFunction("memcpy", new_bias_, 0, packed_bias_size_); 4184+ } 4185+ context->AppendInitWeightSizeCode(trans_weight_size_ + packed_bias_size_); 4186+ context->AppendInitCode(init_code.str()); 4187+} 4188+ 4189 int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) { 4190 Collect(context, 4191 { 4192@@ -253,6 +286,7 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) { 4193 } else if (target_ == kARM64) { 4194 Collect(context, {}, {}, 4195 { 4196+ "BigMatmulFp32Opt.S", 4197 "MatmulFp32.S", 4198 "MatmulFp32Opt.S", 4199 "PreSum4x16Int8Peroc.S", 4200@@ -263,14 +297,14 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) { 4201 "MatmulInt8.S", 4202 }); 4203 } 4204- 4205+ InitCodeOnline(context); 4206 NNaclFp32Serializer code; 4207 // call the op function 4208 code.CodeFunction("memset", trans_input_, "0", tile_buffer_size_); 4209 code.CodeFunction("memset", gemm_out_, "0", gemm_out_size_); 4210 code.CodeFunction("memset", tmp_data_, "0", tmp_data_size_); 4211 code.CodeFunction("memset", col_buffer_, "0", col_buffer_size_); 4212- code << "\t\tfloat *tmp_buffer_address_list[4] = {" << allocator_->GetRuntimeAddr(trans_input_) << ", " 4213+ code << " float *tmp_buffer_address_list[4] = {" << allocator_->GetRuntimeAddr(trans_input_) << ", " 4214 << allocator_->GetRuntimeAddr(gemm_out_) << ", " << allocator_->GetRuntimeAddr(tmp_data_) << ", " 4215 << allocator_->GetRuntimeAddr(col_buffer_) << "};\n"; 4216 code.CodeStruct("conv_parameter", *conv_param_); 4217diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h 4218index d583312a..a4a0438f 100644 4219--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h 4220+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h 4221@@ -38,7 +38,13 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { 4222 ~ConvolutionWinogradFP32Coder() override = default; 4223 4224 private: 4225- int InitWeightBias(); 4226+ int InitParameter(); 4227+ 4228+ int InitWeightBiasOffline(); 4229+ 4230+ int InitWeightBiasOnline(); 4231+ 4232+ void InitCodeOnline(CoderContext *const context); 4233 4234 int ConfigInputOutput(); 4235 4236@@ -46,13 +52,13 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { 4237 4238 int ReSize(); 4239 4240- int WinogradFilterTransform(const float *weight_data, float *matrix_g, const float *matrix_gt, int oc_block); 4241- 4242 std::string GetInputTransFunc(int input_unit); 4243 4244 std::string GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); 4245 4246+ size_t trans_weight_size_{0}; 4247 float *trans_weight_{nullptr}; 4248+ size_t packed_bias_size_{0}; 4249 float *new_bias_{nullptr}; 4250 4251 int kernel_unit_{0}; 4252@@ -70,6 +76,9 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { 4253 float *col_buffer_{nullptr}; 4254 4255 TransFuncStr trans_func_str_; 4256+ TypeId data_type_{kNumberTypeFloat32}; 4257+ std::vector<float> matrix_g_; 4258+ std::vector<float> matrix_gt_; 4259 }; 4260 } // namespace mindspore::lite::micro::nnacl 4261 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CONVOLUTION_WINOGRAD_FP32_CODER_H_ 4262diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc 4263new file mode 100644 4264index 00000000..50146e72 4265--- /dev/null 4266+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc 4267@@ -0,0 +1,214 @@ 4268+/** 4269+ * Copyright 2023 Huawei Technologies Co., Ltd 4270+ * 4271+ * Licensed under the Apache License, Version 2.0 (the "License"); 4272+ * you may not use this file except in compliance with the License. 4273+ * You may obtain a copy of the License at 4274+ * 4275+ * http://www.apache.org/licenses/LICENSE-2.0 4276+ * 4277+ * Unless required by applicable law or agreed to in writing, software 4278+ * distributed under the License is distributed on an "AS IS" BASIS, 4279+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4280+ * See the License for the specific language governing permissions and 4281+ * limitations under the License. 4282+ */ 4283+#include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" 4284+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 4285+#include "coder/opcoders/file_collector.h" 4286+#include "nnacl/custom_gru_parameter.h" 4287+ 4288+using mindspore::schema::PrimitiveType_Custom; 4289+ 4290+namespace mindspore::lite::micro::nnacl { 4291+namespace { 4292+constexpr size_t kOutputNum = 3; 4293+constexpr size_t kInputDims = 3; 4294+constexpr size_t kWeightDims = 2; 4295+constexpr size_t kInputSize = 6; 4296+} // namespace 4297+int CustomGruFP32Coder::Prepare(CoderContext *const context) { 4298+ if (input_tensors_.size() != kInputSize) { 4299+ MS_LOG(ERROR) << "built-in CustomGru must have 6 input." << node_->name_; 4300+ return RET_ERROR; 4301+ } 4302+ for (size_t i = 1; i < kInputSize - 1; ++i) { 4303+ if (!input_tensors_[i]->IsConst()) { 4304+ MS_LOG(ERROR) << "built-in CustomGru only support first-input and last-input is variable." << node_->name_; 4305+ return RET_NOT_SUPPORT; 4306+ } 4307+ } 4308+ if (InitParamter() != RET_OK) { 4309+ MS_LOG(ERROR) << "Init built-in CustomGru Parameter failed." << node_->name_; 4310+ return RET_ERROR; 4311+ } 4312+ return ReSize(); 4313+} 4314+ 4315+int CustomGruFP32Coder::InitParamter() { 4316+ param_ = reinterpret_cast<CustomGruParameter *>(parameter_); 4317+ param_->op_parameter_.thread_num_ = 1; 4318+ auto weight_in_shape = input_tensors_[1]->shape(); 4319+ auto weight_hidden_shape = input_tensors_[C2NUM]->shape(); 4320+ if (weight_in_shape.size() != kWeightDims || weight_hidden_shape.size() != kWeightDims) { 4321+ MS_LOG(ERROR) << "built-in CustomGru's weight must be 2D." << node_->name_; 4322+ return RET_ERROR; 4323+ } 4324+ if (weight_in_shape[0] != weight_hidden_shape[0]) { 4325+ MS_LOG(ERROR) << "Built-in CustomGru's weight-in and weight-hidden first-dim must be same." << node_->name_; 4326+ return RET_ERROR; 4327+ } 4328+ if (weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) { 4329+ MS_LOG(ERROR) << "Built-in CustomGru's weight-hidden first-dim must be 3 * second-dim." << node_->name_; 4330+ return RET_ERROR; 4331+ } 4332+ auto bias_in_shape = input_tensors_[C3NUM]->shape(); 4333+ auto bias_hidden_shape = input_tensors_[C4NUM]->shape(); 4334+ if (bias_in_shape.size() != 1) { 4335+ MS_LOG(ERROR) << "built-in CustomGru's bias must be 1D." << node_->name_; 4336+ return RET_ERROR; 4337+ } 4338+ if (bias_in_shape != bias_hidden_shape) { 4339+ MS_LOG(ERROR) << "built-in CustomGru's bias-in and bias-hidden must have same shape." << node_->name_; 4340+ return RET_ERROR; 4341+ } 4342+ if (bias_in_shape.back() != weight_in_shape.front()) { 4343+ MS_LOG(ERROR) << "built-in CustomGru's bias-in shape don't match with the first-dim of weight." << node_->name_; 4344+ return RET_ERROR; 4345+ } 4346+ if (bias_in_shape.front() % C3NUM != 0) { 4347+ MS_LOG(ERROR) << "The first-dim of CustomGru's weight must be 3 * hidden."; 4348+ return RET_ERROR; 4349+ } 4350+ param_->input_size = weight_in_shape.back(); 4351+ param_->hidden_size = bias_in_shape.front() / C3NUM; 4352+ return RET_OK; 4353+} 4354+ 4355+int CustomGruFP32Coder::ReSize() { 4356+ auto in_shape = input_tensor_->shape(); 4357+ if (in_shape.size() != kInputDims) { 4358+ MS_LOG(ERROR) << "built-in CustomGru's first-input must be 3D." << node_->name_; 4359+ return RET_ERROR; 4360+ } 4361+ param_->num_step = in_shape[0]; 4362+ param_->batch_size = in_shape[1]; 4363+ if (in_shape.back() != param_->input_size) { 4364+ MS_LOG(ERROR) << "built-in CustomGru's fisrt-input don't match its weight." << node_->name_; 4365+ return RET_ERROR; 4366+ } 4367+ return InitWeightAndBias(); 4368+} 4369+ 4370+int CustomGruFP32Coder::InitWeightAndBias() { 4371+ auto col_align = UP_ROUND(param_->hidden_size, col_tile_); 4372+ auto data_type_size = DataTypeSize(data_type_); 4373+ bias_pack_size_ = col_align * data_type_size; 4374+ weight_in_pack_size_ = static_cast<size_t>(col_align * param_->input_size) * data_type_size; 4375+ weight_input_ = allocator_->Malloc(data_type_, weight_in_pack_size_ * C3NUM, kOnlinePackWeight, 4376+ input_tensors_.at(1)->tensor_name() + "_online_pack"); 4377+ MS_CHECK_TRUE_MSG(weight_input_ != nullptr, RET_NULL_PTR, "Init weight-in failed."); 4378+ weight_hidden_pack_size_ = static_cast<size_t>(col_align * param_->hidden_size) * data_type_size; 4379+ weight_hidden_ = allocator_->Malloc(data_type_, weight_hidden_pack_size_ * C3NUM, kOnlinePackWeight, 4380+ input_tensors_.at(C2NUM)->tensor_name() + "_online_pack"); 4381+ MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, RET_NULL_PTR, "Init weight-hidden failed."); 4382+ bias_input_ = allocator_->Malloc(data_type_, bias_pack_size_ * C3NUM, kOnlinePackWeight, 4383+ input_tensors_.at(C3NUM)->tensor_name() + "_online_pack"); 4384+ MS_CHECK_TRUE_MSG(bias_input_ != nullptr, RET_NULL_PTR, "Init bias-in failed."); 4385+ bias_hidden_ = allocator_->Malloc(data_type_, bias_pack_size_ * C3NUM, kOnlinePackWeight, 4386+ input_tensors_.at(C4NUM)->tensor_name() + "_online_pack"); 4387+ MS_CHECK_TRUE_MSG(bias_hidden_ != nullptr, RET_NULL_PTR, "Init bias-hidden failed."); 4388+ auto row_align = UP_ROUND(param_->batch_size, row_tile_); 4389+ auto work_space = 4390+ (row_align * (param_->input_size + param_->hidden_size) + param_->batch_size * param_->hidden_size * C6NUM) * 4391+ data_type_size; 4392+ run_buffer_ = allocator_->Malloc(data_type_, work_space, kWorkspace); 4393+ MS_CHECK_TRUE_MSG(run_buffer_ != nullptr, RET_NULL_PTR, "Init run_buffer failed."); 4394+ return RET_OK; 4395+} 4396+ 4397+void CustomGruFP32Coder::InitNnaclFile(CoderContext *const context) { 4398+ Collect(context, {"nnacl/fp32/custom_gru_fp32.h"}, 4399+ {"custom_gru_fp32.c", "pack_fp32.c", "matmul_fp32.c", "arithmetic_fp32.c", "activation_fp32.c"}); 4400+} 4401+ 4402+void CustomGruFP32Coder::InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, 4403+ int row, int col) { 4404+ init_code->CodeFunction("RowMajor2Col8Major", src, dst, row, col); 4405+} 4406+ 4407+void CustomGruFP32Coder::InitBiasCode(CoderContext *const context, NNaclFp32Serializer *init_code) { 4408+ auto data_type_size = DataTypeSize(data_type_); 4409+ init_code->CodeBufferOffsetExpression(bias_input_, context->weight_name(), context->weight_offset_name(), 4410+ context->weight_size_name(), bias_pack_size_ * C3NUM); 4411+ auto bias_in_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_input_); 4412+ auto bias_in_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C3NUM]); 4413+ for (int i = 0; i < C3NUM; ++i) { 4414+ auto dst_bias_in = bias_in_str + " + " + std::to_string(i * bias_pack_size_ / data_type_size); 4415+ auto src_bias_in = bias_in_tensor + " + " + std::to_string(i * param_->hidden_size); 4416+ init_code->CodeFunction("memcpy", dst_bias_in, src_bias_in, param_->hidden_size * data_type_size); 4417+ } 4418+ init_code->CodeBufferOffsetExpression(bias_hidden_, context->weight_name(), context->weight_offset_name(), 4419+ context->weight_size_name(), bias_pack_size_ * C3NUM); 4420+ auto bias_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_hidden_); 4421+ auto bias_hidden_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C4NUM]); 4422+ for (int i = 0; i < C3NUM; ++i) { 4423+ auto dst_bias_hidden = bias_hidden_str + " + " + std::to_string(i * bias_pack_size_ / data_type_size); 4424+ auto src_bias_hidden = bias_hidden_tensor + " + " + std::to_string(i * param_->hidden_size); 4425+ init_code->CodeFunction("memcpy", dst_bias_hidden, src_bias_hidden, param_->hidden_size * data_type_size); 4426+ } 4427+} 4428+ 4429+void CustomGruFP32Coder::InitWeightCode(CoderContext *const context, NNaclFp32Serializer *init_code) { 4430+ auto data_type_size = DataTypeSize(data_type_); 4431+ init_code->CodeBufferOffsetExpression(weight_input_, context->weight_name(), context->weight_offset_name(), 4432+ context->weight_size_name(), weight_in_pack_size_ * C3NUM); 4433+ auto weight_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_input_); 4434+ auto weight_in_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[1]); 4435+ for (int i = 0; i < C3NUM; ++i) { 4436+ auto dst_weight_in = weight_input_str + " + " + std::to_string(i * weight_in_pack_size_ / data_type_size); 4437+ auto src_weight_in = weight_in_tensor + " + " + std::to_string(i * param_->hidden_size * param_->input_size); 4438+ InitPackMatrixB(init_code, src_weight_in, dst_weight_in, param_->hidden_size, param_->input_size); 4439+ } 4440+ 4441+ init_code->CodeBufferOffsetExpression(weight_hidden_, context->weight_name(), context->weight_offset_name(), 4442+ context->weight_size_name(), weight_hidden_pack_size_ * C3NUM); 4443+ auto weight_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_hidden_); 4444+ auto weight_hidden_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C2NUM]); 4445+ for (int i = 0; i < C3NUM; ++i) { 4446+ auto dst_weight_hidden = weight_hidden_str + " + " + std::to_string(i * weight_hidden_pack_size_ / data_type_size); 4447+ auto src_weight_hidden = 4448+ weight_hidden_tensor + " + " + std::to_string(i * param_->hidden_size * param_->hidden_size); 4449+ InitPackMatrixB(init_code, src_weight_hidden, dst_weight_hidden, param_->hidden_size, param_->hidden_size); 4450+ } 4451+} 4452+ 4453+int CustomGruFP32Coder::DoCode(CoderContext *const context) { 4454+ NNaclFp32Serializer code, init_code; 4455+ code.CodeStruct("custom_gru_parm", *param_); 4456+ InitNnaclFile(context); 4457+ InitWeightCode(context, &init_code); 4458+ InitBiasCode(context, &init_code); 4459+ auto row_align = UP_ROUND(param_->batch_size, row_tile_); 4460+ auto data_type_str = GetTensorDataType(data_type_); 4461+ auto buffer_name = "( " + data_type_str + "*)" + MemoryAllocator::GetInstance()->GetRuntimeAddr(run_buffer_); 4462+ int offset1 = row_align * param_->input_size; 4463+ int offset2 = offset1 + param_->batch_size * param_->hidden_size * C3NUM; 4464+ int offset3 = offset2 + row_align * param_->hidden_size; 4465+ code << data_type_str + "*buffer[4] = {" << buffer_name << ", " << buffer_name + " + " << offset1 << ", " 4466+ << buffer_name + " + " << offset2 << ", " << buffer_name + " + " << offset3 << "};\n"; 4467+ auto weight_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_input_); 4468+ auto weight_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_hidden_); 4469+ auto bias_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_input_); 4470+ auto bias_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_hidden_); 4471+ code.CodeFunction(op_func_, output_tensor_, input_tensor_, weight_input_str, weight_hidden_str, bias_input_str, 4472+ bias_hidden_str, input_tensors_[C5NUM], "buffer", "&custom_gru_parm"); 4473+ context->AppendInitWeightSizeCode((weight_in_pack_size_ + weight_hidden_pack_size_) * C3NUM + 4474+ bias_pack_size_ * C6NUM); 4475+ context->AppendInitCode(init_code.str()); 4476+ context->AppendCode(code.str()); 4477+ return RET_OK; 4478+} 4479+ 4480+REG_BUILIN_CUSTOM_CODER(kARM64, kNumberTypeFloat32, "CustomGRU", CPUOpCoderCreator<CustomGruFP32Coder>) 4481+} // namespace mindspore::lite::micro::nnacl 4482diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h 4483new file mode 100644 4484index 00000000..27db0f94 4485--- /dev/null 4486+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h 4487@@ -0,0 +1,64 @@ 4488+/** 4489+ * Copyright 2023 Huawei Technologies Co., Ltd 4490+ * 4491+ * Licensed under the Apache License, Version 2.0 (the "License"); 4492+ * you may not use this file except in compliance with the License. 4493+ * You may obtain a copy of the License at 4494+ * 4495+ * http://www.apache.org/licenses/LICENSE-2.0 4496+ * 4497+ * Unless required by applicable law or agreed to in writing, software 4498+ * distributed under the License is distributed on an "AS IS" BASIS, 4499+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4500+ * See the License for the specific language governing permissions and 4501+ * limitations under the License. 4502+ */ 4503+ 4504+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_ 4505+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_ 4506+ 4507+#include <string> 4508+#include <vector> 4509+#include "coder/opcoders/op_coder.h" 4510+#include "nnacl/custom_gru_parameter.h" 4511+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 4512+ 4513+namespace mindspore::lite::micro::nnacl { 4514+class CustomGruFP32Coder : public OperatorCoder { 4515+ public: 4516+ CustomGruFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 4517+ const LiteGraph::Node *node, size_t node_index, Target target) 4518+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 4519+ ~CustomGruFP32Coder() override = default; 4520+ 4521+ int Prepare(CoderContext *const context) override; 4522+ 4523+ int DoCode(CoderContext *const context) override; 4524+ 4525+ protected: 4526+ virtual void InitNnaclFile(CoderContext *const context); 4527+ virtual void InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, int row, 4528+ int col); 4529+ TypeId data_type_{kNumberTypeFloat32}; 4530+ int row_tile_{C12NUM}; 4531+ int col_tile_{C8NUM}; 4532+ void *weight_input_{nullptr}; 4533+ void *weight_hidden_{nullptr}; 4534+ void *bias_input_{nullptr}; 4535+ void *bias_hidden_{nullptr}; 4536+ size_t weight_in_pack_size_{0}; 4537+ size_t weight_hidden_pack_size_{0}; 4538+ size_t bias_pack_size_{0}; 4539+ std::string op_func_{"CustomGru"}; 4540+ CustomGruParameter *param_{nullptr}; 4541+ 4542+ private: 4543+ int InitParamter(); 4544+ int InitWeightAndBias(); 4545+ int ReSize(); 4546+ void InitWeightCode(CoderContext *const context, NNaclFp32Serializer *init_code); 4547+ void InitBiasCode(CoderContext *const context, NNaclFp32Serializer *init_code); 4548+ void *run_buffer_{nullptr}; 4549+}; 4550+} // namespace mindspore::lite::micro::nnacl 4551+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_ 4552diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc 4553index 3c31479c..c6b93abf 100644 4554--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc 4555+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc 4556@@ -27,11 +27,30 @@ using mindspore::schema::PrimitiveType_Gather; 4557 namespace mindspore::lite::micro::nnacl { 4558 int GatherFP32Coder::Prepare(CoderContext *const context) { return RET_OK; } 4559 4560+void GatherFP32Coder::InitCodeInChange(CoderContext *const context, std::string *auxiliary_variable) { 4561+ auto input0_shape_str = allocator_->GetAuxiliaryWeight(input_tensor_); 4562+ if (input0_shape_str.empty()) { 4563+ return; 4564+ } 4565+ *auxiliary_variable = input0_shape_str; 4566+ NNaclFp32Serializer init_code; 4567+ auto in_shape = input_tensor_->shape(); 4568+ int in_rank = static_cast<int>(in_shape.size()); 4569+ init_code.CodeArray("shape", in_shape.data(), in_rank); 4570+ init_code << " for (int i = 0; i < " << in_rank << "; ++i) {\n"; 4571+ init_code << " if (i != " << axis_ << " && " << input0_shape_str << "[i] != shape[i]) {\n"; 4572+ init_code << " return RET_ERROR;\n"; 4573+ init_code << " }\n"; 4574+ init_code << " }\n"; 4575+ context->AppendInitCode(init_code.str()); 4576+} 4577+ 4578 int GatherFP32Coder::DoCode(CoderContext *context) { 4579 Tensor *input0 = input_tensors_.at(0); 4580 Tensor *input1 = input_tensors_.at(1); 4581 MS_CHECK_PTR(input0); 4582 MS_CHECK_PTR(input1); 4583+ MS_CHECK_PTR(parameter_); 4584 MS_CHECK_TRUE_MSG(input1->data_type() == kNumberTypeInt32 || input1->data_type() == kNumberTypeInt, RET_ERROR, 4585 "index's data-type is not int32"); 4586 // generate code .h .c 4587@@ -44,18 +63,16 @@ int GatherFP32Coder::DoCode(CoderContext *context) { 4588 }); 4589 4590 NNaclFp32Serializer code; 4591- std::vector<int> in_shape = input0->shape(); 4592+ auto in_shape = input0->shape(); 4593 int in_rank = static_cast<int>(in_shape.size()); 4594- MS_CHECK_PTR(parameter_); 4595- int axis = (reinterpret_cast<GatherParameter *>(parameter_))->axis_; 4596- MS_CHECK_TRUE(static_cast<int>(in_shape.size()) >= axis, "invalid axis in gather parameter"); 4597- const int limit = in_shape.at(axis); 4598- 4599- int outer_size = 1, inner_size = 1; 4600- for (int i = 0; i < axis; ++i) { 4601+ axis_ = *(reinterpret_cast<int *>(input_tensors_.at(THIRD_INPUT)->data())); 4602+ MS_CHECK_TRUE(static_cast<int>(in_shape.size()) >= axis_, "invalid axis in gather parameter"); 4603+ int outer_size = 1; 4604+ for (int i = 0; i < axis_; ++i) { 4605 outer_size *= in_shape.at(i); 4606 } 4607- for (int i = axis + 1; i < in_rank; ++i) { 4608+ int inner_size = 1; 4609+ for (int i = axis_ + 1; i < in_rank; ++i) { 4610 inner_size *= in_shape.at(i); 4611 } 4612 auto data_size = static_cast<int>(lite::DataTypeSize(input0->data_type())); 4613@@ -67,22 +84,22 @@ int GatherFP32Coder::DoCode(CoderContext *context) { 4614 int start = stride * kDefaultTaskId; 4615 int count = MSMIN(stride, outer_size - stride * kDefaultTaskId); 4616 std::string input0_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input0, true); 4617- if (input0_data.empty()) { 4618- MS_LOG(ERROR) << "pointer is not allocated by the allocator"; 4619- return RET_ERROR; 4620- } 4621+ MS_CHECK_TRUE_MSG(!input0_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 4622 std::string input1_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input1, true); 4623- if (input1_data.empty()) { 4624- MS_LOG(ERROR) << "pointer is not allocated by the allocator"; 4625- return RET_ERROR; 4626- } 4627+ MS_CHECK_TRUE_MSG(!input1_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 4628 std::string output_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(output_tensor_, true); 4629- if (output_data.empty()) { 4630- MS_LOG(ERROR) << "pointer is not allocated by the allocator"; 4631- return RET_ERROR; 4632+ MS_CHECK_TRUE_MSG(!output_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 4633+ 4634+ std::string limit = std::to_string(in_shape[axis_]); 4635+ std::string in_offset = std::to_string(start * in_shape[axis_] * byte_inner_size); 4636+ std::string auxiliary_variable; 4637+ InitCodeInChange(context, &auxiliary_variable); 4638+ if (!auxiliary_variable.empty()) { 4639+ limit = auxiliary_variable + "[" + std::to_string(axis_) + "]"; 4640+ in_offset = std::to_string(start) + " * " + limit + " * " + std::to_string(byte_inner_size); 4641 } 4642 code << "\t\tconst int8_t *int8_in = (const int8_t *)" << input0_data << ";\n"; 4643- code << "\t\tint8_in += " << std::to_string(start * limit * byte_inner_size) << ";\n"; 4644+ code << "\t\tint8_in += " << in_offset << ";\n"; 4645 code << "\t\tconst int *index_data = (const int *)" << input1_data << ";\n"; 4646 code << "\t\tint8_t *int8_out = (int8_t *)" << output_data << ";\n"; 4647 code << "\t\tint8_out += " << std::to_string(start * byte_out_stride) << ";\n"; 4648diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h 4649index a14d9c3c..a175d694 100644 4650--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h 4651+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h 4652@@ -35,7 +35,9 @@ class GatherFP32Coder final : public OperatorCoder { 4653 int DoCode(CoderContext *const context) override; 4654 4655 private: 4656+ void InitCodeInChange(CoderContext *const context, std::string *auxiliary_variable); 4657 int32_t *indices_{nullptr}; 4658+ int axis_{0}; 4659 }; 4660 } // namespace mindspore::lite::micro::nnacl 4661 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_FP32_CODER_H_ 4662diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc 4663index 790a142e..6115edb5 100644 4664--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc 4665+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc 4666@@ -50,13 +50,13 @@ int MatMulFP32BaseCoder::ReSize() { 4667 int MatMulFP32BaseCoder::InitBiasData() { 4668 if (input_tensors_.size() == DIMENSION_3D) { 4669 int max_bias_data = params_->col_align_; 4670- bias_pack_ptr_size_ = static_cast<size_t>(max_bias_data * sizeof(float)); 4671+ bias_pack_ptr_size_ = static_cast<size_t>(max_bias_data * DataTypeSize(data_type_)); 4672 if (bias_tensor_->ElementsNum() == 1) { 4673 is_bias_broadcast_ = true; 4674 } 4675- ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * sizeof(float); 4676- bias_ptr_ = allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight, 4677- bias_tensor_->tensor_name() + "_online_pack"); 4678+ ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * DataTypeSize(data_type_); 4679+ bias_ptr_ = 4680+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 4681 MS_CHECK_PTR(bias_ptr_); 4682 } 4683 return RET_OK; 4684@@ -83,18 +83,19 @@ int MatMulFP32BaseCoder::InitBufferA() { 4685 if (a_pack_ptr_ != nullptr) { 4686 return RET_OK; 4687 } 4688- a_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * sizeof(float)); 4689+ a_pack_ptr_size_ = 4690+ static_cast<size_t>(params_->batch * params_->row_align_ * params_->deep_ * DataTypeSize(data_type_)); 4691 if (params_->a_const_) { 4692- a_pack_ptr_ = reinterpret_cast<float *>(allocator_->GetSharedWeightAddr(input_tensors_.at(0))); 4693+ a_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(0)); 4694 if (a_pack_ptr_ == nullptr) { 4695- a_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight, 4696- input_tensors_.at(0)->tensor_name() + "_online_pack")); 4697+ a_pack_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, 4698+ input_tensors_.at(0)->tensor_name() + "_online_pack"); 4699 allocator_->MarkSharedWeight(input_tensors_.at(0), a_pack_ptr_); 4700 } else { 4701 a_packed_ = true; 4702 } 4703 } else { 4704- a_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeFloat32, a_pack_ptr_size_, kWorkspace)); 4705+ a_pack_ptr_ = allocator_->Malloc(data_type_, a_pack_ptr_size_, kWorkspace); 4706 } 4707 MS_CHECK_PTR(a_pack_ptr_); 4708 return RET_OK; 4709@@ -104,18 +105,19 @@ int MatMulFP32BaseCoder::InitBufferB() { 4710 if (b_pack_ptr_ != nullptr) { 4711 return RET_OK; 4712 } 4713- b_pack_ptr_size_ = static_cast<size_t>(params_->batch * params_->col_align_ * params_->deep_ * data_type_size_); 4714+ b_pack_ptr_size_ = 4715+ static_cast<size_t>(params_->batch * params_->col_align_ * params_->deep_ * DataTypeSize(data_type_)); 4716 if (params_->b_const_) { 4717- b_pack_ptr_ = reinterpret_cast<float *>(allocator_->GetSharedWeightAddr(input_tensors_.at(1))); 4718+ b_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(1)); 4719 if (b_pack_ptr_ == nullptr) { 4720- b_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeUInt8, b_pack_ptr_size_, kOnlinePackWeight, 4721- input_tensors_.at(1)->tensor_name() + "_online_pack")); 4722+ b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size_, kOnlinePackWeight, 4723+ input_tensors_.at(1)->tensor_name() + "_online_pack"); 4724 allocator_->MarkSharedWeight(input_tensors_.at(1), b_pack_ptr_); 4725 } else { 4726 b_packed_ = true; 4727 } 4728 } else { 4729- b_pack_ptr_ = reinterpret_cast<float *>(allocator_->Malloc(kNumberTypeUInt8, b_pack_ptr_size_, kWorkspace)); 4730+ b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size_, kWorkspace); 4731 } 4732 MS_CHECK_PTR(b_pack_ptr_); 4733 return RET_OK; 4734@@ -194,7 +196,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { 4735 NNaclFp32Serializer code, init_code; 4736 size_t w_buf_size = 0; 4737 std::string param_name = "mat_mul_parameter"; 4738- std::string bias_ptr_str = "((float *)(" + allocator_->GetRuntimeAddr(bias_ptr_) + "))"; 4739+ 4740 code.CodeStruct(param_name, *params_); 4741 if (support_parallel_) { 4742 code << " " << param_name << ".op_parameter_.thread_num_ = 1;\n"; 4743@@ -207,6 +209,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { 4744 int max_bias_data = params_->col_align_; 4745 if (is_bias_broadcast_) { 4746 float broad_cast_data = (reinterpret_cast<float *>(bias_tensor_->data()))[0]; 4747+ std::string bias_ptr_str = allocator_->GetRuntimeAddr(bias_ptr_); 4748 init_code << "\t for (int i = 0; i < " << max_bias_data << "; ++i) {\n"; 4749 init_code << "\t\t " << bias_ptr_str << "[i] = " << broad_cast_data << ";\n"; 4750 init_code << " }\n"; 4751@@ -219,8 +222,8 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { 4752 std::string a_str = allocator_->GetRuntimeAddr(input_tensor_); 4753 std::string b_str = allocator_->GetRuntimeAddr(filter_tensor_); 4754 std::string c_str = allocator_->GetRuntimeAddr(output_tensor_); 4755- std::string a_pack_str = allocator_->GetRuntimeAddr(a_pack_ptr_); 4756- std::string b_pack_str = allocator_->GetRuntimeAddr(b_pack_ptr_); 4757+ std::string a_pack_str = allocator_->GetRuntimeAddr(static_cast<float *>(a_pack_ptr_)); 4758+ std::string b_pack_str = allocator_->GetRuntimeAddr(static_cast<float *>(b_pack_ptr_)); 4759 // do const value packing to init 4760 if ((params_->a_const_ && !a_packed_) || (params_->b_const_ && !b_packed_)) { 4761 init_code.CodeStruct("mat_mul_parameter", *params_); 4762@@ -271,7 +274,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { 4763 code << " const float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_ << ";\n"; 4764 code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n "; 4765 4766- code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_str, params_->act_type_, 4767+ code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_, 4768 params_->deep_, cur_oc); 4769 } else { 4770 code << " const float *batch_a_ptr = " << a_pack_str << " + i * " << params_->row_align_ * params_->deep_ 4771@@ -280,7 +283,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { 4772 << ";\n"; 4773 code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n "; 4774 4775- code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_str, params_->act_type_, 4776+ code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_, 4777 params_->deep_, params_->row_, cur_oc, params_->col_, "OutType_Nhwc"); 4778 } 4779 code << " }\n"; 4780diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h 4781index 68b2658a..a5ef9277 100644 4782--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h 4783+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h 4784@@ -69,7 +69,7 @@ class MatMulFP32BaseCoder : public OperatorCoder { 4785 size_t a_pack_ptr_size_{0}; 4786 size_t b_pack_ptr_size_{0}; 4787 bool is_bias_broadcast_{false}; 4788- size_t data_type_size_{C4NUM}; 4789+ TypeId data_type_{kNumberTypeFloat32}; 4790 }; 4791 } // namespace mindspore::lite::micro::nnacl 4792 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_MATMUL_FP32_BASE_CODER_H_ 4793diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 4794index a107c3cf..45b2e37f 100644 4795--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 4796+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 4797@@ -27,6 +27,14 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build(int schema_version) { 4798 MS_CHECK_PTR_RET_NULL(node_->primitive_); 4799 int primitive_type = GetPrimitiveType(node_->primitive_, schema_version); 4800 CoderKey coder_key(target_, data_type_, primitive_type); 4801+ if (builtin_custom_) { 4802+ auto custom_type = reinterpret_cast<const schema::Primitive *>(node_->primitive_)->value_as_Custom()->type(); 4803+ if (custom_type == nullptr || custom_type->str().empty()) { 4804+ MS_LOG(ERROR) << "Builtin custom-op has no type."; 4805+ return nullptr; 4806+ } 4807+ coder_key = CoderKey(target_, data_type_, schema::PrimitiveType_Custom, custom_type->str()); 4808+ } 4809 CoderCreatorFunc creator_func = OpCoderFactory::GetInstance()->FindOpCoder(coder_key); 4810 if (creator_func == nullptr) { 4811 MS_LOG(ERROR) << "caught unsupported layer: " << node_->name_; 4812@@ -112,5 +120,10 @@ OpCoderBuilder &OpCoderBuilder::support_parallel(bool parallel) { 4813 return *this; 4814 } 4815 4816+OpCoderBuilder &OpCoderBuilder::is_builtin_custom(bool builtin_custom) { 4817+ builtin_custom_ = builtin_custom; 4818+ return *this; 4819+} 4820+ 4821 void OpCoderBuilder::Reset() {} 4822 } // namespace mindspore::lite::micro 4823diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 4824index adce6c73..d85f1c32 100644 4825--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 4826+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 4827@@ -46,6 +46,8 @@ class OpCoderBuilder { 4828 4829 OpCoderBuilder &support_parallel(bool parallel); 4830 4831+ OpCoderBuilder &is_builtin_custom(bool builtin_custom); 4832+ 4833 void Reset(); 4834 4835 private: 4836@@ -70,6 +72,8 @@ class OpCoderBuilder { 4837 std::vector<uint32_t> output_indices_; 4838 4839 bool support_parallel_{false}; 4840+ 4841+ bool builtin_custom_{false}; 4842 }; 4843 } // namespace mindspore::lite::micro 4844 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_BUILDER_H_ 4845diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 4846index 031df2e7..cf26d51d 100644 4847--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 4848+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 4849@@ -37,9 +37,9 @@ OpCoderFactory *OpCoderFactory::GetInstance() { 4850 } 4851 4852 int OpCoderFactory::RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, 4853- const CoderCreatorFunc &creator_func) { 4854+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func) { 4855 // check key 4856- CoderKey key(target, data_type, operator_type); 4857+ CoderKey key(target, data_type, operator_type, builtin_custom_type); 4858 // insert pair to registry 4859 if (this->opcoder_sets_.find(key) != this->opcoder_sets_.end()) { 4860 MS_LOG(ERROR) << "coder already exist: " << key.ToString(); 4861@@ -63,7 +63,7 @@ CoderCreatorFunc OpCoderFactory::FindOpCoder(const CoderKey &key) { 4862 } 4863 4864 OpCoderRegister::OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, 4865- const CoderCreatorFunc &creatorFunc) { 4866- OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, creatorFunc); 4867+ const std::string &builtin_custom_type, const CoderCreatorFunc &creatorFunc) { 4868+ OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, builtin_custom_type, creatorFunc); 4869 } 4870 } // namespace mindspore::lite::micro 4871diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 4872index 9a1aed63..acbd3a22 100644 4873--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 4874+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 4875@@ -18,6 +18,7 @@ 4876 #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_ 4877 4878 #include <map> 4879+#include <utility> 4880 #include <vector> 4881 #include <memory> 4882 #include <string> 4883@@ -34,10 +35,14 @@ class CoderKey { 4884 public: 4885 CoderKey() = delete; 4886 4887- CoderKey(Target target, TypeId data_type, int op_type) : target_(target), data_type_(data_type), op_type_(op_type) {} 4888+ CoderKey(Target target, TypeId data_type, int op_type, std::string builtin_custom_type = "") 4889+ : target_(target), 4890+ data_type_(data_type), 4891+ op_type_(op_type), 4892+ builtin_custom_type_(std::move(builtin_custom_type)) {} 4893 4894 CoderKey AllKey() const { 4895- CoderKey key(kAllTargets, data_type_, op_type_); 4896+ CoderKey key(kAllTargets, data_type_, op_type_, builtin_custom_type_); 4897 return key; 4898 } 4899 4900@@ -50,6 +55,7 @@ class CoderKey { 4901 Target target_ = kTargetUnknown; 4902 TypeId data_type_ = kTypeUnknown; 4903 int op_type_ = schema::PrimitiveType_NONE; 4904+ std::string builtin_custom_type_; 4905 }; 4906 4907 class OpCoderFactory { 4908@@ -59,7 +65,7 @@ class OpCoderFactory { 4909 static OpCoderFactory *GetInstance(); 4910 4911 int RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, 4912- const CoderCreatorFunc &creator_func); 4913+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); 4914 4915 CoderCreatorFunc FindOpCoder(const CoderKey &key); 4916 4917@@ -75,11 +81,16 @@ class OpCoderRegister { 4918 OpCoderRegister() = delete; 4919 4920 OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, 4921- const CoderCreatorFunc &creator_func); 4922+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); 4923 4924 ~OpCoderRegister() = default; 4925 }; 4926-#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 4927- static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, creator_func); 4928+#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 4929+ static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, "", \ 4930+ creator_func); 4931+ 4932+#define REG_BUILIN_CUSTOM_CODER(target, data_type, custom_type, creator_func) \ 4933+ static OpCoderRegister g_##target##data_type##operator_type##Creator( \ 4934+ target, data_type, schema::PrimitiveType_Custom, custom_type, creator_func); 4935 } // namespace mindspore::lite::micro 4936 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_ 4937diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 4938index c333b621..cde08fd8 100644 4939--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 4940+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 4941@@ -196,6 +196,11 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const BroadcastSha 4942 ToString(op_param.output_shape_), op_param.output_shape_size_); 4943 } 4944 4945+void NNaclFp32Serializer::CodeStruct(const std::string &name, const CustomGruParameter &op_param) { 4946+ CodeBaseStruct<false>("CustomGruParameter", name, op_param.op_parameter_, op_param.num_step, op_param.batch_size, 4947+ op_param.input_size, op_param.hidden_size); 4948+} 4949+ 4950 void NNaclFp32Serializer::CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector<Tensor *> tensor) { 4951 std::vector<std::string> tensor_names; 4952 int size = tensor.size(); 4953diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 4954index f52ced20..797a9574 100644 4955--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 4956+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 4957@@ -44,6 +44,7 @@ 4958 #include "nnacl/layer_norm_parameter.h" 4959 #include "nnacl/broadcast_to_parameter.h" 4960 #include "nnacl/split_parameter.h" 4961+#include "nnacl/custom_gru_parameter.h" 4962 4963 namespace mindspore::lite::micro::nnacl { 4964 class NNaclFp32Serializer : public Serializer { 4965@@ -74,6 +75,7 @@ class NNaclFp32Serializer : public Serializer { 4966 void CodeStruct(const std::string &name, const SplitParameter &split_parameter); 4967 void CodeStruct(const std::string &name, const LayerNormParameter ¶m); 4968 void CodeStruct(const std::string &name, const BroadcastShapeInfo ¶m); 4969+ void CodeStruct(const std::string &name, const CustomGruParameter ¶m); 4970 void CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector<Tensor *> tensor); 4971 4972 private: 4973diff --git a/mindspore/lite/tools/converter/micro/coder/session.cc b/mindspore/lite/tools/converter/micro/coder/session.cc 4974index 471f1491..756b7222 100644 4975--- a/mindspore/lite/tools/converter/micro/coder/session.cc 4976+++ b/mindspore/lite/tools/converter/micro/coder/session.cc 4977@@ -40,11 +40,38 @@ 4978 #include "coder/opcoders/nnacl/dequant/de_quant.h" 4979 4980 namespace mindspore::lite::micro { 4981+namespace { 4982+bool IsBuiltInCustomNode(const void *primitive, int schema_version) { 4983+ if (!IsCustomNode(primitive, schema_version)) { 4984+ return false; 4985+ } 4986+ const auto &custom = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_Custom(); 4987+ if (custom == nullptr) { 4988+ return false; 4989+ } 4990+ const auto &attrs = custom->attr(); 4991+ if (attrs == nullptr) { 4992+ return false; 4993+ } 4994+ for (size_t i = 0; i < attrs->size(); ++i) { 4995+ if (attrs->Get(i) == nullptr || attrs->Get(i)->name() == nullptr) { 4996+ continue; 4997+ } 4998+ if (attrs->Get(i)->name()->str() == "builtin") { 4999+ return true; 5000+ } 5001+ } 5002+ return false; 5003+} 5004+} // namespace 5005+ 5006 CoderSession::CoderSession() { allocator_ = MemoryAllocator::GetInstance(); } 5007 5008-void CoderSession::EndCode() { 5009+int CoderSession::PassArgsToContext() { 5010 context_->set_tensor_map(allocator_->tensors_map()); 5011 context_->set_saved_weights(allocator_->saved_weights()); 5012+ context_->set_origin_weights(allocator_->origin_weights()); 5013+ context_->set_auxiliary_weights(allocator_->auxiliary_weights()); 5014 size_t de_quant_max_workspace_size = nnacl::Dequant::GetInstance()->de_quant_max_workspace(); 5015 size_t final_total_size = allocator_->total_buffer_size() > de_quant_max_workspace_size 5016 ? allocator_->total_buffer_size() 5017@@ -61,13 +88,20 @@ void CoderSession::EndCode() { 5018 if (config->code_mode() == Train) { 5019 Train::TransformGraphForTrain(context_.get(), op_coders_, schema_version_); 5020 } 5021+ if (!context_->JudgeIsValid(Configurator::GetInstance()->keep_original_weight())) { 5022+ MS_LOG(ERROR) << "Current model cannot keep-original-weight, due to existing generated tensor-data, please set " 5023+ "'keep_original_weight' to false."; 5024+ return RET_NOT_SUPPORT; 5025+ } 5026+ return RET_OK; 5027 } 5028 5029 int CoderSession::Run() { 5030 MS_LOG(INFO) << "start run opcoders"; 5031 // 1. assign memory 5032 std::vector<lite::Tensor *> inputs = coder_graph_->input_tensors(); 5033- int ret = allocator_->Assign(inputs, op_coders_); 5034+ int ret = allocator_->Assign(inputs, op_coders_, coder_graph_->all_tensors(), 5035+ Configurator::GetInstance()->changeable_weights_name()); 5036 MS_CHECK_RET_CODE(ret, "assign memory failed"); 5037 // 2. prepare, init model parameters 5038 for (const auto &op_coder : op_coders_) { 5039@@ -84,10 +118,10 @@ int CoderSession::Run() { 5040 ret = op_coder->DoCode(this->context_.get()); 5041 MS_CHECK_RET_CODE(ret, "do coder " << op_coder->name() << " failed"); 5042 } 5043- 5044- this->EndCode(); 5045+ ret = PassArgsToContext(); 5046+ MS_CHECK_RET_CODE(ret, "PassArgsToContext failed"); 5047 MS_LOG(INFO) << "run opcoders success"; 5048- return RET_OK; 5049+ return ret; 5050 } 5051 5052 int CoderSession::GenerateCode() { 5053@@ -269,7 +303,9 @@ int CoderSession::CreateOpCoders() { 5054 } 5055 5056 OpParameter *parameter = nullptr; 5057- if (IsCustomNode(node->primitive_, schema_version_)) { 5058+ bool is_custom_op = IsCustomNode(node->primitive_, schema_version_); 5059+ bool is_built_in_custom_op = IsBuiltInCustomNode(node->primitive_, schema_version_); 5060+ if (is_custom_op && !is_built_in_custom_op) { 5061 KernelRegistry::GetInstance()->RegisterKernel(schema::PrimitiveType_Custom); 5062 } else { 5063 parameter = GenParameterAndInfer(node, inputs, &outputs); // built-in ops infer 5064@@ -287,6 +323,7 @@ int CoderSession::CreateOpCoders() { 5065 .mode(code_mode) 5066 .input_indices(input_indices) 5067 .output_indices(output_indices) 5068+ .is_builtin_custom(is_built_in_custom_op) 5069 .build(schema_version_); 5070 if (op_coder == nullptr) { 5071 coder_graph_->DumpUnSupportLayer(code_target); 5072diff --git a/mindspore/lite/tools/converter/micro/coder/session.h b/mindspore/lite/tools/converter/micro/coder/session.h 5073index 3a8f7290..20f6b2b5 100644 5074--- a/mindspore/lite/tools/converter/micro/coder/session.h 5075+++ b/mindspore/lite/tools/converter/micro/coder/session.h 5076@@ -50,7 +50,7 @@ class CoderSession { 5077 int CreateOpCoders(); 5078 int InitCodeGraph(); 5079 int CompileGraph(); 5080- void EndCode(); 5081+ int PassArgsToContext(); 5082 5083 std::unique_ptr<CoderGraph> coder_graph_{nullptr}; 5084 std::unique_ptr<CoderContext> context_{nullptr}; 5085diff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 5086index a552da05..3b868b41 100644 5087--- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 5088+++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 5089@@ -103,6 +103,8 @@ std::string GetTensorDataType(TypeId type) { 5090 return "uint32_t "; 5091 case kNumberTypeInt64: 5092 return "int64_t "; 5093+ case kNumberTypeFloat16: 5094+ return "float16_t "; 5095 default: 5096 MS_LOG(ERROR) << "unsupported data type: " << EnumNameDataType(type); 5097 return ""; 5098@@ -152,7 +154,6 @@ std::string EnumMicroTensorDataType(TypeId type) { 5099 case kNumberTypeUInt16: 5100 return "DataType_DT_UINT16"; 5101 case kNumberTypeFloat16: 5102- MS_LOG(WARNING) << "unsupported data type: kNumberTypeFloat16"; 5103 return "DataType_DT_FLOAT16"; 5104 default: 5105 MS_LOG(WARNING) << "unsupported data type: " << type << ", reference: " << kNumberTypeInt; 5106diff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h 5107index 7753e123..61c7c923 100644 5108--- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h 5109+++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h 5110@@ -27,6 +27,7 @@ 5111 #include "src/common/log_adapter.h" 5112 #include "nnacl/op_base.h" 5113 #include "tools/converter/micro/coder/config.h" 5114+#include "base/float16.h" 5115 5116 namespace mindspore::lite::micro { 5117 std::string EnumNameDataType(TypeId type); 5118@@ -63,7 +64,8 @@ std::string GetVariableTypeName() { 5119 {std::type_index(typeid(int16_t *)), "int16_t *"}, 5120 {std::type_index(typeid(int8_t *)), "int8_t *"}, 5121 {std::type_index(typeid(uint8_t *)), "uint8_t *"}, 5122- {std::type_index(typeid(float *)), "float *"}}; 5123+ {std::type_index(typeid(float *)), "float *"}, 5124+ {std::type_index(typeid(float16 *)), "float16_t *"}}; 5125 auto item = types_name.find(std::type_index(typeid(T))); 5126 if (item != types_name.end()) { 5127 return item->second; 5128-- 51292.17.1 5130 5131