From 46c42f4387b8280a25b16592aa25179093d99fa2 Mon Sep 17 00:00:00 2001 From: z00574805 Date: Wed, 24 May 2023 11:10:00 +0800 Subject: [PATCH 3/5] xiaoyi-0003 --- .jenkins/check/config/filter_cppcheck.txt | 1 + cmake/package_micro.cmake | 4 +- include/api/serialization.h | 26 + .../nnacl/base/minimal_filtering_generator.c | 2 +- .../nnacl/base/minimal_filtering_generator.h | 1 - .../cpu/kernel/nnacl/custom_gru_parameter.h | 31 + .../cpu/kernel/nnacl/fp16/custom_gru_fp16.c | 70 ++ .../cpu/kernel/nnacl/fp16/custom_gru_fp16.h | 32 + .../cpu/kernel/nnacl/fp32/custom_gru_fp32.c | 71 ++ .../cpu/kernel/nnacl/fp32/custom_gru_fp32.h | 32 + .../device/cpu/kernel/nnacl/fp32/lstm_fp32.h | 2 +- .../cpu/kernel/nnacl/infer/custom_gru_infer.c | 45 + .../cpu/kernel/nnacl/infer/custom_gru_infer.h | 30 + .../plugin/device/cpu/kernel/nnacl/op_base.h | 1 + .../lite/include/registry/converter_context.h | 3 +- mindspore/lite/src/CMakeLists.txt | 1 + mindspore/lite/src/common/graph_util.cc | 7 +- .../common/ops/populate/custom_populate.cc | 14 + .../lite/src/runtime/cxx_api/serialization.cc | 31 + .../kernel/cpu/base/group_convolution_base.cc | 34 +- .../cpu/base/group_convolution_creator.cc | 24 +- .../cpu/base/group_convolution_creator.h | 8 +- .../cpu/fp16/convolution_delegate_fp16.cc | 2 +- .../kernel/cpu/fp16/custom_gru_fp16.cc | 132 +++ .../runtime/kernel/cpu/fp16/custom_gru_fp16.h | 40 + .../cpu/fp32/convolution_delegate_fp32.cc | 2 +- .../kernel/cpu/fp32/custom_gru_fp32.cc | 251 ++++++ .../runtime/kernel/cpu/fp32/custom_gru_fp32.h | 51 ++ .../cpu/int8/convolution_int8_creator.cc | 2 +- mindspore/lite/src/runtime/lite_session.h | 6 + mindspore/lite/src/train/graph_fusion.cc | 7 + .../train/optimizer/fusion/gru_fusion_pass.cc | 809 ++++++++++++++++++ .../train/optimizer/fusion/gru_fusion_pass.h | 45 + mindspore/lite/src/train/static_allocator.h | 6 +- mindspore/lite/src/train/train_export.cc | 38 + mindspore/lite/src/train/train_export.h | 1 + mindspore/lite/src/train/train_session.cc | 37 +- mindspore/lite/src/train/train_session.h | 3 + .../test/config_level0/micro/micro_arm64.cfg | 7 + .../config_parser/config_file_parser.cc | 13 +- .../config_parser/config_file_parser.h | 2 + .../config_parser/micro_param_parser.cc | 33 + .../config_parser/micro_param_parser.h | 2 + mindspore/lite/tools/converter/converter.cc | 18 +- .../converter_lite/converter_flags.cc | 4 +- .../converter/micro/cmake/file_list.cmake | 3 + .../micro/coder/allocator/allocator.cc | 82 +- .../micro/coder/allocator/allocator.h | 13 +- .../lite/tools/converter/micro/coder/coder.cc | 51 +- .../lite/tools/converter/micro/coder/coder.h | 9 +- .../lite/tools/converter/micro/coder/config.h | 10 + .../tools/converter/micro/coder/context.h | 25 +- .../generator/component/common_component.cc | 5 +- .../generator/component/weight_component.cc | 301 +++++-- .../generator/component/weight_component.h | 2 +- .../micro/coder/generator/generator.cc | 2 +- .../coder/opcoders/base/reshape_base_coder.cc | 5 + .../coder/opcoders/base/stack_base_coder.cc | 85 ++ .../coder/opcoders/base/stack_base_coder.h | 42 + .../opcoders/base/strided_slice_base_coder.cc | 21 + .../nnacl/fp16/custom_gru_fp16_coder.cc | 34 + .../nnacl/fp16/custom_gru_fp16_coder.h | 44 + .../nnacl/fp16/matmul_fp16_base_coder.cc | 22 +- .../nnacl/fp16/matmul_fp16_base_coder.h | 4 +- .../opcoders/nnacl/fp16/matmul_fp16_coder.h | 4 +- .../fp32/convolution_depthwise_fp32_coder.cc | 69 +- .../fp32/convolution_depthwise_fp32_coder.h | 8 +- .../fp32/convolution_winograd_fp32_coder.cc | 98 ++- .../fp32/convolution_winograd_fp32_coder.h | 15 +- .../nnacl/fp32/custom_gru_fp32_coder.cc | 214 +++++ .../nnacl/fp32/custom_gru_fp32_coder.h | 64 ++ .../opcoders/nnacl/fp32/gather_fp32_coder.cc | 59 +- .../opcoders/nnacl/fp32/gather_fp32_coder.h | 2 + .../nnacl/fp32/matmul_fp32_base_coder.cc | 41 +- .../nnacl/fp32/matmul_fp32_base_coder.h | 2 +- .../micro/coder/opcoders/op_coder_builder.cc | 13 + .../micro/coder/opcoders/op_coder_builder.h | 4 + .../micro/coder/opcoders/op_coder_register.cc | 8 +- .../micro/coder/opcoders/op_coder_register.h | 23 +- .../nnacl_serializer/nnacl_fp32_serializer.cc | 5 + .../nnacl_serializer/nnacl_fp32_serializer.h | 2 + .../tools/converter/micro/coder/session.cc | 49 +- .../tools/converter/micro/coder/session.h | 2 +- .../converter/micro/coder/utils/type_cast.cc | 3 +- .../converter/micro/coder/utils/type_cast.h | 4 +- 85 files changed, 3163 insertions(+), 267 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt index 6aeb515a..6f6a08f5 100644 --- a/.jenkins/check/config/filter_cppcheck.txt +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -56,6 +56,7 @@ "mindspore/mindspore/lite/tools/converter/quantizer/quantize_util.cc" "useStlAlgorithm" "mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable" "mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable" +"mindspore/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc" "stlFindInsert" "mindspore/mindspore/lite/examples/quick_start_micro/" "syntaxError" "mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable" "mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError" diff --git a/cmake/package_micro.cmake b/cmake/package_micro.cmake index 3c6da3db..0481e3c3 100644 --- a/cmake/package_micro.cmake +++ b/cmake/package_micro.cmake @@ -10,6 +10,8 @@ function(__install_micro_wrapper) COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(DIRECTORY ${NNACL_DIR}/fp32 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") + install(DIRECTORY ${NNACL_DIR}/fp16 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl + COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(DIRECTORY ${NNACL_DIR}/kernel DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(DIRECTORY ${NNACL_DIR}/infer DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl @@ -34,4 +36,4 @@ function(__install_micro_codegen) COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(TARGETS cmsis_nn ARCHIVE DESTINATION ${CODEGEN_ROOT_DIR}/third_party/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) -endfunction() \ No newline at end of file +endfunction() diff --git a/include/api/serialization.h b/include/api/serialization.h index 1a0c1f57..76d5dbec 100644 --- a/include/api/serialization.h +++ b/include/api/serialization.h @@ -105,6 +105,21 @@ class MS_API Serialization { QuantizationType quantization_type = kNoQuant, bool export_inference_only = true, std::vector output_tensor_name = {}); + /// \brief Export model's weights, which can be used in micro only. + /// + /// \param[in] model The model data. + /// \param[in] model_type The model file type. + /// \param[in] weight_file The path of exported weight file. + /// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`. + /// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format. + /// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable. + /// + /// \return Status. + inline static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, + const std::string &weight_file, bool is_inference = true, + bool enable_fp16 = false, + const std::vector &changeable_weights_name = {}); + private: static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key, const std::vector &dec_mode); @@ -119,6 +134,10 @@ class MS_API Serialization { static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data, QuantizationType quantization_type, bool export_inference_only, const std::vector> &output_tensor_name); + static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, + const std::vector &weight_file, bool is_inference, + bool enable_fp16, + const std::vector> &changeable_weights_name); }; Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, @@ -150,5 +169,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff VectorStringToChar(output_tensor_name)); } +Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, + const std::string &weight_file, bool is_inference, + bool enable_fp16, + const std::vector &changeable_weights_name) { + return ExportWeightsCollaborateWithMicro(model, model_type, StringToChar(weight_file), is_inference, enable_fp16, + VectorStringToChar(changeable_weights_name)); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c index 3796e47b..81bf8ddf 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c @@ -16,9 +16,9 @@ #include "nnacl/base/minimal_filtering_generator.h" #include #include -#include "nnacl/fp32/winograd_utils.h" #include "nnacl/errorcode.h" #include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/pack_fp32.h" void Polynomial(const float *interval, float *m, int degree) { for (int i = 0; i < degree; ++i) { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h index 01b013e8..fc0fa0e6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h @@ -21,7 +21,6 @@ #include #endif #include -#include "nnacl/pack.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h new file mode 100644 index 00000000..3bb8a444 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_ +#define MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct CustomGruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int num_step; + int batch_size; + int input_size; + int hidden_size; +} CustomGruParameter; + +#endif // MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c new file mode 100644 index 00000000..6e754569 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c @@ -0,0 +1,70 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/custom_gru_fp16.h" +#include "nnacl/fp16/activation_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" + +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float16_t *input_gate = buffer[1]; + float16_t *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + VecMatmulFp16(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size); + VecMatmulFp16(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); + } + } + ElementAddFp16(input_gate, hidden_gate, input_gate, double_output_size); + SigmoidFp16(input_gate, input_gate, double_output_size); + ElementMulFp16(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAddFp16(input_gate, input_gate + double_output_size, input_gate, output_size); + TanhFp16(input_gate, input_gate, output_size); + ElementSubFp16(init_h, input_gate, hidden_gate, output_size); + ElementMulFp16(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAddFp16(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h new file mode 100644 index 00000000..67008f03 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_ +#define MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_ +#ifdef ENABLE_ARM64 +#include "nnacl/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c new file mode 100644 index 00000000..caeece4a --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c @@ -0,0 +1,71 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/custom_gru_fp32.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float *input_gate = buffer[1]; + float *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2Col12Major(input + i * batch_size * input_size, buffer[0], batch_size, input_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2Col12Major(init_h, buffer[C2NUM], batch_size, hidden_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + MatVecMulFp32Neon64(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size, col_align); + MatVecMulFp32Neon64(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size, col_align); + } + } + ElementAdd(input_gate, hidden_gate, input_gate, double_output_size); + Sigmoid(input_gate, double_output_size, input_gate); + ElementMul(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAdd(input_gate, input_gate + double_output_size, input_gate, output_size); + Tanh(input_gate, output_size, input_gate); + ElementSub(init_h, input_gate, hidden_gate, output_size); + ElementMul(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAdd(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h new file mode 100644 index 00000000..576726c5 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#define MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#ifdef ENABLE_ARM64 +#include "nnacl/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h index b608e1e0..8e217d02 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h @@ -43,7 +43,7 @@ int ElementOptMulAcc(const float *input0, const float input1, float *output, con void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state, - float *buffer[6], const LstmParameter *lstm_param); + float *buffer[C6NUM], const LstmParameter *lstm_param); void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, float *buffer[7], diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c new file mode 100644 index 00000000..060d04cf --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/custom_gru_infer.h" +#include "nnacl/infer/infer_register.h" + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C6NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != C3NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + SetShapeTensor(output, input); + const TensorC *weight_in = inputs[1]; + if (weight_in->shape_size_ != C2NUM || weight_in->shape_[0] % C3NUM != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + output->shape_[C2NUM] = weight_in[0].shape_[0] / C3NUM; + return NNACL_OK; +} + +REG_INFER(CustomGru, PrimType_Inner_CustomGru, CustomGruInferShape) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h new file mode 100644 index 00000000..830150d5 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#define MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_GRU_INFER_H diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h index 5876bdf6..8c219212 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h @@ -520,6 +520,7 @@ enum PrimType { PrimType_Inner_ShapeFusion = 10003, PrimType_Inner_GraphKernel = 10004, PrimType_Inner_ThirdPartyModel = 10005, + PrimType_Inner_CustomGru = 10006, PrimType_InnerOpMax, PrimType_InnerOpMin = PrimType_Inner_ToFormat }; diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h index dd6e6d08..0d2de256 100644 --- a/mindspore/lite/include/registry/converter_context.h +++ b/mindspore/lite/include/registry/converter_context.h @@ -34,7 +34,8 @@ enum MS_API FmkType : int { kFmkTypeTflite = 4, kFmkTypePytorch = 5, kFmkTypeThirdParty = 6, - kFmkTypeEnd = 7, // For range check purpose, valid range: [0, kFmkTypeEnd) + kFmkTypeMsLite = 7, + kFmkTypeEnd = 8, // For range check purpose, valid range: [0, kFmkTypeEnd) }; /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 48e0fe7c..a80656f2 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -380,6 +380,7 @@ set(TRAIN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/gru_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc index 1688fb79..d5cb2a98 100644 --- a/mindspore/lite/src/common/graph_util.cc +++ b/mindspore/lite/src/common/graph_util.cc @@ -23,6 +23,7 @@ #include "src/common/log_adapter.h" #include "src/common/version_manager.h" #include "include/errorcode.h" +#include "nnacl/op_base.h" namespace mindspore { namespace lite { @@ -86,9 +87,9 @@ std::vector GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor // only support op_type from current schema bool IsPackedOp(int op_type) { - static const std::vector packed_ops = {schema::PrimitiveType_Conv2DFusion, - schema::PrimitiveType_Conv2dTransposeFusion, - schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion}; + static const std::vector packed_ops = { + schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion, + schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion, PrimType::PrimType_Inner_CustomGru}; return IsContain(packed_ops, op_type); } diff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc index f1506ece..391a587b 100644 --- a/mindspore/lite/src/common/ops/populate/custom_populate.cc +++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc @@ -17,10 +17,22 @@ #include "src/common/log_adapter.h" #include "src/tensor.h" #include "nnacl/custom_parameter.h" +#include "nnacl/custom_gru_parameter.h" using mindspore::schema::PrimitiveType_Custom; namespace mindspore { namespace lite { +OpParameter *CreateCustomGruParameter() { + auto *param = static_cast(malloc(sizeof(CustomGruParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc CustomGruParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(CustomGruParameter)); + param->op_parameter_.type_ = PrimType_Inner_CustomGru; + return reinterpret_cast(param); +} + OpParameter *PopulateCustomParameter(const void *prim) { MS_CHECK_TRUE_RET(prim != nullptr, nullptr); auto primitive = static_cast(prim); @@ -62,6 +74,8 @@ OpParameter *PopulateCustomParameter(const void *prim) { // Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary. param->attr_data[0] = static_cast(const_cast(prim)); return reinterpret_cast(param); + } else if (type == "CustomGRU") { + return CreateCustomGruParameter(); } else { MS_LOG(ERROR) << "Unsupported custom type: " << type; } diff --git a/mindspore/lite/src/runtime/cxx_api/serialization.cc b/mindspore/lite/src/runtime/cxx_api/serialization.cc index 8405f4b2..ddf69d23 100644 --- a/mindspore/lite/src/runtime/cxx_api/serialization.cc +++ b/mindspore/lite/src/runtime/cxx_api/serialization.cc @@ -212,4 +212,35 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; } + +Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type, + const std::vector &weight_file, bool is_inference, + bool enable_fp16, + const std::vector> &changeable_weights_name) { + if (model.impl_ == nullptr) { + MS_LOG(ERROR) << "Model implement is null."; + return kLiteUninitializedObj; + } + if (!model.impl_->IsTrainModel()) { + MS_LOG(ERROR) << "Model is not TrainModel."; + return kLiteError; + } + if (model_type != kMindIR && model_type != kMindIR_Lite) { + MS_LOG(ERROR) << "Unsupported Export Format " << model_type; + return kLiteParamInvalid; + } + if (model.impl_->session_ == nullptr) { + MS_LOG(ERROR) << "Model session is nullptr."; + return kLiteError; + } + if (!is_inference) { + MS_LOG(ERROR) << "Currently, can only export inference-model's weights."; + return kLiteNotSupport; + } + auto ret = model.impl_->session_->ExportWeightsCollaborateWithMicro(CharToString(weight_file), lite::MT_INFERENCE, + lite::FT_FLATBUFFERS, enable_fp16, + VectorCharToString(changeable_weights_name)); + + return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError; +} } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc index b5370ddd..0352ad19 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc @@ -35,6 +35,8 @@ int GroupConvolutionBaseCPUKernel::Prepare() { return ret; } } + conv_param_->input_channel_ *= group_num_; + conv_param_->output_channel_ *= group_num_; // if infer shape is done, resize func will be invoked in sub kernels return RET_OK; } @@ -99,11 +101,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front(); CHECK_NULL_RETURN(sub_kernel_in_tensor); sub_kernel_in_tensor->set_shape(in_shape); - ret = sub_kernel_in_tensor->MallocData(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "sub kernel in tensor malloc data failed."; - return ret; - } // out auto out_tensor = out_tensors_.front(); CHECK_NULL_RETURN(out_tensor); @@ -113,11 +110,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() { for (auto tensor : sub_kernel_out_tensors) { CHECK_NULL_RETURN(tensor); tensor->set_shape(out_shape); - ret = tensor->MallocData(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "sub kernel out tensor malloc data failed."; - return ret; - } } } ret = ReSize(); @@ -177,7 +169,22 @@ int GroupConvolutionBaseCPUKernel::Run() { ori_out_data_ = out_tensors_[0]->data(); CHECK_NULL_RETURN(ori_out_data_); for (int i = 0; i < group_num_; ++i) { - // first, separate group conv input into several parts. This step must be in runtime stage. + // first, malloc data for sub_kernel's tensors + auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front(); + ret = sub_kernel_in_tensor->MallocData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "sub kernel in tensor malloc data failed."; + return ret; + } + auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors(); + for (auto tensor : sub_kernel_out_tensors) { + ret = tensor->MallocData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "sub kernel out tensor malloc data failed."; + return ret; + } + } + // second, separate group conv input into several parts. This step must be in runtime stage. ret = SeparateInput(i); if (ret != RET_OK) { MS_LOG(ERROR) << "Separate input failed."; @@ -195,6 +202,11 @@ int GroupConvolutionBaseCPUKernel::Run() { MS_LOG(ERROR) << "Concat output failed."; return ret; } + // free data + sub_kernel_in_tensor->FreeData(); + for (auto tensor : sub_kernel_out_tensors) { + tensor->FreeData(); + } } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc index fc78a887..81b0aac2 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc @@ -53,15 +53,6 @@ void FreeCurrentConv(ConvParameter *conv_param, std::vector *new } } -static inline lite::Tensor *TensorMalloc(lite::Tensor *tensor) { - if (tensor->MallocData() != lite::RET_OK) { - delete tensor; - MS_LOG(ERROR) << "malloc tensor data failed."; - return nullptr; - } - return tensor; -} - lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector &shape, const int index) { auto new_tensor = new (std::nothrow) lite::Tensor(tensor->data_type(), shape, mindspore::NHWC, lite::Category::CONST_TENSOR); @@ -87,7 +78,7 @@ lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vectorset_format(tensor_info.format_); tensor->set_category(tensor_info.tensor_type_); tensor->set_shape(tensor_info.shape_); - - if (inferred) { - // set shape of out tensor - return TensorMalloc(tensor); - } + tensor->set_allocator(tensor_info.allocator_); return tensor; } @@ -129,7 +116,8 @@ void GroupConvCreator::FreeGroupConvs() { } int GroupConvCreator::NewInputTensor(std::vector *tensors) { - auto in_tensor = CreateVarTensor({input_shape_, mindspore::NHWC, data_type_, lite::Category::VAR, true}, infered_); + auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr; + auto in_tensor = CreateVarTensor({input_shape_, allocator, mindspore::NHWC, data_type_, lite::Category::VAR, true}); if (in_tensor == nullptr) { return lite::RET_ERROR; } @@ -138,7 +126,9 @@ int GroupConvCreator::NewInputTensor(std::vector *tensors) { } int GroupConvCreator::NewOutputTensor(std::vector *tensors, const lite::Tensor *output) const { - auto out_tensor = CreateVarTensor({output_shape_, output->format(), data_type_, output->category(), false}, infered_); + auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr; + auto out_tensor = + CreateVarTensor({output_shape_, allocator, output->format(), data_type_, output->category(), false}); if (out_tensor == nullptr) { return lite::RET_ERROR; } diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h index b4a4f768..27aa0cc8 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h +++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h @@ -22,10 +22,12 @@ #include "src/runtime/lite_kernel.h" #include "nnacl/conv_parameter.h" #include "src/runtime/tensor_category.h" +#include "include/api/allocator.h" namespace mindspore::kernel { struct TensorInfo { std::vector shape_; + AllocatorPtr allocator_; mindspore::Format format_; TypeId data_type_; lite::Category tensor_type_; @@ -35,8 +37,9 @@ struct TensorInfo { class GroupConvCreator { public: GroupConvCreator(std::vector inputs, std::vector outputs, OpParameter *op_parameter, - bool is_quant, TypeId data_type) - : origin_inputs_(std::move(inputs)), + bool is_quant, TypeId data_type, const lite::InnerContext *ctx) + : ms_context_(ctx), + origin_inputs_(std::move(inputs)), origin_outputs_(std::move(outputs)), is_quant_(is_quant), data_type_(data_type) { @@ -64,6 +67,7 @@ class GroupConvCreator { int NewOutputTensor(std::vector *tensors, const lite::Tensor *output) const; private: + const lite::InnerContext *ms_context_ = nullptr; std::vector origin_inputs_; std::vector origin_outputs_; std::vector group_convs_; diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc index 7aa823b0..17ba38ff 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc @@ -202,7 +202,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector &outputs, OpParameter *op_parameter, const InnerContext *ctx) { auto *group_conv_creator = - new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16); + new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16, ctx); if (group_conv_creator == nullptr) { MS_LOG(ERROR) << "new GroupConvCreator fail"; free(op_parameter); diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc new file mode 100644 index 00000000..7851eecb --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc @@ -0,0 +1,132 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/cpu/fp16/custom_gru_fp16.h" +#include +#include "src/runtime/kernel_registry.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "src/runtime/pack_weight_manager.h" +#include "nnacl/custom_gru_parameter.h" +#include "nnacl/fp16/custom_gru_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NOT_SUPPORT; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +int CustomGruFp16CPUKernel::InitWeightAndBias() { + auto weight_shape = in_tensors_[1]->shape(); + auto hidden_size = weight_shape[0] / C3NUM; + auto col_align = UP_ROUND(hidden_size, col_tile_); + auto weight_in_pack_size = static_cast(col_align * weight_shape[1]) * sizeof(float16_t); + bool is_packed = false; + weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData( + in_tensors_[SECOND_INPUT]->data(), static_cast(weight_in_pack_size * C3NUM), &is_packed); + MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed."); + if (!is_packed) { + auto weight_in_src = static_cast(in_tensors_[SECOND_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + RowMajor2Col8MajorFp16(weight_in_src + i * hidden_size * weight_shape[1], + static_cast(weight_in_) + i * col_align * weight_shape[1], hidden_size, + weight_shape[1], false); + } + } + auto weight_hidden_pack_size = static_cast(col_align * hidden_size) * sizeof(float16_t); + is_packed = false; + weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData( + in_tensors_[THIRD_INPUT]->data(), static_cast(weight_hidden_pack_size * C3NUM), &is_packed); + MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed."); + if (!is_packed) { + auto weight_hidden_src = static_cast(in_tensors_[THIRD_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + RowMajor2Col8MajorFp16(weight_hidden_src + i * hidden_size * weight_shape[1], + static_cast(weight_hidden_) + i * col_align * weight_shape[1], hidden_size, + hidden_size, false); + } + } + auto bias_pack_size = static_cast(col_align) * sizeof(float16_t); + auto bias = reinterpret_cast(malloc(bias_pack_size * C6NUM)); + if (bias == nullptr) { + MS_LOG(ERROR) << "malloc for packing bias failed."; + return lite::RET_NULL_PTR; + } + (void)memset(bias, 0, bias_pack_size * C6NUM); + bias_in_ = bias; + bias_hidden_ = bias + col_align * C3NUM; + auto bias_in_src = static_cast(in_tensors_[FOURTH_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float16_t)); + } + auto bias_hidden_src = static_cast(in_tensors_[FIFTH_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float16_t)); + } + if (in_tensors_[SIXTH_INPUT]->IsConst()) { + init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size()); + MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed."); + (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size()); + } + return RET_OK; +} + +int CustomGruFp16CPUKernel::Run() { + auto input = reinterpret_cast(in_tensors_[FIRST_INPUT]->data()); + if (input == nullptr) { + MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_; + return lite::RET_NULL_PTR; + } + if (!in_tensors_[SIXTH_INPUT]->IsConst()) { + init_h_ = in_tensors_[SIXTH_INPUT]->data(); + } + if (init_h_ == nullptr) { + MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_; + return lite::RET_NULL_PTR; + } + auto output = reinterpret_cast(out_tensors_.front()->data()); + if (output == nullptr) { + MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_; + return lite::RET_NULL_PTR; + } + MallocRunBuffer(sizeof(float16_t)); + if (run_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc running buffer failed." << name_; + return lite::RET_NULL_PTR; + } + auto param = reinterpret_cast(op_parameter_); + auto row_align = UP_ROUND(param->batch_size, row_tile_); + auto run_buffer = reinterpret_cast(run_buffer_); + float16_t *buffer[C4NUM] = { + run_buffer, run_buffer + row_align * param->input_size, + run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM, + run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM}; + CustomGruFp16(output, input, static_cast(weight_in_), static_cast(weight_hidden_), + static_cast(bias_in_), static_cast(bias_hidden_), + static_cast(init_h_), buffer, param); + if (ms_context_->allocator != nullptr) { + ms_context_->allocator->Free(run_buffer_); + } else { + free(run_buffer_); + } + run_buffer_ = nullptr; + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimType_Inner_CustomGru, LiteKernelCreator) +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h new file mode 100644 index 00000000..d7ed313a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_ +#ifdef ENABLE_ARM64 +#include +#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h" + +namespace mindspore::kernel { +class CustomGruFp16CPUKernel : public CustomGruCPUKernel { + public: + CustomGruFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : CustomGruCPUKernel(parameter, inputs, outputs, ctx) { + row_tile_ = C4NUM; + col_tile_ = C8NUM; + } + ~CustomGruFp16CPUKernel() override = default; + int Run() override; + + protected: + int InitWeightAndBias() override; +}; +} // namespace mindspore::kernel +#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc index bbbf8488..3514e5b4 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc @@ -359,7 +359,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, const lite::InnerContext *ctx) { - auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32); + auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32, ctx); auto group_kernel = new (std::nothrow) GroupConvolutionFp32CPUKernel( op_parameter, inputs, outputs, ctx, group_conv_creator, reinterpret_cast(op_parameter)->group_); if (group_kernel == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc new file mode 100644 index 00000000..c85a1283 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc @@ -0,0 +1,251 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h" +#include +#include "src/runtime//kernel_registry.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "src/runtime//pack_weight_manager.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/custom_gru_parameter.h" +#include "nnacl/fp32/custom_gru_fp32.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_NOT_SUPPORT; +using mindspore::lite::RET_OK; + +namespace mindspore::kernel { +CustomGruCPUKernel::~CustomGruCPUKernel() { + if (weight_in_) { + lite::PackWeightManager::GetInstance()->Free(weight_in_); + weight_in_ = nullptr; + } + if (weight_hidden_) { + lite::PackWeightManager::GetInstance()->Free(weight_hidden_); + weight_hidden_ = nullptr; + } + if (bias_in_) { + free(bias_in_); + bias_in_ = nullptr; + bias_hidden_ = nullptr; + } + if (in_tensors_[SIXTH_INPUT]->IsConst() && init_h_) { + free(init_h_); + init_h_ = nullptr; + } +} + +int CustomGruCPUKernel::Prepare() { + CHECK_LESS_RETURN(in_tensors_.size(), C6NUM); + CHECK_LESS_RETURN(out_tensors_.size(), 1); + if (in_tensors_[FIRST_INPUT]->IsConst()) { + MS_LOG(ERROR) << "Built-in CustomGru first-input must be a variable." << name_; + return RET_NOT_SUPPORT; + } + for (size_t i = 1; i < C5NUM; ++i) { + if (!in_tensors_[i]->IsConst()) { + MS_LOG(ERROR) << "Built-in CustomGru only support first-input and last-input is variable." << name_; + return RET_NOT_SUPPORT; + } + } + if (InitParamter() != RET_OK) { + MS_LOG(ERROR) << "Init Built-in CustomGru Parameter failed." << name_; + return RET_ERROR; + } + if (InitWeightAndBias() != RET_OK) { + MS_LOG(ERROR) << "Init Built-in CustomGru Weight and bias failed." << name_; + return RET_ERROR; + } + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int CustomGruCPUKernel::InitParamter() { + auto param = reinterpret_cast(op_parameter_); + thread_num_ = 1; + auto weight_in_shape = in_tensors_[1]->shape(); + auto weight_hidden_shape = in_tensors_[C2NUM]->shape(); + if (weight_in_shape.size() != C2NUM || weight_hidden_shape.size() != C2NUM) { + MS_LOG(ERROR) << "Built-in CustomGru's weight must be 2D." << name_; + return RET_ERROR; + } + if (weight_in_shape[0] != weight_hidden_shape[0]) { + MS_LOG(ERROR) << "Built-in CustomGru's weight-in and weight-hidden first-dim must be same." << name_; + return RET_ERROR; + } + if (weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) { + MS_LOG(ERROR) << "Built-in CustomGru's weight-hidden first-dim must be 3 * second-dim." << name_; + return RET_ERROR; + } + auto bias_in_shape = in_tensors_[C3NUM]->shape(); + auto bias_hidden_shape = in_tensors_[C4NUM]->shape(); + if (bias_in_shape.size() != 1) { + MS_LOG(ERROR) << "Built-in CustomGru's bias must be 1D." << name_; + return RET_ERROR; + } + if (bias_in_shape != bias_hidden_shape) { + MS_LOG(ERROR) << "Built-in CustomGru's bias-in and bias-hidden must have same shape." << name_; + return RET_ERROR; + } + if (bias_in_shape.back() != weight_in_shape.front()) { + MS_LOG(ERROR) << "Built-in CustomGru's bias-in shape don't match with the first-dim of weight." << name_; + return RET_ERROR; + } + if (bias_in_shape.front() % C3NUM != 0) { + MS_LOG(ERROR) << "The first-dim of CustomGru's weight must be 3 * hidden."; + return RET_ERROR; + } + param->input_size = weight_in_shape.back(); + param->hidden_size = bias_in_shape.front() / C3NUM; + return RET_OK; +} + +int CustomGruCPUKernel::InitWeightAndBias() { + auto weight_shape = in_tensors_[1]->shape(); + auto hidden_size = weight_shape[0] / C3NUM; + auto col_align = UP_ROUND(hidden_size, col_tile_); + auto weight_in_pack_size = static_cast(col_align * weight_shape[1]) * sizeof(float); + bool is_packed = false; + weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData( + in_tensors_[SECOND_INPUT]->data(), static_cast(weight_in_pack_size * C3NUM), &is_packed); + MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed."); + if (!is_packed) { + auto weight_in_src = static_cast(in_tensors_[SECOND_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + RowMajor2Col8Major(weight_in_src + i * hidden_size * weight_shape[1], + static_cast(weight_in_) + i * col_align * weight_shape[1], hidden_size, + weight_shape[1]); + } + } + auto weight_hidden_pack_size = static_cast(col_align * hidden_size) * sizeof(float); + is_packed = false; + weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData( + in_tensors_[THIRD_INPUT]->data(), static_cast(weight_hidden_pack_size * C3NUM), &is_packed); + MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed."); + if (!is_packed) { + auto weight_hidden_src = static_cast(in_tensors_[THIRD_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + RowMajor2Col8Major(weight_hidden_src + i * hidden_size * weight_shape[1], + static_cast(weight_hidden_) + i * col_align * weight_shape[1], hidden_size, + hidden_size); + } + } + auto bias_pack_size = static_cast(col_align) * sizeof(float); + auto bias = reinterpret_cast(malloc(bias_pack_size * C6NUM)); + if (bias == nullptr) { + MS_LOG(ERROR) << "malloc for packing bias failed."; + return lite::RET_NULL_PTR; + } + (void)memset(bias, 0, bias_pack_size * C6NUM); + bias_in_ = bias; + bias_hidden_ = bias + col_align * C3NUM; + auto bias_in_src = static_cast(in_tensors_[FOURTH_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float)); + } + auto bias_hidden_src = static_cast(in_tensors_[FIFTH_INPUT]->data()); + for (int i = 0; i < C3NUM; ++i) { + (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float)); + } + if (in_tensors_[SIXTH_INPUT]->IsConst()) { + init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size()); + MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed."); + (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size()); + } + return RET_OK; +} + +int CustomGruCPUKernel::ReSize() { + auto in_shape = in_tensors_.front()->shape(); + if (in_shape.size() != C3NUM) { + MS_LOG(ERROR) << "Built-in CustomGru's first-input must be 3D." << name_; + return RET_ERROR; + } + auto param = reinterpret_cast(op_parameter_); + param->num_step = in_shape[0]; + param->batch_size = in_shape[1]; + if (in_shape.back() != param->input_size) { + MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input don't match its weight." << name_; + return RET_ERROR; + } + return RET_OK; +} + +void CustomGruCPUKernel::MallocRunBuffer(size_t data_type_size) { + if (run_buffer_ != nullptr) { + return; + } + auto param = reinterpret_cast(op_parameter_); + auto row_align = UP_ROUND(param->batch_size, row_tile_); + auto run_buffer_size = + (row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C6NUM) * + data_type_size; + if (ms_context_->allocator != nullptr) { + run_buffer_ = ms_context_->allocator->Malloc(run_buffer_size); + } else { + run_buffer_ = malloc(run_buffer_size); + } +} + +int CustomGruCPUKernel::Run() { + auto input = reinterpret_cast(in_tensors_[FIRST_INPUT]->data()); + if (input == nullptr) { + MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_; + return lite::RET_NULL_PTR; + } + if (!in_tensors_[SIXTH_INPUT]->IsConst()) { + init_h_ = in_tensors_[SIXTH_INPUT]->data(); + } + if (init_h_ == nullptr) { + MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_; + return lite::RET_NULL_PTR; + } + auto output = reinterpret_cast(out_tensors_.front()->data()); + if (output == nullptr) { + MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_; + return lite::RET_NULL_PTR; + } + MallocRunBuffer(sizeof(float)); + if (run_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc running buffer failed." << name_; + return lite::RET_NULL_PTR; + } + auto param = reinterpret_cast(op_parameter_); + auto row_align = UP_ROUND(param->batch_size, row_tile_); + auto run_buffer = reinterpret_cast(run_buffer_); + float *buffer[C4NUM] = { + run_buffer, run_buffer + row_align * param->input_size, + run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM, + run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM}; + CustomGru(output, input, static_cast(weight_in_), static_cast(weight_hidden_), + static_cast(bias_in_), static_cast(bias_hidden_), static_cast(init_h_), buffer, + param); + if (ms_context_->allocator != nullptr) { + ms_context_->allocator->Free(run_buffer_); + } else { + free(run_buffer_); + } + run_buffer_ = nullptr; + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomGru, LiteKernelCreator) +} // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h new file mode 100644 index 00000000..e70e408f --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h @@ -0,0 +1,51 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_ +#ifdef ENABLE_ARM64 +#include +#include "src/runtime/lite_kernel.h" + +namespace mindspore::kernel { +class CustomGruCPUKernel : public LiteKernel { + public: + CustomGruCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : LiteKernel(parameter, inputs, outputs, ctx) {} + ~CustomGruCPUKernel() override; + int Prepare() override; + int ReSize() override; + int Run() override; + + private: + int InitParamter(); + + protected: + void MallocRunBuffer(size_t data_type_size); + virtual int InitWeightAndBias(); + int row_tile_{C12NUM}; + int col_tile_{C8NUM}; + void *weight_in_{nullptr}; + void *weight_hidden_{nullptr}; + void *bias_in_{nullptr}; + void *bias_hidden_{nullptr}; + void *init_h_{nullptr}; + void *run_buffer_{nullptr}; +}; +} // namespace mindspore::kernel +#endif +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc index ef03171a..0ff780c7 100644 --- a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc +++ b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc @@ -107,7 +107,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vectorinput_channel_; return nullptr; } - auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8); + auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8, ctx); if (group_conv_creator == nullptr) { MS_LOG(ERROR) << "group_conv_creator is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h index 2fdb1eb7..d5a672bb 100644 --- a/mindspore/lite/src/runtime/lite_session.h +++ b/mindspore/lite/src/runtime/lite_session.h @@ -106,6 +106,12 @@ class MS_API LiteSession { std::vector out_put_tensor_name = {}) { return mindspore::lite::RET_ERROR; } + virtual int ExportWeightsCollaborateWithMicro(const std::string &file_name, + lite::ModelType model_type = lite::MT_TRAIN, + lite::FormatType = lite::FT_FLATBUFFERS, bool enable_fp16 = false, + const std::vector &changeable_weights_name = {}) { + return mindspore::lite::RET_ERROR; + } virtual int UpdateWeights(std::vector new_weights) { return mindspore::lite::RET_ERROR; } virtual std::vector GetFeatureMaps() const { std::vector features; diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc index 1af44e45..03f5675e 100644 --- a/mindspore/lite/src/train/graph_fusion.cc +++ b/mindspore/lite/src/train/graph_fusion.cc @@ -22,6 +22,7 @@ #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" +#include "src/train/optimizer/fusion/gru_fusion_pass.h" #include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h" #include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h" @@ -41,6 +42,12 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) { MS_LOG(ERROR) << "graph is nullptr."; return RET_ERROR; } + auto gru_fusion = std::make_shared(); + MS_CHECK_TRUE_MSG(gru_fusion != nullptr, RET_NULL_PTR, "Create GruFusion object failed."); + if (gru_fusion->Run(graph) != RET_OK) { + MS_LOG(ERROR) << "Do GruFusion failed."; + return RET_ERROR; + } auto old_nodes = GetGraphNodes(*graph); Optimizer fusion_optimizer; fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass()); diff --git a/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc new file mode 100644 index 00000000..435686e5 --- /dev/null +++ b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc @@ -0,0 +1,809 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/train/optimizer/fusion/gru_fusion_pass.h" +#include +#include +#include +#include +#include +#include +#include +#include "src/common/log_adapter.h" +#include "include/errorcode.h" +#include "nnacl/op_base.h" + +namespace mindspore { +namespace lite { +namespace { +constexpr size_t kSplitOutSize = 3; +constexpr uint32_t kAdd0 = 0; +constexpr uint32_t kAdd1 = 1; +constexpr uint32_t kAdd2 = 2; +constexpr uint32_t kAdd3 = 3; +constexpr uint32_t kAdd4 = 4; +constexpr uint32_t kAdd5 = 5; +constexpr uint32_t kSub = 6; +constexpr uint32_t kMul0 = 7; +constexpr uint32_t kMul1 = 8; +constexpr uint32_t kTanh = 9; +constexpr uint32_t kSigmoid0 = 10; +constexpr uint32_t kSigmoid1 = 11; +constexpr uint32_t kSplit0 = 12; +constexpr uint32_t kSplit1 = 13; +constexpr uint32_t kMatmul0 = 14; +constexpr uint32_t kMatmul1 = 15; +constexpr uint32_t kInputH = 16; +constexpr uint32_t kInputI = 17; +constexpr auto kCustomGRU = "CustomGRU"; + +bool CheckCommon(schema::MetaGraphT *graph, uint32_t node_index, schema::PrimitiveType type, size_t in_nums, + size_t out_nums) { + if (graph->nodes.size() <= node_index) { + return false; + } + const auto &node = graph->nodes[node_index]; + if (node == nullptr || node->primitive == nullptr) { + return false; + } + const auto &value = node->primitive->value; + if (value.type != type) { + return false; + } + if (value.value == nullptr) { + return false; + } + if ((in_nums > 0 && node->inputIndex.size() != in_nums) || node->outputIndex.size() != out_nums) { + return false; + } + return std::all_of(node->inputIndex.begin(), node->inputIndex.end(), + [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; }) && + std::all_of(node->outputIndex.begin(), node->outputIndex.end(), + [&graph](uint32_t tensor_index) { return graph->allTensors.size() > tensor_index; }); +} + +template +bool CheckArithmetic(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, T, kInputSize1, 1)) { + return false; + } + const auto &node = graph->nodes[node_index]; + const auto &value = node->primitive->value; + const auto add_attr = static_cast(value.value); + if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) { + return false; + } + auto tensor_indexes = node->inputIndex; + (void)tensor_indexes.insert(tensor_indexes.end(), node->outputIndex.begin(), node->outputIndex.end()); + std::vector shape; + for (size_t i = 0; i < tensor_indexes.size(); ++i) { + if (i == 0) { + shape = graph->allTensors[tensor_indexes[i]]->dims; + continue; + } + if (graph->allTensors[tensor_indexes[i]]->dims != shape) { + return false; + } + } + return true; +} + +template +bool CheckActivation(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_Activation, 1, 1)) { + return false; + } + const auto &value = graph->nodes[node_index]->primitive->value; + const auto add_attr = static_cast(value.value); + if (add_attr->activation_type != T) { + return false; + } + return true; +} + +bool CheckBiasAdd(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_AddFusion, kInputSize1, 1) && + !CheckCommon(graph, node_index, schema::PrimitiveType_BiasAdd, kInputSize1, 1)) { + return false; + } + const auto &node = graph->nodes[node_index]; + const auto &value = node->primitive->value; + if (value.type == schema::PrimitiveType_AddFusion) { + const auto add_attr = static_cast(value.value); + if (add_attr->activation_type != schema::ActivationType_NO_ACTIVATION) { + return false; + } + } + auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims; + auto in_shape1 = graph->allTensors[node->inputIndex[1]]->dims; + if (in_shape1.size() != 1 || in_shape0.empty() || in_shape0.back() != in_shape1.back()) { + return false; + } + return true; +} + +bool CheckMatmul(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_MatMulFusion, kInputSize1, 1)) { + return false; + } + const auto &node = graph->nodes[node_index]; + const auto &value = node->primitive->value; + const auto matmul_attr = static_cast(value.value); + if (matmul_attr->activation_type != schema::ActivationType_NO_ACTIVATION) { + return false; + } + auto out_shape = graph->allTensors[node->outputIndex.front()]->dims; + return out_shape.size() == kInputSize1; +} + +bool CheckSplit(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_Split, 1, kSplitOutSize)) { + return false; + } + const auto &node = graph->nodes[node_index]; + if (node->inputIndex.size() != 1 || node->outputIndex.size() != kSplitOutSize) { + return false; + } + auto in_shape0 = graph->allTensors[node->inputIndex[0]]->dims; + auto out_shape0 = graph->allTensors[node->outputIndex[0]]->dims; + auto out_shape1 = graph->allTensors[node->outputIndex[1]]->dims; + auto out_shape2 = graph->allTensors[node->outputIndex[kInputSize1]]->dims; + if (out_shape0 != out_shape1 || out_shape0 != out_shape2) { + return false; + } + if (in_shape0.empty() || out_shape0.empty()) { + return false; + } + if (in_shape0.back() != (out_shape0.back() + out_shape1.back() + out_shape2.back())) { + return false; + } + return true; +} + +bool CheckStack(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_Stack, 0, 1)) { + return false; + } + const auto &node = graph->nodes[node_index]; + const auto &value = node->primitive->value; + const auto stack_attr = static_cast(value.value); + auto out_shape = graph->allTensors[node->outputIndex.front()]->dims; + if (out_shape.empty()) { + return false; + } + auto axis = stack_attr->axis; + if (axis < 0) { + axis += static_cast(out_shape.size()); + } + return axis == 0; +} + +bool CheckSqueeze(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_Squeeze, 0, 1)) { + return false; + } + const auto &node = graph->nodes[node_index]; + if (node->inputIndex.size() != 1 && node->inputIndex.size() != kInputSize1) { + return false; + } + int axis = 0; + if (node->inputIndex.size() == kInputSize1) { + const auto &data = graph->allTensors[node->inputIndex[1]]->data; + if (data.size() != sizeof(int)) { + return false; + } + axis = *(reinterpret_cast(data.data())); + } else { + const auto &value = node->primitive->value; + const auto squeeze_attr = static_cast(value.value); + if (squeeze_attr->axis.size() != 1) { + return false; + } + axis = squeeze_attr->axis.front(); + } + auto in_shape = graph->allTensors[node->inputIndex[0]]->dims; + if (in_shape.empty()) { + return false; + } + if (axis < 0) { + axis += static_cast(in_shape.size()); + } + return axis == 0; +} + +std::vector GetStridedSlicePoints(const schema::TensorT *tensor, int64_t mask) { + if (tensor->data.empty()) { + return {}; + } + auto origin_data = reinterpret_cast(tensor->data.data()); + size_t num = tensor->data.size() / sizeof(int); + std::vector data; + for (size_t i = 0; i < num; ++i) { + bool ineffective = (mask & (1 << i)); + int cur_point = ineffective ? 0 : origin_data[i]; + data.push_back(cur_point); + } + return data; +} + +bool CheckStridedSlice(schema::MetaGraphT *graph, uint32_t node_index, int batch_position) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_StridedSlice, C4NUM, 1)) { + return false; + } + const auto &node = graph->nodes[node_index]; + const auto &step_tensor = graph->allTensors[node->inputIndex.back()]; + if (!step_tensor->data.empty()) { + const auto data = reinterpret_cast(step_tensor->data.data()); + auto size = step_tensor->data.size() / sizeof(int); + if (std::any_of(data, data + size, [](int val) { return val != 1; })) { + return false; + } + } + auto in_shape = graph->allTensors[node->inputIndex.front()]->dims; + auto out_shape = graph->allTensors[node->outputIndex.back()]->dims; + if (in_shape.size() != out_shape.size() || in_shape.empty()) { + return false; + } + for (size_t i = 1; i < in_shape.size(); ++i) { + if (in_shape[i] != out_shape[i]) { + return false; + } + } + const auto &value = node->primitive->value; + const auto strided_slice_attr = static_cast(value.value); + if (strided_slice_attr->ellipsis_mask != 0 || strided_slice_attr->new_axis_mask != 0 || + strided_slice_attr->shrink_axis_mask != 0) { + return false; + } + auto begin = GetStridedSlicePoints(graph->allTensors[node->inputIndex[1]].get(), strided_slice_attr->begin_mask); + if (begin.empty()) { + return false; + } + return begin.front() == batch_position; +} + +bool CheckGruCell(schema::MetaGraphT *graph, uint32_t node_index) { + if (!CheckCommon(graph, node_index, schema::PrimitiveType_Custom, C6NUM, 1)) { + return false; + } + const auto &node = graph->nodes[node_index]; + const auto &value = node->primitive->value; + const auto gru_attr = static_cast(value.value); + return gru_attr->type == kCustomGRU; +} + +std::unique_ptr CreateCustom() { + auto ConvertToAttr = [](const std::string &key, const std::vector &value) { + auto attr = std::make_unique(); + attr->name = key; + attr->data = value; + return attr; + }; + auto attrs = std::make_unique(); + MS_CHECK_TRUE_MSG(attrs != nullptr, nullptr, "Create CustomT failed."); + attrs->type = kCustomGRU; + std::vector transpose_a{false}; + std::vector transpose_b{true}; + std::vector built_in{true}; + + attrs->attr.push_back(ConvertToAttr("transpose_a", transpose_a)); + attrs->attr.push_back(ConvertToAttr("transpose_b", transpose_b)); + attrs->attr.push_back(ConvertToAttr("builtin", built_in)); + return attrs; +} + +struct InNodeInfo { + int node_index; + std::vector in_indexes; +}; + +struct OutNodeInfo { + int node_index; + uint32_t out_index; +}; + +struct camp { + bool operator()(uint32_t left, uint32_t right) const { return left > right; } +}; +} // namespace + +class LinkInfoManager { + public: + explicit LinkInfoManager(schema::MetaGraphT *graph) : graph_{graph} { + auto &all_nodes = graph->nodes; + for (int node_index = 0; node_index < static_cast(all_nodes.size()); ++node_index) { + auto in_indexes = all_nodes[node_index]->inputIndex; + for (uint32_t index = 0; index < static_cast(in_indexes.size()); ++index) { + if (link_info_manager_.find(in_indexes[index]) == link_info_manager_.end()) { + link_info_manager_[in_indexes[index]] = std::make_pair(std::vector{}, OutNodeInfo{-1, 0}); + } + auto &in_infos = link_info_manager_[in_indexes[index]].first; + auto iter = in_infos.begin(); + for (; iter != in_infos.end(); ++iter) { + if (iter->node_index == node_index) { + break; + } + } + if (iter != in_infos.end()) { + iter->in_indexes.push_back(index); + } else { + in_infos.push_back({node_index, {index}}); + } + } + + auto out_indexes = all_nodes[node_index]->outputIndex; + for (uint32_t index = 0; index < static_cast(out_indexes.size()); ++index) { + link_info_manager_[out_indexes[index]].second = OutNodeInfo{node_index, index}; + } + } + } + + const auto &GetLinkInfos() const { return link_info_manager_; } + + void Replace(uint32_t node_index, std::unique_ptr node) { graph_->nodes[node_index].swap(node); } + + void AddDeleteNodes(const std::set &node_indexes) { + delete_nodes_.insert(node_indexes.begin(), node_indexes.end()); + } + + void UpdateMetaGraph() { + auto &main_graph = graph_->subGraph.front(); + for (auto node_index : delete_nodes_) { + graph_->nodes.erase(graph_->nodes.begin() + node_index); + } + main_graph->nodeIndices.clear(); + for (uint32_t index = 0; index < static_cast(graph_->nodes.size()); ++index) { + main_graph->nodeIndices.push_back(index); + } + std::map tensor_maps; + BuildTensorMap(&tensor_maps); + auto UpdateTensorIndex = [&tensor_maps](std::vector *origin) { + auto origin_indexes = *origin; + origin->clear(); + (void)std::transform(origin_indexes.begin(), origin_indexes.end(), std::back_inserter(*origin), + [&tensor_maps](uint32_t origin_index) { return tensor_maps[origin_index]; }); + }; + UpdateTensorIndex(&graph_->inputIndex); + for (auto &node : graph_->nodes) { + UpdateTensorIndex(&node->inputIndex); + UpdateTensorIndex(&node->outputIndex); + } + UpdateTensorIndex(&graph_->outputIndex); + main_graph->inputIndices = graph_->inputIndex; + main_graph->outputIndices = graph_->outputIndex; + main_graph->tensorIndices.clear(); + for (uint32_t index = 0; index < static_cast(tensor_maps.size()); ++index) { + main_graph->tensorIndices.push_back(index); + } + std::vector> tensors; + graph_->allTensors.swap(tensors); + graph_->allTensors.resize(tensor_maps.size()); + for (auto &tensor_map : tensor_maps) { + graph_->allTensors[tensor_map.second].swap(tensors[tensor_map.first]); + } + } + + private: + void BuildTensorMap(std::map *tensor_maps) { + uint32_t new_index = 0; + auto InsertElements = [tensor_maps, &new_index](const std::vector &indexes) mutable { + for (auto index : indexes) { + if (tensor_maps->find(index) != tensor_maps->end()) { + continue; + } + (*tensor_maps)[index] = new_index++; + } + }; + InsertElements(graph_->inputIndex); + for (auto &node : graph_->nodes) { + InsertElements(node->inputIndex); + InsertElements(node->outputIndex); + } + InsertElements(graph_->outputIndex); + } + + schema::MetaGraphT *graph_{nullptr}; + std::set delete_nodes_; + // tensor_index, + std::map, OutNodeInfo>> link_info_manager_; +}; + +class GruCellFusion { + public: + GruCellFusion() = default; + ~GruCellFusion() = default; + STATUS Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(graph->subGraph.size() == 1); + link_info_manager_ = std::make_shared(graph); + graph_ = graph; + DefinePattern(); + for (uint32_t node_index = 0; node_index < static_cast(graph->nodes.size()); ++node_index) { + if (!MatchPattern(node_index)) { + continue; + } + if (CreateCustomGruCell() != RET_OK) { + MS_LOG(ERROR) << "Create Custom-Gru failed."; + return RET_ERROR; + } + } + link_info_manager_->UpdateMetaGraph(); + return RET_OK; + } + + private: + struct NodeInfo { + struct InTensorInfo { + bool is_const{false}; + uint32_t node_index_{0}; + uint32_t tensor_index_{0}; + }; + struct OutTensorInfo { + uint32_t node_index_{0}; + uint32_t tensor_index_{0}; + }; + bool (*checker)(schema::MetaGraphT *graph, uint32_t node_index); + std::vector in_infos; + std::vector out_infos; + }; + + void DefinePattern() { + int match_order = 0; + pattern_[{match_order++, kAdd0}] = { + CheckArithmetic, {{false, kTanh, 0}, {false, kMul0, 0}}, {}}; + pattern_[{match_order++, kTanh}] = { + CheckActivation, {{false, kAdd1, 0}}, {{kSub, 1}, {kAdd0, 0}}}; + pattern_[{match_order++, kMul0}] = {CheckArithmetic, + {{false, kSigmoid0, 0}, {false, kSub, 0}}, + {{kAdd0, 1}}}; + pattern_[{match_order++, kAdd1}] = {CheckArithmetic, + {{false, kSplit0, 2}, {false, kMul1, 0}}, + {{kTanh, 0}}}; + pattern_[{match_order++, kSub}] = {CheckArithmetic, + {{false, kInputH, 0}, {false, kTanh, 0}}, + {{kMul0, 1}}}; + pattern_[{match_order++, kSigmoid0}] = { + CheckActivation, {{false, kAdd2, 0}}, {{kMul0, 0}}}; + pattern_[{match_order++, kSplit0}] = {CheckSplit, {{false, kAdd3, 0}}, {{kAdd4, 0}, {kAdd2, 0}, {kAdd1, 0}}}; + pattern_[{match_order++, kMul1}] = {CheckArithmetic, + {{false, kSigmoid1, 0}, {false, kSplit1, 2}}, + {{kAdd1, 1}}}; + pattern_[{match_order++, kAdd2}] = {CheckArithmetic, + {{false, kSplit0, 1}, {false, kSplit1, 1}}, + {{kSigmoid0, 0}}}; + pattern_[{match_order++, kSigmoid1}] = { + CheckActivation, {{false, kAdd4, 0}}, {{kMul1, 0}}}; + pattern_[{match_order++, kAdd3}] = {CheckBiasAdd, {{false, kMatmul0, 0}, {true}}, {{kSplit0, 0}}}; + pattern_[{match_order++, kSplit1}] = {CheckSplit, {{false, kAdd5, 0}}, {{kAdd4, 1}, {kAdd2, 1}, {kMul1, 1}}}; + pattern_[{match_order++, kAdd4}] = {CheckArithmetic, + {{false, kSplit0, 0}, {false, kSplit1, 0}}, + {{kSigmoid1, 0}}}; + pattern_[{match_order++, kAdd5}] = {CheckBiasAdd, {{false, kMatmul1, 0}, {true}}, {{kSplit1, 0}}}; + pattern_[{match_order++, kMatmul0}] = {CheckMatmul, {{false, kInputI, 0}, {true}}, {{kAdd3, 0}}}; + pattern_[{match_order++, kMatmul1}] = {CheckMatmul, {{false, kInputH, 0}, {true}}, {{kAdd5, 0}}}; + } + + bool FillRealPattern(uint32_t node_index, std::map *real_pattern) { + const auto &link_infos = link_info_manager_->GetLinkInfos(); + if (real_pattern->find(node_index) != real_pattern->end()) { + return false; + } + real_pattern->insert({node_index, {nullptr}}); + auto in_tensor_indexes = graph_->nodes[node_index]->inputIndex; + for (auto tensor_index : in_tensor_indexes) { + if (link_infos.find(tensor_index) == link_infos.end()) { + return false; + } + const auto &tensor_out_info = link_infos.at(tensor_index).second; + if (tensor_out_info.node_index < 0) { + real_pattern->at(node_index).in_infos.push_back({true}); + } else { + real_pattern->at(node_index) + .in_infos.push_back({false, static_cast(tensor_out_info.node_index), tensor_out_info.out_index}); + } + } + auto out_tensor_indexes = graph_->nodes[node_index]->outputIndex; + for (auto tensor_index : out_tensor_indexes) { + if (link_infos.find(tensor_index) == link_infos.end()) { + return false; + } + const auto &in_tensor_out_info = link_infos.at(tensor_index).first; + for (const auto &in_node_info : in_tensor_out_info) { + for (auto index : in_node_info.in_indexes) { + real_pattern->at(node_index).out_infos.push_back({static_cast(in_node_info.node_index), index}); + } + } + } + return true; + } + + bool CheckPattern(const std::map &real_pattern, + const std::pair &pattern_node_index) { + const auto &real_in_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).in_infos; + const auto &virtual_in_infos = pattern_.at(pattern_node_index).in_infos; + if (real_in_infos.size() != virtual_in_infos.size()) { + return false; + } + for (size_t i = 0; i < virtual_in_infos.size(); ++i) { + if (virtual_in_infos[i].is_const) { + if (!real_in_infos[i].is_const) { + return false; + } + continue; + } + if (virtual_in_infos[i].tensor_index_ != real_in_infos[i].tensor_index_) { + return false; + } + if (real_node_map_.find(virtual_in_infos[i].node_index_) == real_node_map_.end()) { + real_node_map_.insert({virtual_in_infos[i].node_index_, real_in_infos[i].node_index_}); + } else if (real_node_map_.at(virtual_in_infos[i].node_index_) != real_in_infos[i].node_index_) { + return false; + } + } + const auto &real_out_infos = real_pattern.at(real_node_map_.at(pattern_node_index.second)).out_infos; + const auto &virtual_out_infos = pattern_.at(pattern_node_index).out_infos; + if (virtual_out_infos.empty()) { + return true; + } + if (real_out_infos.size() != virtual_out_infos.size()) { + return false; + } + for (size_t i = 0; i < virtual_out_infos.size(); ++i) { + if (virtual_out_infos[i].tensor_index_ != real_out_infos[i].tensor_index_) { + return false; + } + if (real_node_map_.find(virtual_out_infos[i].node_index_) == real_node_map_.end()) { + real_node_map_.insert({virtual_out_infos[i].node_index_, real_out_infos[i].node_index_}); + } else if (real_node_map_.at(virtual_out_infos[i].node_index_) != real_out_infos[i].node_index_) { + return false; + } + } + return true; + } + + bool CheckClosure(const std::map &node_map) { + std::set real_nodes; + (void)std::for_each(node_map.begin(), node_map.end(), + [&real_nodes](std::pair pair) { real_nodes.insert(pair.second); }); + if (real_nodes.size() != node_map.size()) { + return false; + } + const auto &link_infos = link_info_manager_->GetLinkInfos(); + for (uint32_t start = kAdd1; start <= kMatmul1; ++start) { + if (node_map.find(start) == node_map.end()) { + return false; + } + const auto &node = graph_->nodes[node_map.at(start)]; + auto out_tensor_indexes = node->outputIndex; + for (auto out_index : out_tensor_indexes) { + if (link_infos.find(out_index) == link_infos.end()) { + return false; + } + for (const auto &in_node_info : link_infos.at(out_index).first) { + if (real_nodes.find(in_node_info.node_index) == real_nodes.end()) { + return false; + } + } + } + } + return true; + } + + bool MatchPattern(uint32_t add_index) { + real_node_map_.clear(); + real_node_map_[kAdd0] = add_index; + std::map real_pattern; + for (const auto &pair : pattern_) { + if (real_node_map_.find(pair.first.second) == real_node_map_.end()) { + return false; + } + auto node_index = real_node_map_[pair.first.second]; + if (!pair.second.checker(graph_, node_index)) { + return false; + } + if (!FillRealPattern(node_index, &real_pattern)) { + return false; + } + if (!CheckPattern(real_pattern, pair.first)) { + return false; + } + } + auto weight_hidden_index = graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1]; + auto weight_hidden_shape = graph_->allTensors[weight_hidden_index]->dims; + if (weight_hidden_shape.size() != C2NUM || weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) { + return false; + } + return CheckClosure(real_node_map_); + } + + STATUS CreateCustomGruCell() { + std::vector inputs; + inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[0]); // x + inputs.push_back(graph_->nodes[real_node_map_[kMatmul0]]->inputIndex[1]); // weight_input + inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[1]); // weight_hidden + inputs.push_back(graph_->nodes[real_node_map_[kAdd3]]->inputIndex[1]); // bias_input + inputs.push_back(graph_->nodes[real_node_map_[kAdd5]]->inputIndex[1]); // bias_hidden + inputs.push_back(graph_->nodes[real_node_map_[kMatmul1]]->inputIndex[0]); // init_h + auto outputs = graph_->nodes[real_node_map_[kAdd0]]->outputIndex; + auto attrs = CreateCustom(); + MS_CHECK_TRUE_RET(attrs != nullptr, RET_NULL_PTR); + auto prim_t = std::make_unique(); + MS_CHECK_TRUE_MSG(prim_t != nullptr, RET_ERROR, "Create PrimitiveT failed."); + prim_t->value.type = schema::PrimitiveType_Custom; + prim_t->value.value = attrs.release(); + auto custom_gru = std::make_unique(); + MS_CHECK_TRUE_MSG(custom_gru != nullptr, RET_ERROR, "Create Custom-Gru failed."); + custom_gru->name = graph_->nodes[real_node_map_[kAdd0]]->name; + custom_gru->inputIndex = inputs; + custom_gru->outputIndex = outputs; + custom_gru->primitive = std::move(prim_t); + link_info_manager_->Replace(real_node_map_[kAdd0], std::move(custom_gru)); + std::set delete_nodes; + for (uint32_t i = kAdd1; i <= kMatmul1; ++i) { + delete_nodes.insert(real_node_map_[i]); + } + link_info_manager_->AddDeleteNodes(delete_nodes); + return RET_OK; + } + + std::map, NodeInfo> pattern_; + std::map real_node_map_; + schema::MetaGraphT *graph_{nullptr}; + std::shared_ptr link_info_manager_{nullptr}; +}; + +STATUS GruFusionPass::Run(schema::MetaGraphT *graph) { +#ifndef ENABLE_ARM64 + return RET_OK; +#endif + if (graph == nullptr) { + MS_LOG(ERROR) << "graph is a nullptr."; + return RET_NULL_PTR; + } + if (graph->subGraph.size() != 1) { + return RET_OK; + } + if (FuseToGruCell(graph) != RET_OK) { + return RET_ERROR; + } + return FuseGruCell(graph); +} + +STATUS GruFusionPass::FuseToGruCell(schema::MetaGraphT *graph) { + GruCellFusion gru_cell_fusion{}; + if (gru_cell_fusion.Run(graph) != RET_OK) { + MS_LOG(ERROR) << "Fuse GruCell failed."; + return RET_ERROR; + } + return RET_OK; +} + +STATUS GruFusionPass::FuseGruCell(schema::MetaGraphT *graph) { + link_info_manager_ = std::make_shared(graph); + for (uint32_t i = 0; i < static_cast(graph->nodes.size()); ++i) { + if (!CheckStack(graph, i)) { + continue; + } + std::vector strided_slices; + std::vector squeezes; + std::vector gru_cells; + if (!MatchPatten(graph, i, &strided_slices, &squeezes, &gru_cells)) { + continue; + } + if (CreateGru(graph, i, strided_slices, squeezes, gru_cells) != RET_OK) { + MS_LOG(ERROR) << "Fuse GruCell failed."; + return RET_ERROR; + } + } + link_info_manager_->UpdateMetaGraph(); + link_info_manager_ = nullptr; + return RET_OK; +} + +bool GruFusionPass::MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector *strided_slices, + std::vector *squeezes, std::vector *gru_cells) { + auto &link_infos = link_info_manager_->GetLinkInfos(); + auto &stack_node = graph->nodes[stack_index]; + int batch_point = 0; + auto CommonCheck = [&link_infos](uint32_t tensor_index) { + if (link_infos.find(tensor_index) == link_infos.end()) { + return std::make_pair(false, 0); + } + const auto &in_node_info = link_infos.at(tensor_index).first; + if (in_node_info.size() != 1 && in_node_info.front().in_indexes.size() != 1) { + return std::make_pair(false, 0); + } + auto node_index = link_infos.at(tensor_index).second.node_index; + if (node_index < 0) { + return std::make_pair(false, 0); + } + return std::make_pair(true, node_index); + }; + for (auto tensor_index : stack_node->inputIndex) { + auto check_info = CommonCheck(tensor_index); + if (!check_info.first || !CheckGruCell(graph, check_info.second)) { + return false; + } + gru_cells->push_back(check_info.second); + auto &gru_cell_node = graph->nodes[check_info.second]; + check_info = CommonCheck(gru_cell_node->inputIndex.front()); + if (!check_info.first || !CheckSqueeze(graph, check_info.second)) { + return false; + } + squeezes->push_back(check_info.second); + auto &squeeze_node = graph->nodes[check_info.second]; + check_info = CommonCheck(squeeze_node->inputIndex.front()); + if (!check_info.first || !CheckStridedSlice(graph, check_info.second, batch_point)) { + return false; + } + strided_slices->push_back(check_info.second); + ++batch_point; + } + if (strided_slices->empty()) { + return false; + } + uint32_t input_index = graph->nodes[strided_slices->front()]->inputIndex.front(); + if (std::any_of(strided_slices->begin(), strided_slices->end(), [input_index, graph](uint32_t strided_slice) { + return graph->nodes[strided_slice]->inputIndex.front() != input_index; + })) { + return false; + } + auto in_shape = graph->allTensors[input_index]->dims; + if (in_shape.empty() || in_shape.front() != batch_point) { + return false; + } + return CheckGruCellConnection(graph, *gru_cells); +} + +bool GruFusionPass::CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector &gru_cells) { + auto &first_node = graph->nodes[gru_cells.front()]; + if (first_node->inputIndex.size() != C6NUM) { + return false; + } + auto init_h = first_node->outputIndex.front(); + for (size_t i = 1; i < gru_cells.size(); ++i) { + auto &node = graph->nodes[gru_cells[i]]; + if (node->inputIndex.size() != first_node->inputIndex.size()) { + return false; + } + for (size_t j = 1; j < C5NUM; ++j) { + if (node->inputIndex[j] != first_node->inputIndex[j]) { + return false; + } + } + if (node->inputIndex[C5NUM] != init_h) { + return false; + } + init_h = node->outputIndex.front(); + } + return true; +} + +STATUS GruFusionPass::CreateGru(schema::MetaGraphT *graph, uint32_t stack_index, + const std::vector &strided_slices, const std::vector &squeezes, + const std::vector &gru_cells) { + auto &gru_cell_node = graph->nodes[gru_cells.front()]; + gru_cell_node->inputIndex[0] = graph->nodes[strided_slices.front()]->inputIndex[0]; + gru_cell_node->outputIndex[0] = graph->nodes[stack_index]->outputIndex[0]; + std::set delete_node{stack_index}; + (void)delete_node.insert(strided_slices.begin(), strided_slices.end()); + (void)delete_node.insert(squeezes.begin(), squeezes.end()); + (void)delete_node.insert(gru_cells.begin() + 1, gru_cells.end()); + link_info_manager_->AddDeleteNodes(delete_node); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h new file mode 100644 index 00000000..5e2b705d --- /dev/null +++ b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h @@ -0,0 +1,45 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_ +#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_ + +#include +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class LinkInfoManager; +class GruFusionPass : public GraphPass { + public: + GruFusionPass() = default; + ~GruFusionPass() override = default; + STATUS Run(schema::MetaGraphT *graph) override; + + private: + STATUS FuseToGruCell(schema::MetaGraphT *graph); + STATUS FuseGruCell(schema::MetaGraphT *graph); + bool MatchPatten(schema::MetaGraphT *graph, uint32_t stack_index, std::vector *strided_slices, + std::vector *squeezes, std::vector *gru_cells); + bool CheckGruCellConnection(schema::MetaGraphT *graph, const std::vector &gru_cells); + STATUS CreateGru(schema::MetaGraphT *graph, uint32_t stack_index, const std::vector &strided_slices, + const std::vector &squeezes, const std::vector &gru_cells); + std::shared_ptr link_info_manager_{nullptr}; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_GRU_FUSION_PASS_H_ diff --git a/mindspore/lite/src/train/static_allocator.h b/mindspore/lite/src/train/static_allocator.h index d78e13ba..bd80651d 100644 --- a/mindspore/lite/src/train/static_allocator.h +++ b/mindspore/lite/src/train/static_allocator.h @@ -40,12 +40,12 @@ class StaticAllocator : public Allocator { if (ptr == nullptr) return STATIC_ALLOCATION; char *ptrc = reinterpret_cast(ptr); char *bufc = reinterpret_cast(start_buf_); - return ((ptrc < bufc) || (ptrc - bufc >= static_cast(size_)) ? 1 : 0); + return ((ptrc < bufc) || (ptrc >= bufc + size_)) ? 1 : 0; } private: - void *start_buf_; - size_t size_; + void *start_buf_{nullptr}; + size_t size_{0}; size_t total_size_ = 0; }; }; // namespace mindspore diff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc index 7e504c4e..008de7c5 100644 --- a/mindspore/lite/src/train/train_export.cc +++ b/mindspore/lite/src/train/train_export.cc @@ -30,6 +30,10 @@ #include "src/train/graph_fusion.h" #include "src/train/graph_dropout.h" #include "src/runtime/weight_decoder.h" +#include "src/runtime/kernel/cpu/fp16/fp16_op_handler.h" +#ifndef ENABLE_ARM +#include "base/float16.h" +#endif namespace mindspore { namespace lite { @@ -645,6 +649,40 @@ int TrainExport::SaveToBuffer() { return RET_OK; } +int TrainExport::SaveWeightsToFile(bool enable_fp16, const std::vector &changeable_weights_name) { + const auto &all_tensors = meta_graph_->allTensors; + std::ofstream weights(file_name_, std::ios::out | std::ios::trunc | std::ios::binary); + for (auto &tensor : all_tensors) { + MS_CHECK_TRUE_MSG(tensor != nullptr, RET_NULL_PTR, "Exist tensor is a nullptr."); + if (tensor->data.empty()) { + continue; + } + if (std::find(changeable_weights_name.begin(), changeable_weights_name.end(), tensor->name) != + changeable_weights_name.end()) { + auto shape = tensor->dims; + weights.write(reinterpret_cast(shape.data()), shape.size() * sizeof(uint32_t)); + } + if (!enable_fp16 || tensor->dataType != kNumberTypeFloat32) { + weights.write(reinterpret_cast(tensor->data.data()), tensor->data.size()); + } else { + std::vector data_fp16(tensor->data.size() / sizeof(float)); +#ifndef ENABLE_ARM + auto fp32_data = reinterpret_cast(tensor->data.data()); + auto fp16_data = reinterpret_cast(data_fp16.data()); + CHECK_NULL_RETURN(fp32_data); + CHECK_NULL_RETURN(fp16_data); + for (size_t j = 0; j < data_fp16.size(); ++j) { + fp16_data[j] = float16(fp32_data[j]); + } +#else + Float32ToFloat16_fp16_handler(tensor->data.data(), data_fp16.data(), data_fp16.size(), true); +#endif + weights.write(reinterpret_cast(data_fp16.data()), data_fp16.size() * sizeof(uint16_t)); + } + } + weights.close(); + return RET_OK; +} bool TrainExport::IsInputTensor(const schema::TensorT &t) { int total_dims = std::accumulate(t.dims.begin(), t.dims.end(), 1, std::multiplies()); diff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h index 8e802021..d6f81187 100644 --- a/mindspore/lite/src/train/train_export.h +++ b/mindspore/lite/src/train/train_export.h @@ -52,6 +52,7 @@ class TrainExport { int ExportInit(const std::string model_name, std::string version); int SaveToFile(); int SaveToBuffer(); + int SaveWeightsToFile(bool enable_fp16 = false, const std::vector &changeable_weights_name = {}); void set_connect(const std::unordered_map &map) { connect_ = map; } int LoadModel(void *buf, size_t buf_size); int AddTransformNode(); diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index b40ff8c2..2f9aa99b 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -1233,10 +1233,45 @@ int TrainSession::Export(Buffer *model_buffer, ModelType model_type, Quantizatio return ExportInner(model_buffer, model_type, quant_type, format, out_put_tensor_name); } +int TrainSession::ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type, + FormatType format, bool enable_fp16, + const std::vector &changeable_weights_name) { + MS_CHECK_FALSE_MSG(file_name.empty(), RET_ERROR, "File name cannot be empty"); + struct stat path_type; + if (stat(file_name.c_str(), &path_type) == RET_OK) { + if (path_type.st_mode & S_IFDIR) { + MS_LOG(ERROR) << "Destination must be path, now is a directory"; + return RET_ERROR; + } + } + MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "Format must be `FT_FLATBUFFERS`"); + MS_CHECK_FALSE_MSG(model_type != mindspore::lite::MT_INFERENCE, RET_ERROR, + "Currently, can only export inference-model's weights."); + int status = Eval(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed"); + + TrainExport texport(file_name); + status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); + + status = texport.ExportNet(inference_kernels_, tensors_, eval_output_tensor_names_, model_.get(), QT_DEFAULT); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); + status = texport.TrainModelDrop(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); + status = texport.TrainModelFusion(); + TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelFusion failed."); + status = texport.SaveWeightsToFile(enable_fp16, changeable_weights_name); + if (status != RET_OK) { + MS_LOG(ERROR) << "Failed to save to " << file_name; + return status; + } + return RET_OK; +} + std::vector TrainSession::GetFeatureMaps() const { std::vector features; for (auto cur_tensor : this->tensors_) { - if (cur_tensor->category() ==lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) { + if (cur_tensor->category() == lite::Category::CONST_TENSOR && cur_tensor->data_type() == kNumberTypeFloat32) { features.push_back(cur_tensor); } } diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index 5acff82a..edcab32d 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -106,6 +106,9 @@ class TrainSession : virtual public lite::LiteSession { std::vector out_put_tensor_name = {}) override; int Export(Buffer *model_buffer, ModelType model_type, QuantizationType quant_type, FormatType, std::vector out_put_tensor_name = {}) override; + int ExportWeightsCollaborateWithMicro(const std::string &file_name, lite::ModelType model_type, FormatType, + bool enable_fp16, + const std::vector &changeable_weights_name) override; std::vector GetFeatureMaps() const override; diff --git a/mindspore/lite/test/config_level0/micro/micro_arm64.cfg b/mindspore/lite/test/config_level0/micro/micro_arm64.cfg index 0375ebf7..765549ab 100644 --- a/mindspore/lite/test/config_level0/micro/micro_arm64.cfg +++ b/mindspore/lite/test/config_level0/micro/micro_arm64.cfg @@ -25,3 +25,10 @@ support_parallel=false # enable debug debug_mode=false + +# false indicates that only the required weights are saved. If collaborate with lite-train, the parameter must be true. +keep_original_weight=false + +# the names of those weight-tensors whose shape is changeable, only embedding-table supports change. +# the parameter is used to collaborate with lite-train. If set, `keep_original_weight` must be true. +#changeable_weights_name=name0,name1 diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc index ea263e64..4c6bd237 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc @@ -221,11 +221,14 @@ int ConfigFileParser::ParseAclOptionCfgString(const std::map> &maps) { if (maps.find(kMicroParam) != maps.end()) { const auto &map = maps.at(kMicroParam); - std::map parse_map{{"target", micro_param_string_.target}, - {"codegen_mode", micro_param_string_.codegen_mode}, - {"debug_mode", micro_param_string_.debug_mode}, - {"support_parallel", micro_param_string_.support_parallel}, - {"enable_micro", micro_param_string_.enable_micro}}; + std::map parse_map{ + {"target", micro_param_string_.target}, + {"codegen_mode", micro_param_string_.codegen_mode}, + {"debug_mode", micro_param_string_.debug_mode}, + {"support_parallel", micro_param_string_.support_parallel}, + {"enable_micro", micro_param_string_.enable_micro}, + {"keep_original_weight", micro_param_string_.keep_original_weight}, + {"changeable_weights_name", micro_param_string_.changeable_weights_name}}; return SetMapData(map, parse_map, kMicroParam); } return RET_OK; diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h index 0ada406e..8854e5f7 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h @@ -86,6 +86,8 @@ struct MicroParamString { std::string support_parallel; std::string debug_mode; std::string enable_micro; + std::string keep_original_weight; + std::string changeable_weights_name; }; struct ThirdPartyModelString { diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc index 310b2398..559bee8b 100644 --- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc @@ -61,6 +61,23 @@ STATUS MicroParamParser::ParseEnableMicro(const std::string &enable_micro, micro return RET_OK; } +STATUS MicroParamParser::ParseKeepOriginalWeight(const std::string &save_all_weights, micro::MicroParam *micro_param) { + MS_LOG(DEBUG) << "Micro enables : " << save_all_weights; + micro_param->keep_original_weight = false; // default + bool is_keep_original_weight; + if (ConvertBool(save_all_weights, &is_keep_original_weight)) { + micro_param->keep_original_weight = is_keep_original_weight; + } + return RET_OK; +} + +STATUS MicroParamParser::ParseChangeableWeightsName(const std::string &changeable_weights_name, + micro::MicroParam *micro_param) { + MS_LOG(DEBUG) << "Micro record changeable weights name: " << changeable_weights_name; + micro_param->changeable_weights_name = changeable_weights_name; + return RET_OK; +} + STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_string, micro::MicroParam *micro_param) { CHECK_NULL_RETURN(micro_param); if (!micro_param_string.target.empty()) { @@ -93,6 +110,22 @@ STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_str return RET_INPUT_PARAM_INVALID; } } + if (!micro_param_string.keep_original_weight.empty()) { + if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) { + MS_LOG(ERROR) << "Parse keep_original_weight valï¼› " << micro_param_string.keep_original_weight; + return RET_INPUT_PARAM_INVALID; + } + } + if (!micro_param_string.changeable_weights_name.empty()) { + if (!micro_param->keep_original_weight) { + MS_LOG(ERROR) << "When changeable_weights_name is set, the keep_original_weight must be true."; + return RET_INPUT_PARAM_INVALID; + } + if (ParseChangeableWeightsName(micro_param_string.changeable_weights_name, micro_param) != RET_OK) { + MS_LOG(ERROR) << "Parse changeable_weights_name val: " << micro_param_string.changeable_weights_name; + return RET_INPUT_PARAM_INVALID; + } + } return RET_OK; } } // namespace lite diff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h index 860182af..93a30b39 100644 --- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h +++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h @@ -33,6 +33,8 @@ class MicroParamParser { STATUS ParseCodeGenMode(const std::string &codegen_mode, micro::MicroParam *micro_param); STATUS ParseSupportParallel(const std::string &support_parallel, micro::MicroParam *micro_param); STATUS ParseDebugMode(const std::string &debug_mode, micro::MicroParam *micro_param); + STATUS ParseKeepOriginalWeight(const std::string &save_all_weights, micro::MicroParam *micro_param); + STATUS ParseChangeableWeightsName(const std::string &changeable_weights_name, micro::MicroParam *micro_param); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 944ed29c..6177d379 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -186,6 +186,9 @@ schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr } } + if (param->fmk_type == FmkType::kFmkTypeMsLite) { + return nullptr; + } auto graph = BuildFuncGraph(param); if (graph == nullptr) { MS_LOG(ERROR) << "Parser/Import model return nullptr"; @@ -557,7 +560,7 @@ int CheckFmkType(const std::shared_ptr ¶m) { } const std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, - FmkType::kFmkTypeThirdParty}; + FmkType::kFmkTypeThirdParty, FmkType::kFmkTypeMsLite}; if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|THIRDPARTY" << ", but got " << param->fmk_type; @@ -780,6 +783,14 @@ int RunConverter(const std::shared_ptr ¶m, void **model_data, NotSupportOp::GetInstance()->PrintOps(); status = ReturnCode::GetSingleReturnCode()->status_code(); if (meta_graph == nullptr) { + if (param->fmk_type == FmkType::kFmkTypeMsLite && param->microParam.enable_micro) { + status = micro::Coder::MicroSourceCodeGeneration(param->model_file, param->output_file, param->microParam, + param->weight_fp16); + if (status != RET_OK) { + CONVERTER_LOG_ERROR("MICRO CODEGEN FAILED:" << status << " " << GetErrorInfo(status)); + } + return status; + } CONVERTER_LOG_ERROR("CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status)); status = RET_ERROR; return status; @@ -797,9 +808,8 @@ int RunConverter(const std::shared_ptr ¶m, void **model_data, } if (param->microParam.enable_micro) { - status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam.codegen_mode, - param->microParam.target, param->microParam.support_parallel, - param->microParam.debug_mode, param->weight_fp16); + status = + micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam, param->weight_fp16); if (status != RET_OK) { delete meta_graph; CONVERTER_LOG_ERROR("MICRO CODEGEN FAILED:" << status << " " << GetErrorInfo(status)); diff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc index 595b59ed..e30994cc 100644 --- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc @@ -29,7 +29,7 @@ using mindspore::lite::RET_INPUT_PARAM_INVALID; using mindspore::lite::RET_OK; Flags::Flags() { - AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", ""); + AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX | MSLITE", ""); AddFlag(&Flags::modelFile, "modelFile", "Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", ""); AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); @@ -120,7 +120,7 @@ int Flags::InitFmk() { // value check not here, it is in converter c++ API's CheckValueParam method. std::map StrToEnumFmkTypeMap = { {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, - {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}}; + {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"THIRDPARTY", kFmkTypeThirdParty}, {"MSLITE", kFmkTypeMsLite}}; if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) { this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn); } else { diff --git a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake index 9ae54538..589ee81a 100644 --- a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake +++ b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake @@ -53,6 +53,7 @@ set(CODER_OPCODERS_SRC ${MICRO_DIR}/coder/opcoders/base/reduce_base_coder.cc ${MICRO_DIR}/coder/opcoders/base/resize_base_coder.cc ${MICRO_DIR}/coder/opcoders/base/reshape_base_coder.cc + ${MICRO_DIR}/coder/opcoders/base/stack_base_coder.cc ${MICRO_DIR}/coder/opcoders/base/softmax_base_coder.cc ${MICRO_DIR}/coder/opcoders/base/detection_post_process_base_coder.cc ${MICRO_DIR}/coder/opcoders/base/strided_slice_base_coder.cc @@ -71,6 +72,7 @@ set(CODER_OPCODERS_SRC ${MICRO_DIR}/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp16/avg_pooling_fp16_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc + ${MICRO_DIR}/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc @@ -90,6 +92,7 @@ set(CODER_OPCODERS_SRC ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_fp32_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc + ${MICRO_DIR}/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp32/full_connection_fp32_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc ${MICRO_DIR}/coder/opcoders/nnacl/fp32/groupnorm_fp32_coder.cc diff --git a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc index 9c5839b4..be314ed6 100644 --- a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc +++ b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.cc @@ -79,18 +79,26 @@ void MemoryAllocator::Free() { iter++; } } + for (auto &item : auxiliary_weights_) { + delete item.second.first; + } malloc_weights_addr_.clear(); for (auto &item : allocated_) { free(item); item = nullptr; } allocated_.clear(); + origin_weights_.clear(); + auxiliary_weights_.clear(); } std::map MemoryAllocator::tensors_map() const { std::map res; res.insert(tensors_addr_.begin(), tensors_addr_.end()); res.insert(malloc_weights_addr_.begin(), malloc_weights_addr_.end()); + (void)std::for_each( + auxiliary_weights_.begin(), auxiliary_weights_.end(), + [&res](const std::pair> &item) { res.insert(item.second); }); return res; } @@ -121,17 +129,25 @@ void MemoryAllocator::AssignGraphInputs(const std::vector &inputs) { } } -void MemoryAllocator::RecordOriginWeightsAddr(const std::vector> &nodes) { - for (const auto &node : nodes) { - std::vector inputs = node->input_tensors(); - for (const auto &tensor : inputs) { - if (tensor->category() == lite::Category::CONST_TENSOR || tensor->category() == lite::Category::CONST_SCALAR) { - std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_); - origin_weights_addr_.insert(std::make_pair(tensor, runtime_addr)); - weight_index_++; +int MemoryAllocator::RecordOriginWeightsAddr(const std::vector &all_tensors, + const std::string &changeable_weights_name) { + std::vector weights_name; + if (!changeable_weights_name.empty()) { + weights_name = StrSplit(changeable_weights_name, ","); + } + for (const auto &tensor : all_tensors) { + if (tensor->category() == lite::Category::CONST_TENSOR || tensor->category() == lite::Category::CONST_SCALAR) { + if (std::find(weights_name.begin(), weights_name.end(), tensor->tensor_name()) != weights_name.end()) { + if (RecordChangeableWeights(tensor) != RET_OK) { + return RET_ERROR; + } } + std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++); + origin_weights_addr_.insert(std::make_pair(tensor, runtime_addr)); + origin_weights_.push_back(tensor); } } + return RET_OK; } int MemoryAllocator::AssignTensors(const std::vector> &nodes) { @@ -150,9 +166,13 @@ int MemoryAllocator::AssignTensors(const std::vector &inputs, - const std::vector> &nodes) { + const std::vector> &nodes, + const std::vector &all_tensors, const std::string &changeable_weights_name) { AssignGraphInputs(inputs); - RecordOriginWeightsAddr(nodes); + if (RecordOriginWeightsAddr(all_tensors, changeable_weights_name) != RET_OK) { + MS_LOG(ERROR) << "RecordOriginWeightsAddr failed."; + return RET_ERROR; + } return AssignTensors(nodes); } @@ -163,4 +183,46 @@ void MemoryAllocator::MarkSharedWeight(const Tensor *src, void *pack_weight) { void *MemoryAllocator::GetSharedWeightAddr(const Tensor *src) { return shared_pack_weights_.find(src) == shared_pack_weights_.end() ? nullptr : shared_pack_weights_[src]; } + +int MemoryAllocator::RecordChangeableWeights(Tensor *src) { + MS_ASSERT(src != nullptr); + auto variable_str = GetAuxiliaryWeight(src); + if (!variable_str.empty()) { + return RET_OK; + } + if (!src->IsConst()) { + MS_LOG(ERROR) << "Currently, the tensor must be a constant."; + return RET_NOT_SUPPORT; + } + auto shape = src->shape(); + auto shape_tensor = new (std::nothrow) + Tensor(kNumberTypeInt32, {static_cast(shape.size())}, src->format(), Category::CONST_TENSOR); + if (shape_tensor == nullptr) { + MS_LOG(ERROR) << "Create an assistant tensor failed."; + return RET_NULL_PTR; + } + auto data = shape_tensor->MutableData(); + if (data == nullptr) { + MS_LOG(ERROR) << "Create an assistant tensor failed."; + delete shape_tensor; + return RET_NULL_PTR; + } + if (memcpy_s(data, shape_tensor->Size(), shape.data(), shape.size() * sizeof(int)) != EOK) { + MS_LOG(ERROR) << "Create an assistant tensor failed."; + delete shape_tensor; + return RET_ERROR; + } + shape_tensor->set_tensor_name(src->tensor_name() + "_shape"); + std::string runtime_addr = kWeightPrefixName + std::to_string(weight_index_++); + auxiliary_weights_[src] = std::make_pair(shape_tensor, runtime_addr); + return RET_OK; +} + +std::string MemoryAllocator::GetAuxiliaryWeight(Tensor *src) { + auto iter = auxiliary_weights_.find(src); + if (iter != auxiliary_weights_.end()) { + return iter->second.second; + } + return {}; +} } // namespace mindspore::lite::micro diff --git a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h index 8a1331fb..f5bacf6f 100644 --- a/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h +++ b/mindspore/lite/tools/converter/micro/coder/allocator/allocator.h @@ -56,7 +56,8 @@ class MemoryAllocator { /* * assign model's input, original weights and all tensors memory addr */ - int Assign(const std::vector &inputs, const std::vector> &nodes); + int Assign(const std::vector &inputs, const std::vector> &nodes, + const std::vector &all_tensors, const std::string &changeable_weights_name = {}); // allocator holds the space malloced by opcoders, will free before session coder destroy void Free(); @@ -141,14 +142,18 @@ class MemoryAllocator { void *MallocWeightTensor(TypeId type_id, size_t size, MallocType type, const std::string &tensor_name = ""); void MarkSharedWeight(const Tensor *src, void *pack_weight); void *GetSharedWeightAddr(const Tensor *src); + std::string GetAuxiliaryWeight(Tensor *src); + std::vector origin_weights() const { return origin_weights_; } + std::map> auxiliary_weights() const { return auxiliary_weights_; } private: int AssignTensors(const std::vector> &nodes); void AssignGraphInputs(const std::vector &inputs); void AssignWorkspaces(void *addr, size_t size); - void RecordOriginWeightsAddr(const std::vector> &nodes); + int RecordOriginWeightsAddr(const std::vector &all_tensors, + const std::string &changeable_weights_name = {}); void RecordTensorsAddr(const std::map &offsets); - + int RecordChangeableWeights(Tensor *src); MemoryAllocator() = default; ~MemoryAllocator() = default; @@ -160,11 +165,13 @@ class MemoryAllocator { bool is_next_{false}; size_t offset_{0}; std::vector allocated_; + std::vector origin_weights_; std::map saved_weights_addr_; std::map origin_weights_addr_; std::map malloc_weights_addr_; std::map tensors_addr_; std::map shared_pack_weights_; + std::map> auxiliary_weights_; }; } // namespace mindspore::lite::micro #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_ALLOCATOR_ALLOCATOR_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/coder.cc b/mindspore/lite/tools/converter/micro/coder/coder.cc index cca4687e..a94ac91b 100644 --- a/mindspore/lite/tools/converter/micro/coder/coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/coder.cc @@ -93,25 +93,48 @@ bool Coder::InitPath(const std::string &output_path) { } int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, - const std::string &codegen_mode, const std::string &device, bool support_parallel, - bool debug_mode, bool enableFp16) { + const MicroParam ¶m, bool enable_fp16) { flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); auto offset = schema::MetaGraph::Pack(builder, &graph); builder.Finish(offset); schema::FinishMetaGraphBuffer(builder, offset); size_t size = builder.GetSize(); + if (ExecuteMicroGeneration(builder.GetBufferPointer(), size, output_path, param, enable_fp16) != RET_OK) { + MS_LOG(ERROR) << "Execute Micro failed."; + return RET_ERROR; + } + return RET_OK; +} + +int Coder::MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, + const MicroParam ¶m, bool enable_fp16) { + size_t buffer_size; + auto model_buf = lite::ReadFile(model_file.c_str(), &buffer_size); + if (model_buf == nullptr) { + MS_LOG(ERROR) << "Read model-file failed."; + return RET_NULL_PTR; + } + auto ret = ExecuteMicroGeneration(model_buf, buffer_size, output_path, param, enable_fp16); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Execute Micro failed."; + } + delete[] model_buf; + return ret; +} +int Coder::ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path, + const MicroParam ¶m, bool enable_fp16) { micro::Coder code_gen; if (!code_gen.InitPath(output_path)) { MS_LOG(ERROR) << "Init path failed"; return RET_ERROR; } // codegeneration for micro - STATUS status = code_gen.Init(codegen_mode, device, support_parallel, debug_mode); + STATUS status = code_gen.Init(param); if (status != RET_OK) { MS_LOG(ERROR) << "Codegen init Error"; return RET_ERROR; } - status = code_gen.Run(builder.GetBufferPointer(), size, enableFp16); + status = code_gen.Run(model_buf, size, enable_fp16); if (status != RET_OK) { MS_LOG(ERROR) << "Codegen Run Error"; return RET_ERROR; @@ -120,29 +143,30 @@ int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std: return RET_OK; } -int Coder::Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode) const { +int Coder::Init(const MicroParam ¶m) const { static const std::map kTargetMap = { {"x86", kX86}, {"Cortex-M", kCortex_M}, {"ARM32", kARM32}, {"ARM64", kARM64}, {"All", kAllTargets}}; static const std::map kCodeModeMap = {{"Inference", Inference}, {"Train", Train}}; Configurator *config = Configurator::GetInstance(); - auto target_item = kTargetMap.find(target); - MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + target); + auto target_item = kTargetMap.find(param.target); + MS_CHECK_TRUE_RET_BOOL(target_item != kTargetMap.end(), "unsupported target: " + param.target); config->set_target(target_item->second); - auto code_item = kCodeModeMap.find(code_mode); - MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + code_mode); + auto code_item = kCodeModeMap.find(param.codegen_mode); + MS_CHECK_TRUE_RET_BOOL(code_item != kCodeModeMap.end(), "unsupported code mode: " + param.codegen_mode); config->set_code_mode(code_item->second); - if (support_parallel && config->target() == kCortex_M) { + if (param.support_parallel && config->target() == kCortex_M) { MS_LOG(ERROR) << "Cortex-M cannot support parallel."; return RET_ERROR; } - config->set_support_parallel(support_parallel); - config->set_debug_mode(debug_mode); + config->set_support_parallel(param.support_parallel); + config->set_debug_mode(param.debug_mode); config->set_proj_dir(model_name_); - + config->set_keep_original_weight(param.keep_original_weight); + config->set_changeable_weights_name(param.changeable_weights_name); const std::string slash = std::string(kSlash); if (!save_path_.empty() && !DirExists(save_path_)) { MS_LOG(ERROR) << "code_gen code path " << save_path_ << " is not valid"; @@ -170,6 +194,7 @@ int Coder::Init(const std::string code_mode, const std::string target, bool supp print_parameter("codePath", config->code_path()); print_parameter("codeMode", config->code_mode()); print_parameter("debugMode", config->debug_mode()); + print_parameter("keepOriginalWeight", config->keep_original_weight()); return RET_OK; } } // namespace mindspore::lite::micro diff --git a/mindspore/lite/tools/converter/micro/coder/coder.h b/mindspore/lite/tools/converter/micro/coder/coder.h index 96531e6f..0753156a 100644 --- a/mindspore/lite/tools/converter/micro/coder/coder.h +++ b/mindspore/lite/tools/converter/micro/coder/coder.h @@ -31,11 +31,14 @@ class Coder final { ~Coder() = default; static int MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, - const std::string &codegen_mode, const std::string &device, - bool support_parallel, bool debug_mode, bool enableFp16); + const MicroParam ¶m, bool enable_fp16); + static int MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, + const MicroParam ¶m, bool enable_fp16); private: - int Init(const std::string code_mode, const std::string target, bool support_parallel, bool debug_mode_) const; + static int ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path, + const MicroParam ¶m, bool enable_fp16); + int Init(const MicroParam ¶m) const; int Run(const void *model_buff, size_t size, bool enableFp16); bool InitPath(const std::string &output_path); std::shared_ptr session_{nullptr}; diff --git a/mindspore/lite/tools/converter/micro/coder/config.h b/mindspore/lite/tools/converter/micro/coder/config.h index 84285932..42e0f50e 100644 --- a/mindspore/lite/tools/converter/micro/coder/config.h +++ b/mindspore/lite/tools/converter/micro/coder/config.h @@ -26,9 +26,11 @@ enum CodeMode { Inference = 0, Train = 1, Code_Unknown = 99 }; struct MicroParam { std::string codegen_mode = "Inference"; std::string target; + std::string changeable_weights_name; bool enable_micro{false}; bool support_parallel{false}; bool debug_mode{false}; + bool keep_original_weight{false}; }; class Configurator { @@ -56,6 +58,12 @@ class Configurator { void set_proj_dir(std::string dir) { proj_dir_ = dir; } std::string proj_dir() const { return proj_dir_; } + void set_keep_original_weight(bool keep_weight) { keep_original_weight_ = keep_weight; } + bool keep_original_weight() const { return keep_original_weight_; } + + void set_changeable_weights_name(const std::string &weights_name) { changeable_weights_name_ = weights_name; } + const std::string &changeable_weights_name() const { return changeable_weights_name_; } + private: Configurator() = default; ~Configurator() = default; @@ -64,7 +72,9 @@ class Configurator { CodeMode code_mode_{Code_Unknown}; bool support_parallel_{false}; bool debug_mode_{false}; + bool keep_original_weight_{false}; std::string proj_dir_; + std::string changeable_weights_name_; }; } // namespace mindspore::lite::micro #endif // MICRO_CODER_CONFIG_H diff --git a/mindspore/lite/tools/converter/micro/coder/context.h b/mindspore/lite/tools/converter/micro/coder/context.h index 724475fe..cec385bb 100644 --- a/mindspore/lite/tools/converter/micro/coder/context.h +++ b/mindspore/lite/tools/converter/micro/coder/context.h @@ -69,6 +69,25 @@ class CoderContext { void set_saved_weights(const std::map &saved_weights) { saved_weights_ = saved_weights; } std::map saved_weights() const { return saved_weights_; } + void set_origin_weights(const std::vector &origin_weights) { origin_weights_ = origin_weights; } + const std::vector &origin_weights() const { return origin_weights_; } + + void set_auxiliary_weights(const std::map> &auxiliary_weights) { + auxiliary_weights_ = auxiliary_weights; + } + const std::map> &auxiliary_weights() const { return auxiliary_weights_; } + + bool JudgeIsValid(bool keep_origin_weight) { + if (!keep_origin_weight) { + return true; + } + return std::all_of(saved_weights_.begin(), saved_weights_.end(), + [this](const std::pair &item) { + return std::find(this->origin_weights_.begin(), this->origin_weights_.end(), item.second) != + this->origin_weights_.end(); + }); + } + void set_total_buffer_size(size_t size) { total_buffer_size_ = size; } size_t total_buffer_size() const { return total_buffer_size_; } @@ -107,7 +126,11 @@ class CoderContext { private: std::vector graph_inputs_; std::vector graph_outputs_; - // primitive const tensors, parsed from model, without packed. + // primitive const tensors, parsed from model, without packed. Maybe exist tensor is not used. + std::vector origin_weights_; + // assistant content for origin-weights if needed. + std::map> auxiliary_weights_; + // primitive const tensors, parsed from model, with packed. Tensors are all real used. std::map saved_weights_; // all tensors, include parsed from model and packed tensors. std::map tensors_map_; diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc index 058f0ba0..d30e0133 100644 --- a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc +++ b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc @@ -141,6 +141,7 @@ void CodeMSModelDestory(std::ofstream &ofs, const Configurator *config) { } ofs << " MSTensorHandleArrayDestroy(micro_model->inputs);\n" " MSTensorHandleArrayDestroy(micro_model->outputs);\n" + " FreeResource();\n" " free(*model);\n" " *model = NULL;\n" " }\n"; @@ -331,10 +332,12 @@ void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptrauxiliary_weights(); for (const auto &item : ctx->tensors_map()) { Tensor *tensor = item.first; std::string name = item.second; - if (tensor->data() != nullptr && !(CheckConstantTensor(tensor))) { + if (tensor->data() != nullptr && + (!(CheckConstantTensor(tensor)) || w_auxiliary.find(tensor) != w_auxiliary.end())) { ofs << " (void**)&" << name << ",\n"; num++; } diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc index 4d102391..d0824ecb 100644 --- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc +++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc @@ -22,6 +22,32 @@ #include "coder/opcoders/parallel.h" namespace mindspore::lite::micro { +namespace { +struct camp { + bool operator()(const std::string &a, const std::string &b) const { return a.size() < b.size() || a < b; } +}; + +std::string GenerateArrayContent(const std::vector &contents, const std::string &prefix) { + std::string lines; + std::string line = prefix; + for (auto content : contents) { + std::string append = std::to_string(content) + ", "; + if (line == prefix) { + line += append; + continue; + } + if (line.size() + append.size() > 120) { + lines += line + "\n"; + line = prefix + append; + } else { + line += append; + } + } + lines += line + "\n"; + return lines; +} +} // namespace + void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr &ctx, const Configurator &config) { ofs << g_hwLicense; // include all operator header @@ -71,10 +97,11 @@ void CodeModelParamsData(std::ofstream &ofs, const std::map &ctx, const Configurator &config) { // reverse key and value of tensors_map - std::map address_map; + std::map address_map; for (const auto &item : ctx->tensors_map()) { address_map.insert(std::make_pair(item.second, item.first)); } + auto &w_auxiliary = ctx->auxiliary_weights(); for (auto &item : address_map) { std::string name = item.first; Tensor *tensor = item.second; @@ -83,13 +110,22 @@ void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std:: } if (CheckConstantTensor(tensor)) { if (config.target() != kCortex_M) { - hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[];\n"; - cofs << GetTensorDataType(tensor->data_type()) << name << "[" << tensor->ElementsNum() << "];\n"; + if (w_auxiliary.find(tensor) == w_auxiliary.end()) { + hofs << "extern " << GetTensorDataType(tensor->data_type()) << name << "[]; // " << tensor->tensor_name() + << std::endl; + cofs << GetTensorDataType(tensor->data_type()) << name << "[" << tensor->ElementsNum() << "];\n"; + } else { + hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << "; // " + << tensor->tensor_name() << std::endl; + cofs << GetTensorDataType(tensor->data_type()) << "*" << name << " = NULL;\n"; + } } else { - hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[];\n"; + hofs << "extern const " << GetTensorDataType(tensor->data_type()) << name << "[]; // " << tensor->tensor_name() + << std::endl; } } else if (tensor->category() == lite::Category::VAR) { - hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << ";\n"; + hofs << "extern " << GetTensorDataType(tensor->data_type()) << "*" << name << "; // " << tensor->tensor_name() + << std::endl; cofs << GetTensorDataType(tensor->data_type()) << "*" << name << " = NULL;\n"; } } @@ -104,6 +140,186 @@ void CodeInitWeightState(std::ofstream &ofs) { << "int Init(void *weight_buffer, int weight_size);\n\n"; } +void CodeWeightContentInit(std::ofstream &ofs, const std::unique_ptr &ctx, + const std::map &tensors_index) { + auto &w_auxiliary = ctx->auxiliary_weights(); + std::map real_need_tensors; + auto record_saved_tensors = ctx->saved_weights(); + for (auto &item : record_saved_tensors) { + real_need_tensors.insert(std::make_pair(item.first, item.second)); + } + std::string non_copy; + std::string copy_static; + std::string copy_dynamic; + int copy_static_num = 0; + int copy_dynamic_num = 0; + auto tensors_map = ctx->tensors_map(); + for (const auto &item : real_need_tensors) { + if (!CheckConstantTensor(item.second) || item.second->data() == nullptr) { + continue; + } + auto iter = tensors_map.find(item.second); + if (iter == tensors_map.end()) { + TypeId data_type = item.second->data_type(); + non_copy += " " + GetTensorDataType(data_type) + "*" + item.first + " = (weight_buffer + offsets[" + + std::to_string(tensors_index.at(item.second)) + "]);\n"; + continue; + } + if (w_auxiliary.find(item.second) == w_auxiliary.end()) { + copy_static += " {" + item.first + ", " + std::to_string(tensors_index.at(item.second)) + "},\n"; + ++copy_static_num; + } else { + copy_dynamic += " {&" + item.first + ", " + std::to_string(tensors_index.at(item.second)) + "},\n"; + ++copy_dynamic_num; + } + } + for (const auto &item : w_auxiliary) { + copy_static += " {" + item.second.second + ", " + std::to_string(tensors_index.at(item.second.first)) + "},\n"; + ++copy_static_num; + } + ofs << non_copy << "\n"; + if (copy_static_num > 0) { + ofs << " {\n struct ModelParameter static_copy[] = {\n" << copy_static << " };\n"; + ofs << " for(int i = 0; i < " << copy_static_num << "; ++i) {\n" + << " int index = static_copy[i].index;\n" + << " if (offsets[index] + tensors_size[index] > weight_size) {\n" + " return RET_ERROR;\n" + " }\n" + << " memcpy(static_copy[i].addr, (weight_buffer + offsets[index]), tensors_size[index]);\n" + << " }\n }\n\n"; + } + ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize() + dynamic_memory;\n"; + ofs << " if (" << ctx->weight_size_name() << " > 0) {\n"; + ofs << " " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n"; + ofs << " if (" << ctx->weight_name() << " == NULL) {\n return RET_ERROR;\n }\n"; + ofs << " memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n"; + ofs << " }\n"; + ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; + if (copy_dynamic_num > 0) { + ofs << " {\n struct ModelParameter dynamic_copy[] = {\n" << copy_dynamic << " };\n"; + ofs << " for(int i = 0; i < " << copy_dynamic_num << "; ++i) {\n" + << " int index = dynamic_copy[i].index;\n" + << " memcpy(" << ctx->weight_name() << " + " << ctx->weight_offset_name() + << ", (weight_buffer + offsets[index]), tensors_size[index]);\n" + << " *((void **)dynamic_copy[i].addr) = " << ctx->weight_name() << " + " << ctx->weight_offset_name() + << ";\n" + << " " << ctx->weight_offset_name() << " += tensors_size[index];\n" + << " }\n }\n\n"; + } +} + +void CodeWeightInitIfKeepWeight(std::ofstream &ofs, const std::unique_ptr &ctx) { + auto &w_origin = ctx->origin_weights(); + auto &w_auxiliary = ctx->auxiliary_weights(); + std::vector tensors_size; + std::vector online_compute_index; + std::map tensors_index; + for (auto tensor : w_origin) { + if (!(CheckConstantTensor(tensor)) || tensor->data() == nullptr) { + continue; + } + auto iter = w_auxiliary.find(tensor); + if (iter == w_auxiliary.end()) { + tensors_index[tensor] = tensors_size.size(); + tensors_size.push_back(tensor->Size()); + } else { + tensors_index[iter->second.first] = tensors_size.size(); + tensors_size.push_back(iter->second.first->Size()); + tensors_index[tensor] = tensors_size.size(); + online_compute_index.push_back(tensors_size.size()); + tensors_size.push_back(DataTypeSize(tensor->data_type())); + } + } + std::vector offsets{0}; + int last = online_compute_index.empty() ? tensors_size.size() - 1 : online_compute_index.front(); + for (int i = 1; i <= last; ++i) { + offsets.push_back(offsets[i - 1] + tensors_size[i - 1]); + } + ofs << "int Init(void *weight_buffer, int weight_size) {\n" + << " if (weight_buffer == NULL) {\n" + << " return RET_ERROR;\n" + << " }\n"; + ofs << " struct ModelParameter {\n" + << " void *addr;\n" + << " int index;\n" + << " };\n"; + ofs << " int offsets[" << std::to_string(tensors_size.size()) << "] = {\n" + << GenerateArrayContent(offsets, " ") << " };\n"; + ofs << " size_t tensors_size[" << std::to_string(tensors_size.size()) << "] = {\n" + << GenerateArrayContent(tensors_size, " ") << " };\n"; + ofs << " size_t dynamic_memory = 0;\n"; + offsets.insert(offsets.end(), tensors_size.size() - offsets.size(), 0); + if (!online_compute_index.empty()) { + ofs << " int online_compute_index[] = {\n" << GenerateArrayContent(online_compute_index, " ") << " };\n"; + ofs << " for (size_t i = 0; i < " << std::to_string(online_compute_index.size()) + "; ++i) {\n"; + ofs << " int *shape = (int *)(weight_buffer + offsets[online_compute_index[i] - 1]);\n"; + ofs << " int dim_num = tensors_size[online_compute_index[i] - 1] / 4;\n"; + ofs << " size_t tensor_size = tensors_size[online_compute_index[i]];\n"; + ofs << " for (int j = 0; j < dim_num; ++j) {\n"; + ofs << " tensor_size *= shape[j];\n"; + ofs << " }\n"; + ofs << " tensors_size[online_compute_index[i]] = tensor_size;\n"; + ofs << " dynamic_memory += tensor_size;\n"; + ofs << " int next_index = (i + 1) < " << std::to_string(online_compute_index.size()) + << " ? online_compute_index[i + 1] : " << std::to_string(tensors_size.size()) << " - 1;\n"; + ofs << " for (int j = online_compute_index[i] + 1; j <= next_index; ++j) {\n"; + ofs << " offsets[j] = offsets[j - 1] + tensors_size[j - 1];\n"; + ofs << " }\n }\n"; + } + CodeWeightContentInit(ofs, ctx, tensors_index); +} + +void CodeWeightInitIfNonKeepWeight(std::ofstream &ofs, const std::unique_ptr &ctx) { + ofs << "int Init(void *weight_buffer, int weight_size) {\n" + << " if (weight_buffer == NULL) {\n" + << " return RET_ERROR;\n" + << " }\n"; + ofs << " struct ModelParameter {\n" + << " void *addr;\n" + << " size_t size;\n" + << " size_t offset;\n" + << " };\n"; + + ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize();\n"; + size_t params_num = 0; + size_t offset = 0; + std::string params; + std::string origins; + for (const auto &item : ctx->saved_weights()) { + std::string name = item.first; + Tensor *tensor = item.second; + if (!CheckConstantTensor(tensor)) { + continue; + } + std::map ctx_tensor_map = ctx->tensors_map(); + auto iter = ctx_tensor_map.find(tensor); + if (iter != ctx_tensor_map.end()) { + origins += " {" + name + ", " + std::to_string(tensor->Size()) + ", " + std::to_string(offset) + "},\n"; + params_num++; + } else { + TypeId data_type = tensor->data_type(); + params += + " " + GetTensorDataType(data_type) + "*" + name + " = (weight_buffer + " + std::to_string(offset) + ");\n"; + } + offset += tensor->Size(); + } + ofs << params << "\n"; + ofs << " struct ModelParameter model_params[] = {\n" << origins << " };\n"; + ofs << "\n"; + ofs << " for(int i = 0; i < " << params_num << "; ++i) {\n" + << " if (model_params[i].offset + model_params[i].size > weight_size) {\n" + " return RET_ERROR;\n" + " }\n" + << " memcpy(model_params[i].addr, (weight_buffer + model_params[i].offset), model_params[i].size);\n" + << " }\n"; + ofs << " if (" << ctx->weight_size_name() << " > 0) {\n"; + ofs << " " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n"; + ofs << " if (" << ctx->weight_name() << " == NULL) {\n return RET_ERROR;\n }\n"; + ofs << " memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n"; + ofs << " }\n"; + ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; +} + void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr &ctx, const Configurator &config) { if (config.target() != kCortex_M) { ofs << "static size_t PackWeightSize() {\n"; @@ -114,58 +330,16 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr ofs << " return w_size;\n"; ofs << "}\n\n"; - ofs << "int Init(void *weight_buffer, int weight_size) {\n" - << " if (weight_buffer == NULL) {\n" - << " return RET_ERROR;\n" - << " }\n"; - ofs << " struct ModelParameter {\n" - << " void *addr;\n" - << " size_t size;\n" - << " size_t offset;\n" - << " };\n"; - - ofs << " size_t " << ctx->weight_size_name() << " = PackWeightSize();\n"; - size_t params_num = 0; - size_t offset = 0; - std::string params; - std::string origins; - for (const auto &item : ctx->saved_weights()) { - std::string name = item.first; - Tensor *tensor = item.second; - if (!CheckConstantTensor(tensor)) { - continue; - } - std::map ctx_tensor_map = ctx->tensors_map(); - auto iter = ctx_tensor_map.find(tensor); - if (iter != ctx_tensor_map.end()) { - origins += " {" + name + ", " + std::to_string(tensor->Size()) + ", " + std::to_string(offset) + "},\n"; - params_num++; - } else { - TypeId data_type = tensor->data_type(); - params += - " " + GetTensorDataType(data_type) + "*" + name + " = (weight_buffer + " + std::to_string(offset) + ");\n"; - } - offset += tensor->Size(); - } - ofs << params << "\n"; - ofs << " struct ModelParameter model_params[] = {\n" << origins << " };\n"; - ofs << "\n"; - ofs << " for(int i = 0; i < " << params_num << "; ++i) {\n" - << " if (model_params[i].offset + model_params[i].size > weight_size) {\n" - " return RET_ERROR;\n" - " }\n" - << " memcpy(model_params[i].addr, (weight_buffer + model_params[i].offset), model_params[i].size);\n" - << " }\n"; - ofs << " if (" << ctx->weight_size_name() << " > 0) {\n"; - ofs << " " << ctx->weight_name() << " = malloc(" << ctx->weight_size_name() << ");\n"; - ofs << " if (" << ctx->weight_name() << " == NULL) {\n return RET_ERROR;\n }\n"; - ofs << " memset(" << ctx->weight_name() << ", 0, " << ctx->weight_size_name() << ");\n"; - ofs << " }\n"; + if (config.keep_original_weight()) { + CodeWeightInitIfKeepWeight(ofs, ctx); + } else { + CodeWeightInitIfNonKeepWeight(ofs, ctx); + } } else { ofs << "int Init(void *weight_buffer, int weight_size) {\n"; ofs << " const size_t w_size = " << ctx->weight_buffer_size() << ";\n"; + ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; } - ofs << " size_t " << ctx->weight_offset_name() << " = 0;\n"; for (const auto &block : ctx->init_contents()) { ofs << "{\n" << block << "}\n"; } @@ -175,11 +349,26 @@ void CodeWeightInitFunc(std::ofstream &ofs, const std::unique_ptr ofs << "}\n\n"; } -void SaveDataToNet(const std::map &saved_weights, const std::string &net_file) { +void SaveDataToNet(const std::unique_ptr &ctx, const std::string &net_file, bool keep_weight) { std::ofstream net(net_file, std::ios::out | std::ios::trunc | std::ios::binary); MS_CHECK_TRUE_WITHOUT_RET(net.is_open(), "net file open failed!"); - for (auto &item : saved_weights) { - Tensor *tensor = item.second; + std::vector save_tensors; + if (keep_weight) { + auto &w_origin = ctx->origin_weights(); + auto &w_auxiliary = ctx->auxiliary_weights(); + (void)std::for_each(w_origin.begin(), w_origin.end(), [&save_tensors, &w_auxiliary](Tensor *tensor) { + auto iter = w_auxiliary.find(tensor); + if (iter != w_auxiliary.end()) { + save_tensors.push_back(iter->second.first); + } + save_tensors.push_back(tensor); + }); + } else { + auto recorded_saved_tensors = ctx->saved_weights(); + (void)std::transform(recorded_saved_tensors.begin(), recorded_saved_tensors.end(), std::back_inserter(save_tensors), + [](const std::pair &item) { return item.second; }); + } + for (auto tensor : save_tensors) { if ((CheckConstantTensor(tensor)) && tensor->data() != nullptr) { net.write(reinterpret_cast(tensor->data()), tensor->Size()); } diff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h index 3a68a540..98c56afd 100644 --- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h +++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.h @@ -31,7 +31,7 @@ void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr &weights); void CodeModelParamsData(std::ofstream &ofs, const std::map &weights); -void SaveDataToNet(const std::map &saved_weights, const std::string &net_file); +void SaveDataToNet(const std::unique_ptr &ctx, const std::string &net_file, bool keep_weight); void CodeModelParamsForNet(std::ofstream &hofs, std::ofstream &cofs, const std::unique_ptr &ctx, const Configurator &config); diff --git a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc index 5b29978f..8add577f 100644 --- a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc +++ b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc @@ -259,7 +259,7 @@ int Generator::CodeWeightFile() { cofs << "unsigned char * " << ctx_->buffer_name() << " = 0; \n"; cofs << "unsigned char * " << ctx_->weight_name() << " = 0; \n"; std::string net_file = net_src_file_path_ + "net.bin"; - SaveDataToNet(ctx_->saved_weights(), net_file); + SaveDataToNet(ctx_, net_file, config_->keep_original_weight()); } else { if (!ctx_->weight_buffer_size_code_blocks().empty()) { MS_LOG(ERROR) << "Weight init code generation error "; diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc index d25b3e6b..56b22333 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_base_coder.cc @@ -41,13 +41,18 @@ int ReshapeBaseCoder::DoCode(CoderContext *const context) { } REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Reshape, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Reshape, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Reshape, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Flatten, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Flatten, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Flatten, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_ExpandDims, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_ExpandDims, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_ExpandDims, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Squeeze, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Squeeze, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Squeeze, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Unsqueeze, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Unsqueeze, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Unsqueeze, CPUOpCoderCreator) } // namespace mindspore::lite::micro diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc new file mode 100644 index 00000000..ee887342 --- /dev/null +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "coder/opcoders/base/stack_base_coder.h" +#include +#include +#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" +#include "coder/opcoders/file_collector.h" +#include "coder/opcoders/parallel.h" + +using mindspore::schema::PrimitiveType_Stack; + +namespace mindspore::lite::micro::nnacl { +int StackFP32Coder::Prepare(CoderContext *const context) { + stack_param_ = reinterpret_cast(parameter_); + return ReSize(); +} + +int StackFP32Coder::ReSize() { + axis_ = stack_param_->axis_ >= 0 ? stack_param_->axis_ + : static_cast(input_tensor_->shape().size()) + stack_param_->axis_ + 1; + if (axis_ < 0 || axis_ > static_cast(input_tensor_->shape().size())) { + return RET_ERROR; + } + return RET_OK; +} + +int StackFP32Coder::DoCode(CoderContext *const context) { + Collect(context, + { + "nnacl/base/stack_base.h", + }, + { + "stack_base.c", + }); + + size_t input_num = input_tensors_.size(); + + NNaclFp32Serializer code; + code << "\t\tvoid *inputs_addr[] = {"; + for (size_t i = 0; i < input_num; ++i) { + code << allocator_->GetRuntimeAddr(input_tensors_.at(i)) << ", "; + } + code << "};\n"; + + size_t copy_size = 0; + int outer_size = 1; + auto shape = input_tensor_->shape(); + if (input_tensors_.empty()) { + copy_size = 0; + outer_size = 0; + } else if (input_tensors_.size() == 1) { + copy_size = input_tensor_->ElementsNum(); + outer_size = 1; + } else { + copy_size = 1; + for (int i = axis_; i < static_cast(shape.size()); ++i) { + copy_size *= shape[i]; + } + for (int i = 0; i < axis_; ++i) { + outer_size *= shape[i]; + } + } + copy_size *= DataTypeSize(input_tensor_->data_type()); + code.CodeFunction("Stack", "inputs_addr", output_tensor_, input_num, copy_size, 0, outer_size); + context->AppendCode(code.str()); + return RET_OK; +} + +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Stack, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Stack, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Stack, CPUOpCoderCreator) +} // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h new file mode 100644 index 00000000..08074332 --- /dev/null +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_ + +#include +#include "coder/opcoders/op_coder.h" +#include "nnacl/stack_parameter.h" + +namespace mindspore::lite::micro::nnacl { +class StackFP32Coder final : public OperatorCoder { + public: + StackFP32Coder(const std::vector &in_tensors, const std::vector &out_tensors, + const LiteGraph::Node *node, size_t node_index, Target target) + : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} + ~StackFP32Coder() override = default; + + int Prepare(CoderContext *const context) override; + int DoCode(CoderContext *const context) override; + + private: + int ReSize(); + + int axis_{0}; + StackParameter *stack_param_{nullptr}; +}; +} // namespace mindspore::lite::micro::nnacl +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STACK_FP32_CODER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc index ba9fbaa1..ffc70e1c 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_base_coder.cc @@ -33,6 +33,8 @@ size_t GetInnerSize(TypeId type_id, int inner_elements) { return inner_elements * sizeof(float); case kNumberTypeInt32: return inner_elements * sizeof(int32_t); + case kNumberTypeFloat16: + return inner_elements * sizeof(uint16_t); default: MS_LOG(ERROR) << "Not supported data type: " << type_id; return 0; @@ -142,6 +144,23 @@ int StridedSliceBaseCoder::DoFastCode(CoderContext *ctx) { } int StridedSliceBaseCoder::DoNormalCode(CoderContext *ctx) { + switch (input_tensor_->data_type()) { + case kNumberTypeInt8: + strided_slice_parameter_->data_type = ::kNumberTypeInt8; + break; + case kNumberTypeFloat32: + strided_slice_parameter_->data_type = ::kNumberTypeFloat32; + break; + case kNumberTypeInt32: + strided_slice_parameter_->data_type = ::kNumberTypeInt32; + break; + case kNumberTypeFloat16: + strided_slice_parameter_->data_type = ::kNumberTypeFloat16; + break; + default: + MS_LOG(ERROR) << "Not supported data type: " << input_tensor_->data_type(); + return RET_ERROR; + } nnacl::NNaclFp32Serializer code; code.CodeStruct("strided_slice_parameter", *strided_slice_parameter_); code.CodeFunction("DoStridedSlice", input_tensor_, output_tensor_, @@ -166,6 +185,8 @@ int StridedSliceBaseCoder::DoCode(CoderContext *ctx) { } REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_StridedSlice, CPUOpCoderCreator) +REG_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_StridedSlice, + CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_StridedSlice, CPUOpCoderCreator) REG_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_StridedSlice, CPUOpCoderCreator) } // namespace mindspore::lite::micro diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc new file mode 100644 index 00000000..5470b56a --- /dev/null +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h" +#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" +#include "coder/opcoders/file_collector.h" + +using mindspore::schema::PrimitiveType_Custom; + +namespace mindspore::lite::micro::nnacl { +void CustomGruFP16Coder::InitNnaclFile(CoderContext *const context) { + Collect(context, {"nnacl/fp16/custom_gru_fp16.h"}, + {"custom_gru_fp16.c", "pack_fp16.c", "matmul_fp16.c", "arithmetic_fp16.c", "activation_fp16.c"}); +} + +void CustomGruFP16Coder::InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, + int row, int col) { + init_code->CodeFunction("RowMajor2Col8MajorFp16", src, dst, row, col, false); +} + +REG_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Custom, CPUOpCoderCreator) +} // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h new file mode 100644 index 00000000..eb76faf6 --- /dev/null +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_ + +#include +#include +#include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" +#include "nnacl/custom_gru_parameter.h" + +namespace mindspore::lite::micro::nnacl { +class CustomGruFP16Coder : public CustomGruFP32Coder { + public: + CustomGruFP16Coder(const std::vector &in_tensors, const std::vector &out_tensors, + const LiteGraph::Node *node, size_t node_index, Target target) + : CustomGruFP32Coder(in_tensors, out_tensors, node, node_index, target) { + data_type_ = kNumberTypeFloat16; + row_tile_ = C4NUM; + col_tile_ = C8NUM; + op_func_ = "CustomGruFp16"; + } + ~CustomGruFP16Coder() override = default; + + protected: + void InitNnaclFile(CoderContext *const context) override; + void InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, int row, + int col) override; +}; +} // namespace mindspore::lite::micro::nnacl +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CUSTOM_GRU_FP16_CODER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc index f2aec9d2..37b90b65 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc @@ -30,13 +30,12 @@ int MatMulFP16BaseCoder::InitBiasData() { if (bias_ptr_) { return RET_OK; } - bias_pack_ptr_size_ = static_cast(params_->col_align_ * data_type_size_); + bias_pack_ptr_size_ = static_cast(params_->col_align_ * DataTypeSize(data_type_)); if (input_tensors_.size() == C3NUM) { - bias_ptr_ = allocator_->Malloc(kNumberTypeUInt8, kOnlineSize, kOnlinePackWeight, - bias_tensor_->tensor_name() + "_online_pack"); - } else { bias_ptr_ = - allocator_->Malloc(kNumberTypeUInt8, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); + allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); + } else { + bias_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); } return RET_OK; } @@ -45,18 +44,19 @@ int MatMulFP16BaseCoder::InitBufferA() { if (a_pack_ptr_ != nullptr || vec_matmul_) { return RET_OK; } - a_pack_ptr_size_ = static_cast(params_->batch * params_->row_align_ * params_->deep_ * sizeof(uint16_t)); + a_pack_ptr_size_ = + static_cast(params_->batch * params_->row_align_ * params_->deep_ * DataTypeSize(data_type_)); if (params_->a_const_) { a_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(0)); if (a_pack_ptr_ == nullptr) { - a_pack_ptr_ = allocator_->Malloc(kNumberTypeFloat16, kOnlineSize, kOnlinePackWeight, + a_pack_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, input_tensors_.at(0)->tensor_name() + "_online_pack"); allocator_->MarkSharedWeight(input_tensors_.at(0), a_pack_ptr_); } else { a_packed_ = true; } } else { - a_pack_ptr_ = allocator_->Malloc(kNumberTypeFloat16, a_pack_ptr_size_, kWorkspace); + a_pack_ptr_ = allocator_->Malloc(data_type_, a_pack_ptr_size_, kWorkspace); } MS_CHECK_PTR(a_pack_ptr_); return RET_OK; @@ -77,7 +77,7 @@ std::string MatMulFP16BaseCoder::InitMatrixA(NNaclFp32Serializer *const code, NN return allocator_->GetRuntimeAddr(input_tensor_, input_tensor_->IsConst()); } std::string input_a_str = allocator_->GetRuntimeAddr(input_tensor_); - std::string input_a_pack_str = "(float16_t *)" + allocator_->GetRuntimeAddr(a_pack_ptr_); + std::string input_a_pack_str = allocator_->GetRuntimeAddr(static_cast(a_pack_ptr_)); if (params_->a_const_) { init_code->CodeBufferOffsetExpression(a_pack_ptr_, context->weight_name(), context->weight_offset_name(), context->weight_size_name(), a_pack_ptr_size_); @@ -132,7 +132,7 @@ std::string MatMulFP16BaseCoder::InitMatrixB(NNaclFp32Serializer *const code, NN return allocator_->GetRuntimeAddr(filter_tensor_, filter_tensor_->IsConst()); } std::string input_b_str = allocator_->GetRuntimeAddr(filter_tensor_); - std::string input_b_pack_str = "(float16_t *)" + allocator_->GetRuntimeAddr(b_pack_ptr_); + std::string input_b_pack_str = allocator_->GetRuntimeAddr(static_cast(b_pack_ptr_)); if (params_->b_const_) { init_code->CodeBufferOffsetExpression(b_pack_ptr_, context->weight_name(), context->weight_offset_name(), context->weight_size_name(), b_pack_ptr_size_); @@ -248,7 +248,7 @@ int MatMulFP16BaseCoder::DoCode(CoderContext *const context) { init_code.CodeBufferOffsetExpression(bias_ptr_, context->weight_name(), context->weight_offset_name(), context->weight_size_name(), bias_pack_ptr_size_); w_buf_size += bias_pack_ptr_size_; - std::string bias_str = "(float16_t *)" + allocator_->GetRuntimeAddr(bias_ptr_); + std::string bias_str = allocator_->GetRuntimeAddr(bias_ptr_); if (input_tensors_.size() == DIMENSION_3D) { auto origin_bias_str = allocator_->GetRuntimeAddr(bias_tensor_); init_code.CodeFunction("memcpy", bias_str, origin_bias_str, bias_tensor_->Size()); diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h index 864f54ae..38270456 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.h @@ -27,7 +27,9 @@ class MatMulFP16BaseCoder : public MatMulFP32BaseCoder { public: MatMulFP16BaseCoder(const std::vector &in_tensors, const std::vector &out_tensors, const LiteGraph::Node *node, size_t node_index, Target target) - : MatMulFP32BaseCoder(in_tensors, out_tensors, node, node_index, target) {} + : MatMulFP32BaseCoder(in_tensors, out_tensors, node, node_index, target) { + data_type_ = kNumberTypeFloat16; + } ~MatMulFP16BaseCoder() override = default; diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h index 3a1cb66a..c5ea36cd 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_coder.h @@ -26,9 +26,7 @@ class MatMulFP16Coder final : public MatMulFP16BaseCoder { public: MatMulFP16Coder(const std::vector &in_tensors, const std::vector &out_tensors, const LiteGraph::Node *node, size_t node_index, Target target) - : MatMulFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) { - data_type_size_ = sizeof(uint16_t); - } + : MatMulFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) {} ~MatMulFP16Coder() override = default; diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc index f46005c6..e53472ca 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.cc @@ -20,42 +20,88 @@ #include "coder/opcoders/parallel.h" #include "coder/opcoders/file_collector.h" #include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" +#include "src/common/utils.h" namespace mindspore::lite::micro::nnacl { int ConvolutionDepthwiseFP32Coder::Prepare(CoderContext *const context) { MS_CHECK_RET_CODE(Conv2DBaseCoder::Init(), "Conv2DBaseCoder::Init() failed!"); - MS_CHECK_RET_CODE(InitWeightBias(), "dwconvolution do init weightbais failed"); + MS_CHECK_RET_CODE(InitParameter(), "dwconvolution do InitParamter failed"); + if (Configurator::GetInstance()->keep_original_weight()) { + MS_CHECK_RET_CODE(InitWeightBiasOnline(), "dwconvolution do InitWeightBiasOnline failed"); + } else { + MS_CHECK_RET_CODE(InitWeightBiasOffline(), "dwconvolution do InitWeightBiasOffline failed"); + } conv_param_->thread_num_ = MSMIN(thread_num_, conv_param_->output_h_); return RET_OK; } -int ConvolutionDepthwiseFP32Coder::InitWeightBias() { +int ConvolutionDepthwiseFP32Coder::InitParameter() { + auto shape = filter_tensor_->shape(); + MS_CHECK_TRUE_MSG(shape.size() == C4NUM, RET_ERROR, "Conv: filter-weight's shape must be 4D."); + packed_weight_size_ = + filter_tensor_->Batch() * filter_tensor_->Height() * filter_tensor_->Width() * DataTypeSize(data_type_); + packed_bias_size_ = filter_tensor_->Batch() * DataTypeSize(data_type_); + return RET_OK; +} + +int ConvolutionDepthwiseFP32Coder::InitWeightBiasOffline() { auto *origin_weight = reinterpret_cast(filter_tensor_->data()); MS_CHECK_PTR(origin_weight); int channel = filter_tensor_->Batch(); - size_t pack_weight_size = filter_tensor_->Batch() * filter_tensor_->Height() * filter_tensor_->Width(); - size_t packed_weight_data_size = pack_weight_size * sizeof(float); - packed_weight_ = - reinterpret_cast(allocator_->Malloc(kNumberTypeFloat32, packed_weight_data_size, kOfflinePackWeight)); + packed_weight_ = reinterpret_cast(allocator_->Malloc(data_type_, packed_weight_size_, kOfflinePackWeight)); MS_CHECK_PTR(packed_weight_); - MS_CHECK_RET_CODE(memset_s(packed_weight_, packed_weight_data_size, 0, packed_weight_data_size), + MS_CHECK_RET_CODE(memset_s(packed_weight_, packed_weight_size_, 0, packed_weight_size_), "memset packed weight failed!"); PackNCHWToNHWCFp32(origin_weight, packed_weight_, 1, filter_tensor_->Height() * filter_tensor_->Width(), channel, kDefaultTaskId, 0); - auto bias_size = static_cast(channel * sizeof(float)); - bias_ = reinterpret_cast(allocator_->Malloc(kNumberTypeFloat32, bias_size, kOfflinePackWeight)); + bias_ = reinterpret_cast(allocator_->Malloc(data_type_, packed_bias_size_, kOfflinePackWeight)); MS_CHECK_PTR(bias_); - MS_CHECK_RET_CODE(memset_s(bias_, bias_size, 0, bias_size), "memset bias failed!"); + MS_CHECK_RET_CODE(memset_s(bias_, packed_bias_size_, 0, packed_bias_size_), "memset bias failed!"); // init bias if (input_tensors_.size() == kInputSize2) { auto *ori_bias = reinterpret_cast(bias_tensor_->data()); MS_CHECK_TRUE(bias_tensor_->ElementsNum() > 0, "invalid bias length"); - MS_CHECK_RET_CODE(memcpy_s(bias_, bias_size, ori_bias, bias_tensor_->Size()), "memcpy_s bias failed!"); + MS_CHECK_RET_CODE(memcpy_s(bias_, packed_bias_size_, ori_bias, bias_tensor_->Size()), "memcpy_s bias failed!"); } return RET_OK; } +int ConvolutionDepthwiseFP32Coder::InitWeightBiasOnline() { + packed_weight_ = reinterpret_cast(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); + MS_CHECK_PTR(packed_weight_); + bias_ = reinterpret_cast(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); + MS_CHECK_PTR(bias_); + return RET_OK; +} + +void ConvolutionDepthwiseFP32Coder::InitCodeOnline(CoderContext *const context) { + if (!Configurator::GetInstance()->keep_original_weight()) { + return; + } + Collect(context, + { + "nnacl/fp32/pack_fp32.h", + }, + {"pack_fp32.c"}); + NNaclFp32Serializer init_code; + init_code.CodeBufferOffsetExpression(packed_weight_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), packed_weight_size_); + auto filter_str = allocator_->GetRuntimeAddr(filter_tensor_); + init_code.CodeFunction("PackNCHWToNHWCFp32", filter_str, packed_weight_, 1, + filter_tensor_->Height() * filter_tensor_->Width(), filter_tensor_->Batch(), 0, 0); + init_code.CodeBufferOffsetExpression(bias_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), packed_bias_size_); + if (input_tensors_.size() == kInputSize2) { + auto bias_str = allocator_->GetRuntimeAddr(bias_tensor_); + init_code.CodeFunction("memcpy", bias_, bias_str, bias_tensor_->Size()); + } else { + init_code.CodeFunction("memcpy", bias_, 0, packed_bias_size_); + } + context->AppendInitWeightSizeCode(packed_weight_size_ + packed_bias_size_); + context->AppendInitCode(init_code.str()); +} + int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) { MS_CHECK_TRUE(conv_param_->input_channel_ == conv_param_->output_channel_, "Only support input channel equals output channel."); @@ -78,6 +124,7 @@ int ConvolutionDepthwiseFP32Coder::DoCode(CoderContext *const context) { "activation_fp32.c", }, {}); + InitCodeOnline(context); nnacl::NNaclFp32Serializer code; // call the op function std::string param_name = "conv_parameter"; diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h index a5827f4f..39757871 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_depthwise_fp32_coder.h @@ -34,9 +34,15 @@ class ConvolutionDepthwiseFP32Coder final : public Conv2DBaseCoder { int DoCode(CoderContext *const context) override; private: - int InitWeightBias(); + int InitParameter(); + int InitWeightBiasOffline(); + int InitWeightBiasOnline(); + void InitCodeOnline(CoderContext *const context); + size_t packed_weight_size_{0}; float *packed_weight_{nullptr}; + size_t packed_bias_size_{0}; float *bias_{nullptr}; + TypeId data_type_{kNumberTypeFloat32}; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CONVOLUTION_DEPTHWISE_FP32_CODER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc index 556f851a..466db21a 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc @@ -77,13 +77,6 @@ const std::array OutputTransFuncRelu6List8 = {"", "OutputTransform8x6Relu6Unit", "OutputTransform8x7Relu6Unit"}; -int ConvolutionWinogradFP32Coder::WinogradFilterTransform(const float *weight_data, float *matrix_g, - const float *matrix_gt, int oc_block) { - MS_CHECK_TRUE(oc_block, "Divide by zero!"); - return WinogradWeightTransform(weight_data, trans_weight_, matrix_g, matrix_gt, oc_block, input_unit_, kernel_unit_, - conv_param_->input_channel_, conv_param_->output_channel_, true); -} - int ConvolutionWinogradFP32Coder::InitTmpBuffer() { int channel_out = conv_param_->output_channel_; int oc8 = UP_DIV(channel_out, C8NUM); @@ -115,12 +108,16 @@ int ConvolutionWinogradFP32Coder::Prepare(CoderContext *const context) { input_unit_ = output_unit_ + kernel_unit_ - 1; conv_param_->input_unit_ = input_unit_; conv_param_->output_unit_ = output_unit_; - ret = InitWeightBias(); - MS_CHECK_RET_CODE(ret, "Init weight bias failed."); + MS_CHECK_RET_CODE(InitParameter(), "Winograd convolution do InitParameter failed"); + if (Configurator::GetInstance()->keep_original_weight()) { + MS_CHECK_RET_CODE(InitWeightBiasOnline(), "Winograd convolution do InitWeightBiasOnline failed"); + } else { + MS_CHECK_RET_CODE(InitWeightBiasOffline(), "Winograd convolution do InitWeightBiasOffline failed"); + } return ReSize(); } // namespace micro -int ConvolutionWinogradFP32Coder::InitWeightBias() { +int ConvolutionWinogradFP32Coder::InitParameter() { int in_channel = filter_tensor_->Channel(); int out_channel = filter_tensor_->Batch(); MS_CHECK_TRUE(in_channel > 0, "invalid in channel size"); @@ -132,14 +129,10 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() { const int oc_block = C8NUM; int oc_block_num = UP_DIV(out_channel, C8NUM); // init weight - int trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block; - trans_weight_ = reinterpret_cast( - allocator_->Malloc(kNumberTypeFloat32, trans_matrix_data_size * sizeof(float), kOfflinePackWeight)); - MS_CHECK_PTR(trans_weight_); - int ret = memset_s(trans_weight_, trans_matrix_data_size * sizeof(float), 0, trans_matrix_data_size * sizeof(float)); - MS_CHECK_RET_CODE(ret, "memset_s failed!"); - float matrix_g[k64]; - float matrix_gt[k64]; + trans_weight_size_ = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * DataTypeSize(data_type_); + packed_bias_size_ = oc4 * C4NUM * DataTypeSize(data_type_); + matrix_g_.resize(k64); + matrix_gt_.resize(k64); float matrix_a[k64]; float matrix_at[k64]; float matrix_b[k64]; @@ -148,31 +141,41 @@ int ConvolutionWinogradFP32Coder::InitWeightBias() { if (input_unit_ == DIMENSION_8D) { coef = 0.5f; } - ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_); + auto ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g_.data(), matrix_gt_.data(), coef, + output_unit_, kernel_unit_); MS_CHECK_RET_CODE(ret, "CookToomFilter failed!"); - auto out_channel_size = static_cast(out_channel); + return RET_OK; +} + +int ConvolutionWinogradFP32Coder::InitWeightBiasOffline() { + trans_weight_ = reinterpret_cast(allocator_->Malloc(data_type_, trans_weight_size_, kOfflinePackWeight)); + MS_CHECK_PTR(trans_weight_); + int ret = memset_s(trans_weight_, trans_weight_size_, 0, trans_weight_size_); auto weight_data = reinterpret_cast(filter_tensor_->data()); MS_CHECK_PTR(weight_data); - ret = WinogradFilterTransform(weight_data, matrix_g, matrix_gt, oc_block); + WinogradWeightTransform(weight_data, trans_weight_, matrix_g_.data(), matrix_gt_.data(), C8NUM, input_unit_, + kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_, true); MS_CHECK_RET_CODE(ret, "winograd filter transform failed!"); - // init bias - int new_bias_ele_num = oc4 * C4NUM; - auto new_bias_ele_size = static_cast(new_bias_ele_num * sizeof(float)); - new_bias_ = reinterpret_cast(allocator_->Malloc(kNumberTypeFloat32, new_bias_ele_size, kOfflinePackWeight)); + new_bias_ = reinterpret_cast(allocator_->Malloc(data_type_, packed_bias_size_, kOfflinePackWeight)); MS_CHECK_PTR(new_bias_); - ret = memset_s(new_bias_, new_bias_ele_size, 0, new_bias_ele_size); + ret = memset_s(new_bias_, packed_bias_size_, 0, packed_bias_size_); MS_CHECK_RET_CODE(ret, "memset_s failed!"); if (input_tensors_.size() == kInputSize2) { auto ori_bias_addr = reinterpret_cast(bias_tensor_->data()); MS_CHECK_PTR(ori_bias_addr); - MS_CHECK_RET_CODE(memcpy_s(new_bias_, new_bias_ele_size, ori_bias_addr, out_channel_size * sizeof(float)), - "memcpy_s failed!"); - } else { - MS_CHECK_RET_CODE(memset_s(new_bias_, new_bias_ele_size, 0, new_bias_ele_size), "memset_s failed!"); + MS_CHECK_RET_CODE(memcpy_s(new_bias_, packed_bias_size_, ori_bias_addr, bias_tensor_->Size()), "memcpy_s failed!"); } return RET_OK; } +int ConvolutionWinogradFP32Coder::InitWeightBiasOnline() { + trans_weight_ = reinterpret_cast(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); + MS_CHECK_PTR(trans_weight_); + new_bias_ = reinterpret_cast(allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight)); + MS_CHECK_PTR(new_bias_); + return RET_OK; +} + int ConvolutionWinogradFP32Coder::ConfigInputOutput() { trans_func_str_.in_func_ = GetInputTransFunc(input_unit_); MS_CHECK_TRUE(!trans_func_str_.in_func_.empty(), "Get input_trans_func failed."); @@ -217,6 +220,36 @@ std::string ConvolutionWinogradFP32Coder::GetOutputTransFunc(int input_unit, int } } +void ConvolutionWinogradFP32Coder::InitCodeOnline(CoderContext *const context) { + if (!Configurator::GetInstance()->keep_original_weight()) { + return; + } + Collect(context, + { + "nnacl/base/minimal_filtering_generator.h", + "nnacl/fp32/pack_fp32.h", + }, + {"minimal_filtering_generator.c", "nnacl/fp32/pack_fp32.h"}); + NNaclFp32Serializer init_code; + init_code.CodeBufferOffsetExpression(trans_weight_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), trans_weight_size_); + auto filter_str = allocator_->GetRuntimeAddr(filter_tensor_); + init_code.CodeArray("matrix_g", matrix_g_.data(), k64); + init_code.CodeArray("matrix_gt", matrix_gt_.data(), k64); + init_code.CodeFunction("WinogradWeightTransform", filter_str, trans_weight_, "matrix_g", "matrix_gt", C8NUM, + input_unit_, kernel_unit_, conv_param_->input_channel_, conv_param_->output_channel_, true); + init_code.CodeBufferOffsetExpression(new_bias_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), packed_bias_size_); + if (input_tensors_.size() == kInputSize2) { + auto bias_str = allocator_->GetRuntimeAddr(bias_tensor_); + init_code.CodeFunction("memcpy", new_bias_, bias_str, bias_tensor_->Size()); + } else { + init_code.CodeFunction("memcpy", new_bias_, 0, packed_bias_size_); + } + context->AppendInitWeightSizeCode(trans_weight_size_ + packed_bias_size_); + context->AppendInitCode(init_code.str()); +} + int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) { Collect(context, { @@ -253,6 +286,7 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) { } else if (target_ == kARM64) { Collect(context, {}, {}, { + "BigMatmulFp32Opt.S", "MatmulFp32.S", "MatmulFp32Opt.S", "PreSum4x16Int8Peroc.S", @@ -263,14 +297,14 @@ int ConvolutionWinogradFP32Coder::DoCode(CoderContext *const context) { "MatmulInt8.S", }); } - + InitCodeOnline(context); NNaclFp32Serializer code; // call the op function code.CodeFunction("memset", trans_input_, "0", tile_buffer_size_); code.CodeFunction("memset", gemm_out_, "0", gemm_out_size_); code.CodeFunction("memset", tmp_data_, "0", tmp_data_size_); code.CodeFunction("memset", col_buffer_, "0", col_buffer_size_); - code << "\t\tfloat *tmp_buffer_address_list[4] = {" << allocator_->GetRuntimeAddr(trans_input_) << ", " + code << " float *tmp_buffer_address_list[4] = {" << allocator_->GetRuntimeAddr(trans_input_) << ", " << allocator_->GetRuntimeAddr(gemm_out_) << ", " << allocator_->GetRuntimeAddr(tmp_data_) << ", " << allocator_->GetRuntimeAddr(col_buffer_) << "};\n"; code.CodeStruct("conv_parameter", *conv_param_); diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h index d583312a..a4a0438f 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h @@ -38,7 +38,13 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { ~ConvolutionWinogradFP32Coder() override = default; private: - int InitWeightBias(); + int InitParameter(); + + int InitWeightBiasOffline(); + + int InitWeightBiasOnline(); + + void InitCodeOnline(CoderContext *const context); int ConfigInputOutput(); @@ -46,13 +52,13 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { int ReSize(); - int WinogradFilterTransform(const float *weight_data, float *matrix_g, const float *matrix_gt, int oc_block); - std::string GetInputTransFunc(int input_unit); std::string GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); + size_t trans_weight_size_{0}; float *trans_weight_{nullptr}; + size_t packed_bias_size_{0}; float *new_bias_{nullptr}; int kernel_unit_{0}; @@ -70,6 +76,9 @@ class ConvolutionWinogradFP32Coder : public Conv2DBaseCoder { float *col_buffer_{nullptr}; TransFuncStr trans_func_str_; + TypeId data_type_{kNumberTypeFloat32}; + std::vector matrix_g_; + std::vector matrix_gt_; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CONVOLUTION_WINOGRAD_FP32_CODER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc new file mode 100644 index 00000000..50146e72 --- /dev/null +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h" +#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" +#include "coder/opcoders/file_collector.h" +#include "nnacl/custom_gru_parameter.h" + +using mindspore::schema::PrimitiveType_Custom; + +namespace mindspore::lite::micro::nnacl { +namespace { +constexpr size_t kOutputNum = 3; +constexpr size_t kInputDims = 3; +constexpr size_t kWeightDims = 2; +constexpr size_t kInputSize = 6; +} // namespace +int CustomGruFP32Coder::Prepare(CoderContext *const context) { + if (input_tensors_.size() != kInputSize) { + MS_LOG(ERROR) << "built-in CustomGru must have 6 input." << node_->name_; + return RET_ERROR; + } + for (size_t i = 1; i < kInputSize - 1; ++i) { + if (!input_tensors_[i]->IsConst()) { + MS_LOG(ERROR) << "built-in CustomGru only support first-input and last-input is variable." << node_->name_; + return RET_NOT_SUPPORT; + } + } + if (InitParamter() != RET_OK) { + MS_LOG(ERROR) << "Init built-in CustomGru Parameter failed." << node_->name_; + return RET_ERROR; + } + return ReSize(); +} + +int CustomGruFP32Coder::InitParamter() { + param_ = reinterpret_cast(parameter_); + param_->op_parameter_.thread_num_ = 1; + auto weight_in_shape = input_tensors_[1]->shape(); + auto weight_hidden_shape = input_tensors_[C2NUM]->shape(); + if (weight_in_shape.size() != kWeightDims || weight_hidden_shape.size() != kWeightDims) { + MS_LOG(ERROR) << "built-in CustomGru's weight must be 2D." << node_->name_; + return RET_ERROR; + } + if (weight_in_shape[0] != weight_hidden_shape[0]) { + MS_LOG(ERROR) << "Built-in CustomGru's weight-in and weight-hidden first-dim must be same." << node_->name_; + return RET_ERROR; + } + if (weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) { + MS_LOG(ERROR) << "Built-in CustomGru's weight-hidden first-dim must be 3 * second-dim." << node_->name_; + return RET_ERROR; + } + auto bias_in_shape = input_tensors_[C3NUM]->shape(); + auto bias_hidden_shape = input_tensors_[C4NUM]->shape(); + if (bias_in_shape.size() != 1) { + MS_LOG(ERROR) << "built-in CustomGru's bias must be 1D." << node_->name_; + return RET_ERROR; + } + if (bias_in_shape != bias_hidden_shape) { + MS_LOG(ERROR) << "built-in CustomGru's bias-in and bias-hidden must have same shape." << node_->name_; + return RET_ERROR; + } + if (bias_in_shape.back() != weight_in_shape.front()) { + MS_LOG(ERROR) << "built-in CustomGru's bias-in shape don't match with the first-dim of weight." << node_->name_; + return RET_ERROR; + } + if (bias_in_shape.front() % C3NUM != 0) { + MS_LOG(ERROR) << "The first-dim of CustomGru's weight must be 3 * hidden."; + return RET_ERROR; + } + param_->input_size = weight_in_shape.back(); + param_->hidden_size = bias_in_shape.front() / C3NUM; + return RET_OK; +} + +int CustomGruFP32Coder::ReSize() { + auto in_shape = input_tensor_->shape(); + if (in_shape.size() != kInputDims) { + MS_LOG(ERROR) << "built-in CustomGru's first-input must be 3D." << node_->name_; + return RET_ERROR; + } + param_->num_step = in_shape[0]; + param_->batch_size = in_shape[1]; + if (in_shape.back() != param_->input_size) { + MS_LOG(ERROR) << "built-in CustomGru's fisrt-input don't match its weight." << node_->name_; + return RET_ERROR; + } + return InitWeightAndBias(); +} + +int CustomGruFP32Coder::InitWeightAndBias() { + auto col_align = UP_ROUND(param_->hidden_size, col_tile_); + auto data_type_size = DataTypeSize(data_type_); + bias_pack_size_ = col_align * data_type_size; + weight_in_pack_size_ = static_cast(col_align * param_->input_size) * data_type_size; + weight_input_ = allocator_->Malloc(data_type_, weight_in_pack_size_ * C3NUM, kOnlinePackWeight, + input_tensors_.at(1)->tensor_name() + "_online_pack"); + MS_CHECK_TRUE_MSG(weight_input_ != nullptr, RET_NULL_PTR, "Init weight-in failed."); + weight_hidden_pack_size_ = static_cast(col_align * param_->hidden_size) * data_type_size; + weight_hidden_ = allocator_->Malloc(data_type_, weight_hidden_pack_size_ * C3NUM, kOnlinePackWeight, + input_tensors_.at(C2NUM)->tensor_name() + "_online_pack"); + MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, RET_NULL_PTR, "Init weight-hidden failed."); + bias_input_ = allocator_->Malloc(data_type_, bias_pack_size_ * C3NUM, kOnlinePackWeight, + input_tensors_.at(C3NUM)->tensor_name() + "_online_pack"); + MS_CHECK_TRUE_MSG(bias_input_ != nullptr, RET_NULL_PTR, "Init bias-in failed."); + bias_hidden_ = allocator_->Malloc(data_type_, bias_pack_size_ * C3NUM, kOnlinePackWeight, + input_tensors_.at(C4NUM)->tensor_name() + "_online_pack"); + MS_CHECK_TRUE_MSG(bias_hidden_ != nullptr, RET_NULL_PTR, "Init bias-hidden failed."); + auto row_align = UP_ROUND(param_->batch_size, row_tile_); + auto work_space = + (row_align * (param_->input_size + param_->hidden_size) + param_->batch_size * param_->hidden_size * C6NUM) * + data_type_size; + run_buffer_ = allocator_->Malloc(data_type_, work_space, kWorkspace); + MS_CHECK_TRUE_MSG(run_buffer_ != nullptr, RET_NULL_PTR, "Init run_buffer failed."); + return RET_OK; +} + +void CustomGruFP32Coder::InitNnaclFile(CoderContext *const context) { + Collect(context, {"nnacl/fp32/custom_gru_fp32.h"}, + {"custom_gru_fp32.c", "pack_fp32.c", "matmul_fp32.c", "arithmetic_fp32.c", "activation_fp32.c"}); +} + +void CustomGruFP32Coder::InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, + int row, int col) { + init_code->CodeFunction("RowMajor2Col8Major", src, dst, row, col); +} + +void CustomGruFP32Coder::InitBiasCode(CoderContext *const context, NNaclFp32Serializer *init_code) { + auto data_type_size = DataTypeSize(data_type_); + init_code->CodeBufferOffsetExpression(bias_input_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), bias_pack_size_ * C3NUM); + auto bias_in_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_input_); + auto bias_in_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C3NUM]); + for (int i = 0; i < C3NUM; ++i) { + auto dst_bias_in = bias_in_str + " + " + std::to_string(i * bias_pack_size_ / data_type_size); + auto src_bias_in = bias_in_tensor + " + " + std::to_string(i * param_->hidden_size); + init_code->CodeFunction("memcpy", dst_bias_in, src_bias_in, param_->hidden_size * data_type_size); + } + init_code->CodeBufferOffsetExpression(bias_hidden_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), bias_pack_size_ * C3NUM); + auto bias_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_hidden_); + auto bias_hidden_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C4NUM]); + for (int i = 0; i < C3NUM; ++i) { + auto dst_bias_hidden = bias_hidden_str + " + " + std::to_string(i * bias_pack_size_ / data_type_size); + auto src_bias_hidden = bias_hidden_tensor + " + " + std::to_string(i * param_->hidden_size); + init_code->CodeFunction("memcpy", dst_bias_hidden, src_bias_hidden, param_->hidden_size * data_type_size); + } +} + +void CustomGruFP32Coder::InitWeightCode(CoderContext *const context, NNaclFp32Serializer *init_code) { + auto data_type_size = DataTypeSize(data_type_); + init_code->CodeBufferOffsetExpression(weight_input_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), weight_in_pack_size_ * C3NUM); + auto weight_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_input_); + auto weight_in_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[1]); + for (int i = 0; i < C3NUM; ++i) { + auto dst_weight_in = weight_input_str + " + " + std::to_string(i * weight_in_pack_size_ / data_type_size); + auto src_weight_in = weight_in_tensor + " + " + std::to_string(i * param_->hidden_size * param_->input_size); + InitPackMatrixB(init_code, src_weight_in, dst_weight_in, param_->hidden_size, param_->input_size); + } + + init_code->CodeBufferOffsetExpression(weight_hidden_, context->weight_name(), context->weight_offset_name(), + context->weight_size_name(), weight_hidden_pack_size_ * C3NUM); + auto weight_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_hidden_); + auto weight_hidden_tensor = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[C2NUM]); + for (int i = 0; i < C3NUM; ++i) { + auto dst_weight_hidden = weight_hidden_str + " + " + std::to_string(i * weight_hidden_pack_size_ / data_type_size); + auto src_weight_hidden = + weight_hidden_tensor + " + " + std::to_string(i * param_->hidden_size * param_->hidden_size); + InitPackMatrixB(init_code, src_weight_hidden, dst_weight_hidden, param_->hidden_size, param_->hidden_size); + } +} + +int CustomGruFP32Coder::DoCode(CoderContext *const context) { + NNaclFp32Serializer code, init_code; + code.CodeStruct("custom_gru_parm", *param_); + InitNnaclFile(context); + InitWeightCode(context, &init_code); + InitBiasCode(context, &init_code); + auto row_align = UP_ROUND(param_->batch_size, row_tile_); + auto data_type_str = GetTensorDataType(data_type_); + auto buffer_name = "( " + data_type_str + "*)" + MemoryAllocator::GetInstance()->GetRuntimeAddr(run_buffer_); + int offset1 = row_align * param_->input_size; + int offset2 = offset1 + param_->batch_size * param_->hidden_size * C3NUM; + int offset3 = offset2 + row_align * param_->hidden_size; + code << data_type_str + "*buffer[4] = {" << buffer_name << ", " << buffer_name + " + " << offset1 << ", " + << buffer_name + " + " << offset2 << ", " << buffer_name + " + " << offset3 << "};\n"; + auto weight_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_input_); + auto weight_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(weight_hidden_); + auto bias_input_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_input_); + auto bias_hidden_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(bias_hidden_); + code.CodeFunction(op_func_, output_tensor_, input_tensor_, weight_input_str, weight_hidden_str, bias_input_str, + bias_hidden_str, input_tensors_[C5NUM], "buffer", "&custom_gru_parm"); + context->AppendInitWeightSizeCode((weight_in_pack_size_ + weight_hidden_pack_size_) * C3NUM + + bias_pack_size_ * C6NUM); + context->AppendInitCode(init_code.str()); + context->AppendCode(code.str()); + return RET_OK; +} + +REG_BUILIN_CUSTOM_CODER(kARM64, kNumberTypeFloat32, "CustomGRU", CPUOpCoderCreator) +} // namespace mindspore::lite::micro::nnacl diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h new file mode 100644 index 00000000..27db0f94 --- /dev/null +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h @@ -0,0 +1,64 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_ + +#include +#include +#include "coder/opcoders/op_coder.h" +#include "nnacl/custom_gru_parameter.h" +#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" + +namespace mindspore::lite::micro::nnacl { +class CustomGruFP32Coder : public OperatorCoder { + public: + CustomGruFP32Coder(const std::vector &in_tensors, const std::vector &out_tensors, + const LiteGraph::Node *node, size_t node_index, Target target) + : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} + ~CustomGruFP32Coder() override = default; + + int Prepare(CoderContext *const context) override; + + int DoCode(CoderContext *const context) override; + + protected: + virtual void InitNnaclFile(CoderContext *const context); + virtual void InitPackMatrixB(NNaclFp32Serializer *init_code, const std::string &src, const std::string &dst, int row, + int col); + TypeId data_type_{kNumberTypeFloat32}; + int row_tile_{C12NUM}; + int col_tile_{C8NUM}; + void *weight_input_{nullptr}; + void *weight_hidden_{nullptr}; + void *bias_input_{nullptr}; + void *bias_hidden_{nullptr}; + size_t weight_in_pack_size_{0}; + size_t weight_hidden_pack_size_{0}; + size_t bias_pack_size_{0}; + std::string op_func_{"CustomGru"}; + CustomGruParameter *param_{nullptr}; + + private: + int InitParamter(); + int InitWeightAndBias(); + int ReSize(); + void InitWeightCode(CoderContext *const context, NNaclFp32Serializer *init_code); + void InitBiasCode(CoderContext *const context, NNaclFp32Serializer *init_code); + void *run_buffer_{nullptr}; +}; +} // namespace mindspore::lite::micro::nnacl +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_CUSTOM_GRU_FP32_CODER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc index 3c31479c..c6b93abf 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.cc @@ -27,11 +27,30 @@ using mindspore::schema::PrimitiveType_Gather; namespace mindspore::lite::micro::nnacl { int GatherFP32Coder::Prepare(CoderContext *const context) { return RET_OK; } +void GatherFP32Coder::InitCodeInChange(CoderContext *const context, std::string *auxiliary_variable) { + auto input0_shape_str = allocator_->GetAuxiliaryWeight(input_tensor_); + if (input0_shape_str.empty()) { + return; + } + *auxiliary_variable = input0_shape_str; + NNaclFp32Serializer init_code; + auto in_shape = input_tensor_->shape(); + int in_rank = static_cast(in_shape.size()); + init_code.CodeArray("shape", in_shape.data(), in_rank); + init_code << " for (int i = 0; i < " << in_rank << "; ++i) {\n"; + init_code << " if (i != " << axis_ << " && " << input0_shape_str << "[i] != shape[i]) {\n"; + init_code << " return RET_ERROR;\n"; + init_code << " }\n"; + init_code << " }\n"; + context->AppendInitCode(init_code.str()); +} + int GatherFP32Coder::DoCode(CoderContext *context) { Tensor *input0 = input_tensors_.at(0); Tensor *input1 = input_tensors_.at(1); MS_CHECK_PTR(input0); MS_CHECK_PTR(input1); + MS_CHECK_PTR(parameter_); MS_CHECK_TRUE_MSG(input1->data_type() == kNumberTypeInt32 || input1->data_type() == kNumberTypeInt, RET_ERROR, "index's data-type is not int32"); // generate code .h .c @@ -44,18 +63,16 @@ int GatherFP32Coder::DoCode(CoderContext *context) { }); NNaclFp32Serializer code; - std::vector in_shape = input0->shape(); + auto in_shape = input0->shape(); int in_rank = static_cast(in_shape.size()); - MS_CHECK_PTR(parameter_); - int axis = (reinterpret_cast(parameter_))->axis_; - MS_CHECK_TRUE(static_cast(in_shape.size()) >= axis, "invalid axis in gather parameter"); - const int limit = in_shape.at(axis); - - int outer_size = 1, inner_size = 1; - for (int i = 0; i < axis; ++i) { + axis_ = *(reinterpret_cast(input_tensors_.at(THIRD_INPUT)->data())); + MS_CHECK_TRUE(static_cast(in_shape.size()) >= axis_, "invalid axis in gather parameter"); + int outer_size = 1; + for (int i = 0; i < axis_; ++i) { outer_size *= in_shape.at(i); } - for (int i = axis + 1; i < in_rank; ++i) { + int inner_size = 1; + for (int i = axis_ + 1; i < in_rank; ++i) { inner_size *= in_shape.at(i); } auto data_size = static_cast(lite::DataTypeSize(input0->data_type())); @@ -67,22 +84,22 @@ int GatherFP32Coder::DoCode(CoderContext *context) { int start = stride * kDefaultTaskId; int count = MSMIN(stride, outer_size - stride * kDefaultTaskId); std::string input0_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input0, true); - if (input0_data.empty()) { - MS_LOG(ERROR) << "pointer is not allocated by the allocator"; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(!input0_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); std::string input1_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input1, true); - if (input1_data.empty()) { - MS_LOG(ERROR) << "pointer is not allocated by the allocator"; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(!input1_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); std::string output_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(output_tensor_, true); - if (output_data.empty()) { - MS_LOG(ERROR) << "pointer is not allocated by the allocator"; - return RET_ERROR; + MS_CHECK_TRUE_MSG(!output_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); + + std::string limit = std::to_string(in_shape[axis_]); + std::string in_offset = std::to_string(start * in_shape[axis_] * byte_inner_size); + std::string auxiliary_variable; + InitCodeInChange(context, &auxiliary_variable); + if (!auxiliary_variable.empty()) { + limit = auxiliary_variable + "[" + std::to_string(axis_) + "]"; + in_offset = std::to_string(start) + " * " + limit + " * " + std::to_string(byte_inner_size); } code << "\t\tconst int8_t *int8_in = (const int8_t *)" << input0_data << ";\n"; - code << "\t\tint8_in += " << std::to_string(start * limit * byte_inner_size) << ";\n"; + code << "\t\tint8_in += " << in_offset << ";\n"; code << "\t\tconst int *index_data = (const int *)" << input1_data << ";\n"; code << "\t\tint8_t *int8_out = (int8_t *)" << output_data << ";\n"; code << "\t\tint8_out += " << std::to_string(start * byte_out_stride) << ";\n"; diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h index a14d9c3c..a175d694 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_fp32_coder.h @@ -35,7 +35,9 @@ class GatherFP32Coder final : public OperatorCoder { int DoCode(CoderContext *const context) override; private: + void InitCodeInChange(CoderContext *const context, std::string *auxiliary_variable); int32_t *indices_{nullptr}; + int axis_{0}; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_FP32_CODER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc index 790a142e..6115edb5 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc @@ -50,13 +50,13 @@ int MatMulFP32BaseCoder::ReSize() { int MatMulFP32BaseCoder::InitBiasData() { if (input_tensors_.size() == DIMENSION_3D) { int max_bias_data = params_->col_align_; - bias_pack_ptr_size_ = static_cast(max_bias_data * sizeof(float)); + bias_pack_ptr_size_ = static_cast(max_bias_data * DataTypeSize(data_type_)); if (bias_tensor_->ElementsNum() == 1) { is_bias_broadcast_ = true; } - ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * sizeof(float); - bias_ptr_ = allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight, - bias_tensor_->tensor_name() + "_online_pack"); + ori_bias_pack_ptr_size_ = bias_tensor_->ElementsNum() * DataTypeSize(data_type_); + bias_ptr_ = + allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); MS_CHECK_PTR(bias_ptr_); } return RET_OK; @@ -83,18 +83,19 @@ int MatMulFP32BaseCoder::InitBufferA() { if (a_pack_ptr_ != nullptr) { return RET_OK; } - a_pack_ptr_size_ = static_cast(params_->batch * params_->row_align_ * params_->deep_ * sizeof(float)); + a_pack_ptr_size_ = + static_cast(params_->batch * params_->row_align_ * params_->deep_ * DataTypeSize(data_type_)); if (params_->a_const_) { - a_pack_ptr_ = reinterpret_cast(allocator_->GetSharedWeightAddr(input_tensors_.at(0))); + a_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(0)); if (a_pack_ptr_ == nullptr) { - a_pack_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeFloat32, kOnlineSize, kOnlinePackWeight, - input_tensors_.at(0)->tensor_name() + "_online_pack")); + a_pack_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, + input_tensors_.at(0)->tensor_name() + "_online_pack"); allocator_->MarkSharedWeight(input_tensors_.at(0), a_pack_ptr_); } else { a_packed_ = true; } } else { - a_pack_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeFloat32, a_pack_ptr_size_, kWorkspace)); + a_pack_ptr_ = allocator_->Malloc(data_type_, a_pack_ptr_size_, kWorkspace); } MS_CHECK_PTR(a_pack_ptr_); return RET_OK; @@ -104,18 +105,19 @@ int MatMulFP32BaseCoder::InitBufferB() { if (b_pack_ptr_ != nullptr) { return RET_OK; } - b_pack_ptr_size_ = static_cast(params_->batch * params_->col_align_ * params_->deep_ * data_type_size_); + b_pack_ptr_size_ = + static_cast(params_->batch * params_->col_align_ * params_->deep_ * DataTypeSize(data_type_)); if (params_->b_const_) { - b_pack_ptr_ = reinterpret_cast(allocator_->GetSharedWeightAddr(input_tensors_.at(1))); + b_pack_ptr_ = allocator_->GetSharedWeightAddr(input_tensors_.at(1)); if (b_pack_ptr_ == nullptr) { - b_pack_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeUInt8, b_pack_ptr_size_, kOnlinePackWeight, - input_tensors_.at(1)->tensor_name() + "_online_pack")); + b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size_, kOnlinePackWeight, + input_tensors_.at(1)->tensor_name() + "_online_pack"); allocator_->MarkSharedWeight(input_tensors_.at(1), b_pack_ptr_); } else { b_packed_ = true; } } else { - b_pack_ptr_ = reinterpret_cast(allocator_->Malloc(kNumberTypeUInt8, b_pack_ptr_size_, kWorkspace)); + b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size_, kWorkspace); } MS_CHECK_PTR(b_pack_ptr_); return RET_OK; @@ -194,7 +196,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { NNaclFp32Serializer code, init_code; size_t w_buf_size = 0; std::string param_name = "mat_mul_parameter"; - std::string bias_ptr_str = "((float *)(" + allocator_->GetRuntimeAddr(bias_ptr_) + "))"; + code.CodeStruct(param_name, *params_); if (support_parallel_) { code << " " << param_name << ".op_parameter_.thread_num_ = 1;\n"; @@ -207,6 +209,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { int max_bias_data = params_->col_align_; if (is_bias_broadcast_) { float broad_cast_data = (reinterpret_cast(bias_tensor_->data()))[0]; + std::string bias_ptr_str = allocator_->GetRuntimeAddr(bias_ptr_); init_code << "\t for (int i = 0; i < " << max_bias_data << "; ++i) {\n"; init_code << "\t\t " << bias_ptr_str << "[i] = " << broad_cast_data << ";\n"; init_code << " }\n"; @@ -219,8 +222,8 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { std::string a_str = allocator_->GetRuntimeAddr(input_tensor_); std::string b_str = allocator_->GetRuntimeAddr(filter_tensor_); std::string c_str = allocator_->GetRuntimeAddr(output_tensor_); - std::string a_pack_str = allocator_->GetRuntimeAddr(a_pack_ptr_); - std::string b_pack_str = allocator_->GetRuntimeAddr(b_pack_ptr_); + std::string a_pack_str = allocator_->GetRuntimeAddr(static_cast(a_pack_ptr_)); + std::string b_pack_str = allocator_->GetRuntimeAddr(static_cast(b_pack_ptr_)); // do const value packing to init if ((params_->a_const_ && !a_packed_) || (params_->b_const_ && !b_packed_)) { init_code.CodeStruct("mat_mul_parameter", *params_); @@ -271,7 +274,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { code << " const float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_ << ";\n"; code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n "; - code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_str, params_->act_type_, + code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_, params_->deep_, cur_oc); } else { code << " const float *batch_a_ptr = " << a_pack_str << " + i * " << params_->row_align_ * params_->deep_ @@ -280,7 +283,7 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) { << ";\n"; code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n "; - code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_str, params_->act_type_, + code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_, params_->deep_, params_->row_, cur_oc, params_->col_, "OutType_Nhwc"); } code << " }\n"; diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h index 68b2658a..a5ef9277 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.h @@ -69,7 +69,7 @@ class MatMulFP32BaseCoder : public OperatorCoder { size_t a_pack_ptr_size_{0}; size_t b_pack_ptr_size_{0}; bool is_bias_broadcast_{false}; - size_t data_type_size_{C4NUM}; + TypeId data_type_{kNumberTypeFloat32}; }; } // namespace mindspore::lite::micro::nnacl #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_MATMUL_FP32_BASE_CODER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc index a107c3cf..45b2e37f 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc @@ -27,6 +27,14 @@ std::unique_ptr OpCoderBuilder::build(int schema_version) { MS_CHECK_PTR_RET_NULL(node_->primitive_); int primitive_type = GetPrimitiveType(node_->primitive_, schema_version); CoderKey coder_key(target_, data_type_, primitive_type); + if (builtin_custom_) { + auto custom_type = reinterpret_cast(node_->primitive_)->value_as_Custom()->type(); + if (custom_type == nullptr || custom_type->str().empty()) { + MS_LOG(ERROR) << "Builtin custom-op has no type."; + return nullptr; + } + coder_key = CoderKey(target_, data_type_, schema::PrimitiveType_Custom, custom_type->str()); + } CoderCreatorFunc creator_func = OpCoderFactory::GetInstance()->FindOpCoder(coder_key); if (creator_func == nullptr) { MS_LOG(ERROR) << "caught unsupported layer: " << node_->name_; @@ -112,5 +120,10 @@ OpCoderBuilder &OpCoderBuilder::support_parallel(bool parallel) { return *this; } +OpCoderBuilder &OpCoderBuilder::is_builtin_custom(bool builtin_custom) { + builtin_custom_ = builtin_custom; + return *this; +} + void OpCoderBuilder::Reset() {} } // namespace mindspore::lite::micro diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h index adce6c73..d85f1c32 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h @@ -46,6 +46,8 @@ class OpCoderBuilder { OpCoderBuilder &support_parallel(bool parallel); + OpCoderBuilder &is_builtin_custom(bool builtin_custom); + void Reset(); private: @@ -70,6 +72,8 @@ class OpCoderBuilder { std::vector output_indices_; bool support_parallel_{false}; + + bool builtin_custom_{false}; }; } // namespace mindspore::lite::micro #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_BUILDER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc index 031df2e7..cf26d51d 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc @@ -37,9 +37,9 @@ OpCoderFactory *OpCoderFactory::GetInstance() { } int OpCoderFactory::RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, - const CoderCreatorFunc &creator_func) { + const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func) { // check key - CoderKey key(target, data_type, operator_type); + CoderKey key(target, data_type, operator_type, builtin_custom_type); // insert pair to registry if (this->opcoder_sets_.find(key) != this->opcoder_sets_.end()) { MS_LOG(ERROR) << "coder already exist: " << key.ToString(); @@ -63,7 +63,7 @@ CoderCreatorFunc OpCoderFactory::FindOpCoder(const CoderKey &key) { } OpCoderRegister::OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, - const CoderCreatorFunc &creatorFunc) { - OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, creatorFunc); + const std::string &builtin_custom_type, const CoderCreatorFunc &creatorFunc) { + OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, builtin_custom_type, creatorFunc); } } // namespace mindspore::lite::micro diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h index 9a1aed63..acbd3a22 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_ #include +#include #include #include #include @@ -34,10 +35,14 @@ class CoderKey { public: CoderKey() = delete; - CoderKey(Target target, TypeId data_type, int op_type) : target_(target), data_type_(data_type), op_type_(op_type) {} + CoderKey(Target target, TypeId data_type, int op_type, std::string builtin_custom_type = "") + : target_(target), + data_type_(data_type), + op_type_(op_type), + builtin_custom_type_(std::move(builtin_custom_type)) {} CoderKey AllKey() const { - CoderKey key(kAllTargets, data_type_, op_type_); + CoderKey key(kAllTargets, data_type_, op_type_, builtin_custom_type_); return key; } @@ -50,6 +55,7 @@ class CoderKey { Target target_ = kTargetUnknown; TypeId data_type_ = kTypeUnknown; int op_type_ = schema::PrimitiveType_NONE; + std::string builtin_custom_type_; }; class OpCoderFactory { @@ -59,7 +65,7 @@ class OpCoderFactory { static OpCoderFactory *GetInstance(); int RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, - const CoderCreatorFunc &creator_func); + const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); CoderCreatorFunc FindOpCoder(const CoderKey &key); @@ -75,11 +81,16 @@ class OpCoderRegister { OpCoderRegister() = delete; OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, - const CoderCreatorFunc &creator_func); + const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); ~OpCoderRegister() = default; }; -#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ - static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, creator_func); +#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ + static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, "", \ + creator_func); + +#define REG_BUILIN_CUSTOM_CODER(target, data_type, custom_type, creator_func) \ + static OpCoderRegister g_##target##data_type##operator_type##Creator( \ + target, data_type, schema::PrimitiveType_Custom, custom_type, creator_func); } // namespace mindspore::lite::micro #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_ diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc index c333b621..cde08fd8 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc @@ -196,6 +196,11 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const BroadcastSha ToString(op_param.output_shape_), op_param.output_shape_size_); } +void NNaclFp32Serializer::CodeStruct(const std::string &name, const CustomGruParameter &op_param) { + CodeBaseStruct("CustomGruParameter", name, op_param.op_parameter_, op_param.num_step, op_param.batch_size, + op_param.input_size, op_param.hidden_size); +} + void NNaclFp32Serializer::CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector tensor) { std::vector tensor_names; int size = tensor.size(); diff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h index f52ced20..797a9574 100644 --- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h +++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h @@ -44,6 +44,7 @@ #include "nnacl/layer_norm_parameter.h" #include "nnacl/broadcast_to_parameter.h" #include "nnacl/split_parameter.h" +#include "nnacl/custom_gru_parameter.h" namespace mindspore::lite::micro::nnacl { class NNaclFp32Serializer : public Serializer { @@ -74,6 +75,7 @@ class NNaclFp32Serializer : public Serializer { void CodeStruct(const std::string &name, const SplitParameter &split_parameter); void CodeStruct(const std::string &name, const LayerNormParameter ¶m); void CodeStruct(const std::string &name, const BroadcastShapeInfo ¶m); + void CodeStruct(const std::string &name, const CustomGruParameter ¶m); void CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector tensor); private: diff --git a/mindspore/lite/tools/converter/micro/coder/session.cc b/mindspore/lite/tools/converter/micro/coder/session.cc index 471f1491..756b7222 100644 --- a/mindspore/lite/tools/converter/micro/coder/session.cc +++ b/mindspore/lite/tools/converter/micro/coder/session.cc @@ -40,11 +40,38 @@ #include "coder/opcoders/nnacl/dequant/de_quant.h" namespace mindspore::lite::micro { +namespace { +bool IsBuiltInCustomNode(const void *primitive, int schema_version) { + if (!IsCustomNode(primitive, schema_version)) { + return false; + } + const auto &custom = reinterpret_cast(primitive)->value_as_Custom(); + if (custom == nullptr) { + return false; + } + const auto &attrs = custom->attr(); + if (attrs == nullptr) { + return false; + } + for (size_t i = 0; i < attrs->size(); ++i) { + if (attrs->Get(i) == nullptr || attrs->Get(i)->name() == nullptr) { + continue; + } + if (attrs->Get(i)->name()->str() == "builtin") { + return true; + } + } + return false; +} +} // namespace + CoderSession::CoderSession() { allocator_ = MemoryAllocator::GetInstance(); } -void CoderSession::EndCode() { +int CoderSession::PassArgsToContext() { context_->set_tensor_map(allocator_->tensors_map()); context_->set_saved_weights(allocator_->saved_weights()); + context_->set_origin_weights(allocator_->origin_weights()); + context_->set_auxiliary_weights(allocator_->auxiliary_weights()); size_t de_quant_max_workspace_size = nnacl::Dequant::GetInstance()->de_quant_max_workspace(); size_t final_total_size = allocator_->total_buffer_size() > de_quant_max_workspace_size ? allocator_->total_buffer_size() @@ -61,13 +88,20 @@ void CoderSession::EndCode() { if (config->code_mode() == Train) { Train::TransformGraphForTrain(context_.get(), op_coders_, schema_version_); } + if (!context_->JudgeIsValid(Configurator::GetInstance()->keep_original_weight())) { + MS_LOG(ERROR) << "Current model cannot keep-original-weight, due to existing generated tensor-data, please set " + "'keep_original_weight' to false."; + return RET_NOT_SUPPORT; + } + return RET_OK; } int CoderSession::Run() { MS_LOG(INFO) << "start run opcoders"; // 1. assign memory std::vector inputs = coder_graph_->input_tensors(); - int ret = allocator_->Assign(inputs, op_coders_); + int ret = allocator_->Assign(inputs, op_coders_, coder_graph_->all_tensors(), + Configurator::GetInstance()->changeable_weights_name()); MS_CHECK_RET_CODE(ret, "assign memory failed"); // 2. prepare, init model parameters for (const auto &op_coder : op_coders_) { @@ -84,10 +118,10 @@ int CoderSession::Run() { ret = op_coder->DoCode(this->context_.get()); MS_CHECK_RET_CODE(ret, "do coder " << op_coder->name() << " failed"); } - - this->EndCode(); + ret = PassArgsToContext(); + MS_CHECK_RET_CODE(ret, "PassArgsToContext failed"); MS_LOG(INFO) << "run opcoders success"; - return RET_OK; + return ret; } int CoderSession::GenerateCode() { @@ -269,7 +303,9 @@ int CoderSession::CreateOpCoders() { } OpParameter *parameter = nullptr; - if (IsCustomNode(node->primitive_, schema_version_)) { + bool is_custom_op = IsCustomNode(node->primitive_, schema_version_); + bool is_built_in_custom_op = IsBuiltInCustomNode(node->primitive_, schema_version_); + if (is_custom_op && !is_built_in_custom_op) { KernelRegistry::GetInstance()->RegisterKernel(schema::PrimitiveType_Custom); } else { parameter = GenParameterAndInfer(node, inputs, &outputs); // built-in ops infer @@ -287,6 +323,7 @@ int CoderSession::CreateOpCoders() { .mode(code_mode) .input_indices(input_indices) .output_indices(output_indices) + .is_builtin_custom(is_built_in_custom_op) .build(schema_version_); if (op_coder == nullptr) { coder_graph_->DumpUnSupportLayer(code_target); diff --git a/mindspore/lite/tools/converter/micro/coder/session.h b/mindspore/lite/tools/converter/micro/coder/session.h index 3a8f7290..20f6b2b5 100644 --- a/mindspore/lite/tools/converter/micro/coder/session.h +++ b/mindspore/lite/tools/converter/micro/coder/session.h @@ -50,7 +50,7 @@ class CoderSession { int CreateOpCoders(); int InitCodeGraph(); int CompileGraph(); - void EndCode(); + int PassArgsToContext(); std::unique_ptr coder_graph_{nullptr}; std::unique_ptr context_{nullptr}; diff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc index a552da05..3b868b41 100644 --- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc +++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc @@ -103,6 +103,8 @@ std::string GetTensorDataType(TypeId type) { return "uint32_t "; case kNumberTypeInt64: return "int64_t "; + case kNumberTypeFloat16: + return "float16_t "; default: MS_LOG(ERROR) << "unsupported data type: " << EnumNameDataType(type); return ""; @@ -152,7 +154,6 @@ std::string EnumMicroTensorDataType(TypeId type) { case kNumberTypeUInt16: return "DataType_DT_UINT16"; case kNumberTypeFloat16: - MS_LOG(WARNING) << "unsupported data type: kNumberTypeFloat16"; return "DataType_DT_FLOAT16"; default: MS_LOG(WARNING) << "unsupported data type: " << type << ", reference: " << kNumberTypeInt; diff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h index 7753e123..61c7c923 100644 --- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h +++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.h @@ -27,6 +27,7 @@ #include "src/common/log_adapter.h" #include "nnacl/op_base.h" #include "tools/converter/micro/coder/config.h" +#include "base/float16.h" namespace mindspore::lite::micro { std::string EnumNameDataType(TypeId type); @@ -63,7 +64,8 @@ std::string GetVariableTypeName() { {std::type_index(typeid(int16_t *)), "int16_t *"}, {std::type_index(typeid(int8_t *)), "int8_t *"}, {std::type_index(typeid(uint8_t *)), "uint8_t *"}, - {std::type_index(typeid(float *)), "float *"}}; + {std::type_index(typeid(float *)), "float *"}, + {std::type_index(typeid(float16 *)), "float16_t *"}}; auto item = types_name.find(std::type_index(typeid(T))); if (item != types_name.end()) { return item->second; -- 2.17.1