From 46c42f4387b8280a25b16592aa25179093d99fa2 Mon Sep 17 00:00:00 2001
From: z00574805
Date: Wed, 24 May 2023 11:10:00 +0800
Subject: [PATCH 3/5] xiaoyi-0003
---
.jenkins/check/config/filter_cppcheck.txt | 1 +
cmake/package_micro.cmake | 4 +-
include/api/serialization.h | 26 +
.../nnacl/base/minimal_filtering_generator.c | 2 +-
.../nnacl/base/minimal_filtering_generator.h | 1 -
.../cpu/kernel/nnacl/custom_gru_parameter.h | 31 +
.../cpu/kernel/nnacl/fp16/custom_gru_fp16.c | 70 ++
.../cpu/kernel/nnacl/fp16/custom_gru_fp16.h | 32 +
.../cpu/kernel/nnacl/fp32/custom_gru_fp32.c | 71 ++
.../cpu/kernel/nnacl/fp32/custom_gru_fp32.h | 32 +
.../device/cpu/kernel/nnacl/fp32/lstm_fp32.h | 2 +-
.../cpu/kernel/nnacl/infer/custom_gru_infer.c | 45 +
.../cpu/kernel/nnacl/infer/custom_gru_infer.h | 30 +
.../plugin/device/cpu/kernel/nnacl/op_base.h | 1 +
.../lite/include/registry/converter_context.h | 3 +-
mindspore/lite/src/CMakeLists.txt | 1 +
mindspore/lite/src/common/graph_util.cc | 7 +-
.../common/ops/populate/custom_populate.cc | 14 +
.../lite/src/runtime/cxx_api/serialization.cc | 31 +
.../kernel/cpu/base/group_convolution_base.cc | 34 +-
.../cpu/base/group_convolution_creator.cc | 24 +-
.../cpu/base/group_convolution_creator.h | 8 +-
.../cpu/fp16/convolution_delegate_fp16.cc | 2 +-
.../kernel/cpu/fp16/custom_gru_fp16.cc | 132 +++
.../runtime/kernel/cpu/fp16/custom_gru_fp16.h | 40 +
.../cpu/fp32/convolution_delegate_fp32.cc | 2 +-
.../kernel/cpu/fp32/custom_gru_fp32.cc | 251 ++++++
.../runtime/kernel/cpu/fp32/custom_gru_fp32.h | 51 ++
.../cpu/int8/convolution_int8_creator.cc | 2 +-
mindspore/lite/src/runtime/lite_session.h | 6 +
mindspore/lite/src/train/graph_fusion.cc | 7 +
.../train/optimizer/fusion/gru_fusion_pass.cc | 809 ++++++++++++++++++
.../train/optimizer/fusion/gru_fusion_pass.h | 45 +
mindspore/lite/src/train/static_allocator.h | 6 +-
mindspore/lite/src/train/train_export.cc | 38 +
mindspore/lite/src/train/train_export.h | 1 +
mindspore/lite/src/train/train_session.cc | 37 +-
mindspore/lite/src/train/train_session.h | 3 +
.../test/config_level0/micro/micro_arm64.cfg | 7 +
.../config_parser/config_file_parser.cc | 13 +-
.../config_parser/config_file_parser.h | 2 +
.../config_parser/micro_param_parser.cc | 33 +
.../config_parser/micro_param_parser.h | 2 +
mindspore/lite/tools/converter/converter.cc | 18 +-
.../converter_lite/converter_flags.cc | 4 +-
.../converter/micro/cmake/file_list.cmake | 3 +
.../micro/coder/allocator/allocator.cc | 82 +-
.../micro/coder/allocator/allocator.h | 13 +-
.../lite/tools/converter/micro/coder/coder.cc | 51 +-
.../lite/tools/converter/micro/coder/coder.h | 9 +-
.../lite/tools/converter/micro/coder/config.h | 10 +
.../tools/converter/micro/coder/context.h | 25 +-
.../generator/component/common_component.cc | 5 +-
.../generator/component/weight_component.cc | 301 +++++--
.../generator/component/weight_component.h | 2 +-
.../micro/coder/generator/generator.cc | 2 +-
.../coder/opcoders/base/reshape_base_coder.cc | 5 +
.../coder/opcoders/base/stack_base_coder.cc | 85 ++
.../coder/opcoders/base/stack_base_coder.h | 42 +
.../opcoders/base/strided_slice_base_coder.cc | 21 +
.../nnacl/fp16/custom_gru_fp16_coder.cc | 34 +
.../nnacl/fp16/custom_gru_fp16_coder.h | 44 +
.../nnacl/fp16/matmul_fp16_base_coder.cc | 22 +-
.../nnacl/fp16/matmul_fp16_base_coder.h | 4 +-
.../opcoders/nnacl/fp16/matmul_fp16_coder.h | 4 +-
.../fp32/convolution_depthwise_fp32_coder.cc | 69 +-
.../fp32/convolution_depthwise_fp32_coder.h | 8 +-
.../fp32/convolution_winograd_fp32_coder.cc | 98 ++-
.../fp32/convolution_winograd_fp32_coder.h | 15 +-
.../nnacl/fp32/custom_gru_fp32_coder.cc | 214 +++++
.../nnacl/fp32/custom_gru_fp32_coder.h | 64 ++
.../opcoders/nnacl/fp32/gather_fp32_coder.cc | 59 +-
.../opcoders/nnacl/fp32/gather_fp32_coder.h | 2 +
.../nnacl/fp32/matmul_fp32_base_coder.cc | 41 +-
.../nnacl/fp32/matmul_fp32_base_coder.h | 2 +-
.../micro/coder/opcoders/op_coder_builder.cc | 13 +
.../micro/coder/opcoders/op_coder_builder.h | 4 +
.../micro/coder/opcoders/op_coder_register.cc | 8 +-
.../micro/coder/opcoders/op_coder_register.h | 23 +-
.../nnacl_serializer/nnacl_fp32_serializer.cc | 5 +
.../nnacl_serializer/nnacl_fp32_serializer.h | 2 +
.../tools/converter/micro/coder/session.cc | 49 +-
.../tools/converter/micro/coder/session.h | 2 +-
.../converter/micro/coder/utils/type_cast.cc | 3 +-
.../converter/micro/coder/utils/type_cast.h | 4 +-
85 files changed, 3163 insertions(+), 267 deletions(-)
create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h
create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c
create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h
create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c
create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h
create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c
create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc
create mode 100644 mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h
create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc
create mode 100644 mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.h
create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.cc
create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/base/stack_base_coder.h
create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc
create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.h
create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.cc
create mode 100644 mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/custom_gru_fp32_coder.h
diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt
index 6aeb515a..6f6a08f5 100644
--- a/.jenkins/check/config/filter_cppcheck.txt
+++ b/.jenkins/check/config/filter_cppcheck.txt
@@ -56,6 +56,7 @@
"mindspore/mindspore/lite/tools/converter/quantizer/quantize_util.cc" "useStlAlgorithm"
"mindspore/mindspore/lite/src/runtime/kernel/opencl/kernel/" "unreadVariable"
"mindspore/mindspore/lite/src/runtime/kernel/opencl/cl/" "unreadVariable"
+"mindspore/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc" "stlFindInsert"
"mindspore/mindspore/lite/examples/quick_start_micro/" "syntaxError"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental" "unreadVariable"
"mindspore/mindspore/lite/python/src/pybind_module.cc" "syntaxError"
diff --git a/cmake/package_micro.cmake b/cmake/package_micro.cmake
index 3c6da3db..0481e3c3 100644
--- a/cmake/package_micro.cmake
+++ b/cmake/package_micro.cmake
@@ -10,6 +10,8 @@ function(__install_micro_wrapper)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${NNACL_DIR}/fp32 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
+ install(DIRECTORY ${NNACL_DIR}/fp16 DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
+ COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${NNACL_DIR}/kernel DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${NNACL_DIR}/infer DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl
@@ -34,4 +36,4 @@ function(__install_micro_codegen)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(TARGETS cmsis_nn ARCHIVE DESTINATION ${CODEGEN_ROOT_DIR}/third_party/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
-endfunction()
\ No newline at end of file
+endfunction()
diff --git a/include/api/serialization.h b/include/api/serialization.h
index 1a0c1f57..76d5dbec 100644
--- a/include/api/serialization.h
+++ b/include/api/serialization.h
@@ -105,6 +105,21 @@ class MS_API Serialization {
QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
std::vector output_tensor_name = {});
+ /// \brief Export model's weights, which can be used in micro only.
+ ///
+ /// \param[in] model The model data.
+ /// \param[in] model_type The model file type.
+ /// \param[in] weight_file The path of exported weight file.
+ /// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
+ /// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format.
+ /// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
+ ///
+ /// \return Status.
+ inline static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
+ const std::string &weight_file, bool is_inference = true,
+ bool enable_fp16 = false,
+ const std::vector &changeable_weights_name = {});
+
private:
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
const std::vector &dec_mode);
@@ -119,6 +134,10 @@ class MS_API Serialization {
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
QuantizationType quantization_type, bool export_inference_only,
const std::vector> &output_tensor_name);
+ static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
+ const std::vector &weight_file, bool is_inference,
+ bool enable_fp16,
+ const std::vector> &changeable_weights_name);
};
Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
@@ -150,5 +169,12 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, Buff
VectorStringToChar(output_tensor_name));
}
+Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
+ const std::string &weight_file, bool is_inference,
+ bool enable_fp16,
+ const std::vector &changeable_weights_name) {
+ return ExportWeightsCollaborateWithMicro(model, model_type, StringToChar(weight_file), is_inference, enable_fp16,
+ VectorStringToChar(changeable_weights_name));
+}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c
index 3796e47b..81bf8ddf 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.c
@@ -16,9 +16,9 @@
#include "nnacl/base/minimal_filtering_generator.h"
#include
#include
-#include "nnacl/fp32/winograd_utils.h"
#include "nnacl/errorcode.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
+#include "nnacl/fp32/pack_fp32.h"
void Polynomial(const float *interval, float *m, int degree) {
for (int i = 0; i < degree; ++i) {
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h
index 01b013e8..fc0fa0e6 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/minimal_filtering_generator.h
@@ -21,7 +21,6 @@
#include
#endif
#include
-#include "nnacl/pack.h"
#ifdef __cplusplus
extern "C" {
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h
new file mode 100644
index 00000000..3bb8a444
--- /dev/null
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gru_parameter.h
@@ -0,0 +1,31 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_
+#define MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_
+
+#include "nnacl/op_base.h"
+
+typedef struct CustomGruParameter {
+ // Primitive parameter
+ OpParameter op_parameter_;
+ // shape correlative
+ int num_step;
+ int batch_size;
+ int input_size;
+ int hidden_size;
+} CustomGruParameter;
+
+#endif // MINDSPORE_NNACL_CUSTOM_GRU_PARAMETER_H_
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c
new file mode 100644
index 00000000..6e754569
--- /dev/null
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c
@@ -0,0 +1,70 @@
+#ifdef ENABLE_ARM64
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "nnacl/fp16/custom_gru_fp16.h"
+#include "nnacl/fp16/activation_fp16.h"
+#include "nnacl/fp16/arithmetic_fp16.h"
+#include "nnacl/fp16/matmul_fp16.h"
+
+void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input,
+ const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden,
+ const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param) {
+ int num_step = gru_param->num_step;
+ int batch_size = gru_param->batch_size;
+ int input_size = gru_param->input_size;
+ int hidden_size = gru_param->hidden_size;
+ int output_size = batch_size * hidden_size;
+ int double_output_size = output_size * C2NUM;
+ int col_align = UP_ROUND(hidden_size, C8NUM);
+ int weight_in_offset = col_align * input_size;
+ int weight_hidden_offset = col_align * hidden_size;
+ float16_t *input_gate = buffer[1];
+ float16_t *hidden_gate = buffer[C3NUM];
+ for (int i = 0; i < num_step; ++i) {
+ if (batch_size != 1) {
+ RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size);
+ for (int j = 0; j < C3NUM; ++j) {
+ MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size,
+ bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size,
+ OutType_Nhwc);
+ }
+ RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size);
+ for (int j = 0; j < C3NUM; ++j) {
+ MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
+ bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size,
+ OutType_Nhwc);
+ }
+ } else {
+ for (int j = 0; j < C3NUM; ++j) {
+ VecMatmulFp16(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size,
+ bias_input + j * col_align, ActType_No, input_size, hidden_size);
+ VecMatmulFp16(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
+ bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size);
+ }
+ }
+ ElementAddFp16(input_gate, hidden_gate, input_gate, double_output_size);
+ SigmoidFp16(input_gate, input_gate, double_output_size);
+ ElementMulFp16(input_gate, hidden_gate + double_output_size, input_gate, output_size);
+ ElementAddFp16(input_gate, input_gate + double_output_size, input_gate, output_size);
+ TanhFp16(input_gate, input_gate, output_size);
+ ElementSubFp16(init_h, input_gate, hidden_gate, output_size);
+ ElementMulFp16(input_gate + output_size, hidden_gate, hidden_gate, output_size);
+ ElementAddFp16(input_gate, hidden_gate, output, output_size);
+ init_h = output;
+ output += output_size;
+ }
+}
+#endif
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h
new file mode 100644
index 00000000..67008f03
--- /dev/null
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.h
@@ -0,0 +1,32 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_
+#define MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_
+#ifdef ENABLE_ARM64
+#include "nnacl/custom_gru_parameter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input,
+ const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden,
+ const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param);
+#ifdef __cplusplus
+}
+#endif
+
+#endif
+#endif // MINDSPORE_NNACL_FP16_CUSTOM_GRU_FP16_H_
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c
new file mode 100644
index 00000000..caeece4a
--- /dev/null
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.c
@@ -0,0 +1,71 @@
+#ifdef ENABLE_ARM64
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "nnacl/fp32/custom_gru_fp32.h"
+#include "nnacl/fp32/activation_fp32.h"
+#include "nnacl/fp32/arithmetic_fp32.h"
+#include "nnacl/fp32/matmul_fp32.h"
+#include "nnacl/fp32/pack_fp32.h"
+
+void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden,
+ const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4],
+ const CustomGruParameter *gru_param) {
+ int num_step = gru_param->num_step;
+ int batch_size = gru_param->batch_size;
+ int input_size = gru_param->input_size;
+ int hidden_size = gru_param->hidden_size;
+ int output_size = batch_size * hidden_size;
+ int double_output_size = output_size * C2NUM;
+ int col_align = UP_ROUND(hidden_size, C8NUM);
+ int weight_in_offset = col_align * input_size;
+ int weight_hidden_offset = col_align * hidden_size;
+ float *input_gate = buffer[1];
+ float *hidden_gate = buffer[C3NUM];
+ for (int i = 0; i < num_step; ++i) {
+ if (batch_size != 1) {
+ RowMajor2Col12Major(input + i * batch_size * input_size, buffer[0], batch_size, input_size);
+ for (int j = 0; j < C3NUM; ++j) {
+ MatMulOpt(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size,
+ bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size,
+ OutType_Nhwc);
+ }
+ RowMajor2Col12Major(init_h, buffer[C2NUM], batch_size, hidden_size);
+ for (int j = 0; j < C3NUM; ++j) {
+ MatMulOpt(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
+ bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size,
+ OutType_Nhwc);
+ }
+ } else {
+ for (int j = 0; j < C3NUM; ++j) {
+ MatVecMulFp32Neon64(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size,
+ bias_input + j * col_align, ActType_No, input_size, hidden_size, col_align);
+ MatVecMulFp32Neon64(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size,
+ bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size, col_align);
+ }
+ }
+ ElementAdd(input_gate, hidden_gate, input_gate, double_output_size);
+ Sigmoid(input_gate, double_output_size, input_gate);
+ ElementMul(input_gate, hidden_gate + double_output_size, input_gate, output_size);
+ ElementAdd(input_gate, input_gate + double_output_size, input_gate, output_size);
+ Tanh(input_gate, output_size, input_gate);
+ ElementSub(init_h, input_gate, hidden_gate, output_size);
+ ElementMul(input_gate + output_size, hidden_gate, hidden_gate, output_size);
+ ElementAdd(input_gate, hidden_gate, output, output_size);
+ init_h = output;
+ output += output_size;
+ }
+}
+#endif
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h
new file mode 100644
index 00000000..576726c5
--- /dev/null
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/custom_gru_fp32.h
@@ -0,0 +1,32 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_
+#define MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_
+#ifdef ENABLE_ARM64
+#include "nnacl/custom_gru_parameter.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden,
+ const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4],
+ const CustomGruParameter *gru_param);
+#ifdef __cplusplus
+}
+#endif
+
+#endif
+#endif // MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h
index b608e1e0..8e217d02 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h
@@ -43,7 +43,7 @@ int ElementOptMulAcc(const float *input0, const float input1, float *output, con
void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate,
const float *state_weight, const float *state_bias, float *hidden_state, float *cell_state,
- float *buffer[6], const LstmParameter *lstm_param);
+ float *buffer[C6NUM], const LstmParameter *lstm_param);
void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias,
const float *state_bias, float *hidden_state, float *cell_state, float *buffer[7],
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c
new file mode 100644
index 00000000..060d04cf
--- /dev/null
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.c
@@ -0,0 +1,45 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "nnacl/infer/custom_gru_infer.h"
+#include "nnacl/infer/infer_register.h"
+
+int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
+ OpParameter *parameter) {
+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C6NUM, 1);
+ if (check_ret != NNACL_OK) {
+ return check_ret;
+ }
+
+ const TensorC *input = inputs[0];
+ TensorC *output = outputs[0];
+ SetDataTypeFormat(output, input);
+ if (!InferFlag(inputs, inputs_size)) {
+ return NNACL_INFER_INVALID;
+ }
+ if (input->shape_size_ != C3NUM) {
+ return NNACL_INPUT_TENSOR_ERROR;
+ }
+ SetShapeTensor(output, input);
+ const TensorC *weight_in = inputs[1];
+ if (weight_in->shape_size_ != C2NUM || weight_in->shape_[0] % C3NUM != 0) {
+ return NNACL_INPUT_TENSOR_ERROR;
+ }
+ output->shape_[C2NUM] = weight_in[0].shape_[0] / C3NUM;
+ return NNACL_OK;
+}
+
+REG_INFER(CustomGru, PrimType_Inner_CustomGru, CustomGruInferShape)
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h
new file mode 100644
index 00000000..830150d5
--- /dev/null
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gru_infer.h
@@ -0,0 +1,30 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef MINDSPORE_NNACL_CUSTOM_GRU_INFER_H
+#define MINDSPORE_NNACL_CUSTOM_GRU_INFER_H
+#include "nnacl/infer/common_infer.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
+ OpParameter *parameter);
+
+#ifdef __cplusplus
+}
+#endif
+#endif // MINDSPORE_NNACL_CUSTOM_GRU_INFER_H
diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
index 5876bdf6..8c219212 100644
--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
@@ -520,6 +520,7 @@ enum PrimType {
PrimType_Inner_ShapeFusion = 10003,
PrimType_Inner_GraphKernel = 10004,
PrimType_Inner_ThirdPartyModel = 10005,
+ PrimType_Inner_CustomGru = 10006,
PrimType_InnerOpMax,
PrimType_InnerOpMin = PrimType_Inner_ToFormat
};
diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h
index dd6e6d08..0d2de256 100644
--- a/mindspore/lite/include/registry/converter_context.h
+++ b/mindspore/lite/include/registry/converter_context.h
@@ -34,7 +34,8 @@ enum MS_API FmkType : int {
kFmkTypeTflite = 4,
kFmkTypePytorch = 5,
kFmkTypeThirdParty = 6,
- kFmkTypeEnd = 7, // For range check purpose, valid range: [0, kFmkTypeEnd)
+ kFmkTypeMsLite = 7,
+ kFmkTypeEnd = 8, // For range check purpose, valid range: [0, kFmkTypeEnd)
};
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
index 48e0fe7c..a80656f2 100644
--- a/mindspore/lite/src/CMakeLists.txt
+++ b/mindspore/lite/src/CMakeLists.txt
@@ -380,6 +380,7 @@ set(TRAIN_SRC
${CMAKE_CURRENT_SOURCE_DIR}/train/train_export.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/opt_allocator.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/gru_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/common/storage.cc
diff --git a/mindspore/lite/src/common/graph_util.cc b/mindspore/lite/src/common/graph_util.cc
index 1688fb79..d5cb2a98 100644
--- a/mindspore/lite/src/common/graph_util.cc
+++ b/mindspore/lite/src/common/graph_util.cc
@@ -23,6 +23,7 @@
#include "src/common/log_adapter.h"
#include "src/common/version_manager.h"
#include "include/errorcode.h"
+#include "nnacl/op_base.h"
namespace mindspore {
namespace lite {
@@ -86,9 +87,9 @@ std::vector GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor
// only support op_type from current schema
bool IsPackedOp(int op_type) {
- static const std::vector packed_ops = {schema::PrimitiveType_Conv2DFusion,
- schema::PrimitiveType_Conv2dTransposeFusion,
- schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion};
+ static const std::vector packed_ops = {
+ schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_Conv2dTransposeFusion,
+ schema::PrimitiveType_FullConnection, schema::PrimitiveType_MatMulFusion, PrimType::PrimType_Inner_CustomGru};
return IsContain(packed_ops, op_type);
}
diff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc
index f1506ece..391a587b 100644
--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc
+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc
@@ -17,10 +17,22 @@
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#include "nnacl/custom_parameter.h"
+#include "nnacl/custom_gru_parameter.h"
using mindspore::schema::PrimitiveType_Custom;
namespace mindspore {
namespace lite {
+OpParameter *CreateCustomGruParameter() {
+ auto *param = static_cast(malloc(sizeof(CustomGruParameter)));
+ if (param == nullptr) {
+ MS_LOG(ERROR) << "malloc CustomGruParameter failed.";
+ return nullptr;
+ }
+ memset(param, 0, sizeof(CustomGruParameter));
+ param->op_parameter_.type_ = PrimType_Inner_CustomGru;
+ return reinterpret_cast(param);
+}
+
OpParameter *PopulateCustomParameter(const void *prim) {
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
auto primitive = static_cast(prim);
@@ -62,6 +74,8 @@ OpParameter *PopulateCustomParameter(const void *prim) {
// Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary.
param->attr_data[0] = static_cast(const_cast(prim));
return reinterpret_cast(param);
+ } else if (type == "CustomGRU") {
+ return CreateCustomGruParameter();
} else {
MS_LOG(ERROR) << "Unsupported custom type: " << type;
}
diff --git a/mindspore/lite/src/runtime/cxx_api/serialization.cc b/mindspore/lite/src/runtime/cxx_api/serialization.cc
index 8405f4b2..ddf69d23 100644
--- a/mindspore/lite/src/runtime/cxx_api/serialization.cc
+++ b/mindspore/lite/src/runtime/cxx_api/serialization.cc
@@ -212,4 +212,35 @@ Status Serialization::ExportModel(const Model &model, ModelType model_type, cons
return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
}
+
+Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
+ const std::vector &weight_file, bool is_inference,
+ bool enable_fp16,
+ const std::vector> &changeable_weights_name) {
+ if (model.impl_ == nullptr) {
+ MS_LOG(ERROR) << "Model implement is null.";
+ return kLiteUninitializedObj;
+ }
+ if (!model.impl_->IsTrainModel()) {
+ MS_LOG(ERROR) << "Model is not TrainModel.";
+ return kLiteError;
+ }
+ if (model_type != kMindIR && model_type != kMindIR_Lite) {
+ MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
+ return kLiteParamInvalid;
+ }
+ if (model.impl_->session_ == nullptr) {
+ MS_LOG(ERROR) << "Model session is nullptr.";
+ return kLiteError;
+ }
+ if (!is_inference) {
+ MS_LOG(ERROR) << "Currently, can only export inference-model's weights.";
+ return kLiteNotSupport;
+ }
+ auto ret = model.impl_->session_->ExportWeightsCollaborateWithMicro(CharToString(weight_file), lite::MT_INFERENCE,
+ lite::FT_FLATBUFFERS, enable_fp16,
+ VectorCharToString(changeable_weights_name));
+
+ return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
+}
} // namespace mindspore
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
index b5370ddd..0352ad19 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_base.cc
@@ -35,6 +35,8 @@ int GroupConvolutionBaseCPUKernel::Prepare() {
return ret;
}
}
+ conv_param_->input_channel_ *= group_num_;
+ conv_param_->output_channel_ *= group_num_;
// if infer shape is done, resize func will be invoked in sub kernels
return RET_OK;
}
@@ -99,11 +101,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front();
CHECK_NULL_RETURN(sub_kernel_in_tensor);
sub_kernel_in_tensor->set_shape(in_shape);
- ret = sub_kernel_in_tensor->MallocData();
- if (ret != RET_OK) {
- MS_LOG(ERROR) << "sub kernel in tensor malloc data failed.";
- return ret;
- }
// out
auto out_tensor = out_tensors_.front();
CHECK_NULL_RETURN(out_tensor);
@@ -113,11 +110,6 @@ int GroupConvolutionBaseCPUKernel::PreProcess() {
for (auto tensor : sub_kernel_out_tensors) {
CHECK_NULL_RETURN(tensor);
tensor->set_shape(out_shape);
- ret = tensor->MallocData();
- if (ret != RET_OK) {
- MS_LOG(ERROR) << "sub kernel out tensor malloc data failed.";
- return ret;
- }
}
}
ret = ReSize();
@@ -177,7 +169,22 @@ int GroupConvolutionBaseCPUKernel::Run() {
ori_out_data_ = out_tensors_[0]->data();
CHECK_NULL_RETURN(ori_out_data_);
for (int i = 0; i < group_num_; ++i) {
- // first, separate group conv input into several parts. This step must be in runtime stage.
+ // first, malloc data for sub_kernel's tensors
+ auto sub_kernel_in_tensor = group_convs_.at(i)->in_tensors().front();
+ ret = sub_kernel_in_tensor->MallocData();
+ if (ret != RET_OK) {
+ MS_LOG(ERROR) << "sub kernel in tensor malloc data failed.";
+ return ret;
+ }
+ auto sub_kernel_out_tensors = group_convs_.at(i)->out_tensors();
+ for (auto tensor : sub_kernel_out_tensors) {
+ ret = tensor->MallocData();
+ if (ret != RET_OK) {
+ MS_LOG(ERROR) << "sub kernel out tensor malloc data failed.";
+ return ret;
+ }
+ }
+ // second, separate group conv input into several parts. This step must be in runtime stage.
ret = SeparateInput(i);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Separate input failed.";
@@ -195,6 +202,11 @@ int GroupConvolutionBaseCPUKernel::Run() {
MS_LOG(ERROR) << "Concat output failed.";
return ret;
}
+ // free data
+ sub_kernel_in_tensor->FreeData();
+ for (auto tensor : sub_kernel_out_tensors) {
+ tensor->FreeData();
+ }
}
return RET_OK;
}
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
index fc78a887..81b0aac2 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.cc
@@ -53,15 +53,6 @@ void FreeCurrentConv(ConvParameter *conv_param, std::vector *new
}
}
-static inline lite::Tensor *TensorMalloc(lite::Tensor *tensor) {
- if (tensor->MallocData() != lite::RET_OK) {
- delete tensor;
- MS_LOG(ERROR) << "malloc tensor data failed.";
- return nullptr;
- }
- return tensor;
-}
-
lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vector &shape, const int index) {
auto new_tensor =
new (std::nothrow) lite::Tensor(tensor->data_type(), shape, mindspore::NHWC, lite::Category::CONST_TENSOR);
@@ -87,7 +78,7 @@ lite::Tensor *CreateConstTensor(const lite::Tensor *tensor, const std::vectorset_format(tensor_info.format_);
tensor->set_category(tensor_info.tensor_type_);
tensor->set_shape(tensor_info.shape_);
-
- if (inferred) {
- // set shape of out tensor
- return TensorMalloc(tensor);
- }
+ tensor->set_allocator(tensor_info.allocator_);
return tensor;
}
@@ -129,7 +116,8 @@ void GroupConvCreator::FreeGroupConvs() {
}
int GroupConvCreator::NewInputTensor(std::vector *tensors) {
- auto in_tensor = CreateVarTensor({input_shape_, mindspore::NHWC, data_type_, lite::Category::VAR, true}, infered_);
+ auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr;
+ auto in_tensor = CreateVarTensor({input_shape_, allocator, mindspore::NHWC, data_type_, lite::Category::VAR, true});
if (in_tensor == nullptr) {
return lite::RET_ERROR;
}
@@ -138,7 +126,9 @@ int GroupConvCreator::NewInputTensor(std::vector *tensors) {
}
int GroupConvCreator::NewOutputTensor(std::vector *tensors, const lite::Tensor *output) const {
- auto out_tensor = CreateVarTensor({output_shape_, output->format(), data_type_, output->category(), false}, infered_);
+ auto allocator = ms_context_ != nullptr ? ms_context_->allocator : nullptr;
+ auto out_tensor =
+ CreateVarTensor({output_shape_, allocator, output->format(), data_type_, output->category(), false});
if (out_tensor == nullptr) {
return lite::RET_ERROR;
}
diff --git a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h
index b4a4f768..27aa0cc8 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h
+++ b/mindspore/lite/src/runtime/kernel/cpu/base/group_convolution_creator.h
@@ -22,10 +22,12 @@
#include "src/runtime/lite_kernel.h"
#include "nnacl/conv_parameter.h"
#include "src/runtime/tensor_category.h"
+#include "include/api/allocator.h"
namespace mindspore::kernel {
struct TensorInfo {
std::vector shape_;
+ AllocatorPtr allocator_;
mindspore::Format format_;
TypeId data_type_;
lite::Category tensor_type_;
@@ -35,8 +37,9 @@ struct TensorInfo {
class GroupConvCreator {
public:
GroupConvCreator(std::vector inputs, std::vector outputs, OpParameter *op_parameter,
- bool is_quant, TypeId data_type)
- : origin_inputs_(std::move(inputs)),
+ bool is_quant, TypeId data_type, const lite::InnerContext *ctx)
+ : ms_context_(ctx),
+ origin_inputs_(std::move(inputs)),
origin_outputs_(std::move(outputs)),
is_quant_(is_quant),
data_type_(data_type) {
@@ -64,6 +67,7 @@ class GroupConvCreator {
int NewOutputTensor(std::vector *tensors, const lite::Tensor *output) const;
private:
+ const lite::InnerContext *ms_context_ = nullptr;
std::vector origin_inputs_;
std::vector origin_outputs_;
std::vector group_convs_;
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc
index 7aa823b0..17ba38ff 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/convolution_delegate_fp16.cc
@@ -202,7 +202,7 @@ kernel::LiteKernel *CpuGroupConvFp16KernelCreator(const std::vector &outputs, OpParameter *op_parameter,
const InnerContext *ctx) {
auto *group_conv_creator =
- new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16);
+ new (std::nothrow) GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat16, ctx);
if (group_conv_creator == nullptr) {
MS_LOG(ERROR) << "new GroupConvCreator fail";
free(op_parameter);
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc
new file mode 100644
index 00000000..7851eecb
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.cc
@@ -0,0 +1,132 @@
+#ifdef ENABLE_ARM64
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "src/runtime/kernel/cpu/fp16/custom_gru_fp16.h"
+#include
+#include "src/runtime/kernel_registry.h"
+#include "include/errorcode.h"
+#include "src/common/log_adapter.h"
+#include "src/runtime/pack_weight_manager.h"
+#include "nnacl/custom_gru_parameter.h"
+#include "nnacl/fp16/custom_gru_fp16.h"
+#include "nnacl/fp16/matmul_fp16.h"
+
+using mindspore::lite::KernelRegistrar;
+using mindspore::lite::RET_ERROR;
+using mindspore::lite::RET_NOT_SUPPORT;
+using mindspore::lite::RET_OK;
+
+namespace mindspore::kernel {
+int CustomGruFp16CPUKernel::InitWeightAndBias() {
+ auto weight_shape = in_tensors_[1]->shape();
+ auto hidden_size = weight_shape[0] / C3NUM;
+ auto col_align = UP_ROUND(hidden_size, col_tile_);
+ auto weight_in_pack_size = static_cast(col_align * weight_shape[1]) * sizeof(float16_t);
+ bool is_packed = false;
+ weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData(
+ in_tensors_[SECOND_INPUT]->data(), static_cast(weight_in_pack_size * C3NUM), &is_packed);
+ MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed.");
+ if (!is_packed) {
+ auto weight_in_src = static_cast(in_tensors_[SECOND_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ RowMajor2Col8MajorFp16(weight_in_src + i * hidden_size * weight_shape[1],
+ static_cast(weight_in_) + i * col_align * weight_shape[1], hidden_size,
+ weight_shape[1], false);
+ }
+ }
+ auto weight_hidden_pack_size = static_cast(col_align * hidden_size) * sizeof(float16_t);
+ is_packed = false;
+ weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData(
+ in_tensors_[THIRD_INPUT]->data(), static_cast(weight_hidden_pack_size * C3NUM), &is_packed);
+ MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed.");
+ if (!is_packed) {
+ auto weight_hidden_src = static_cast(in_tensors_[THIRD_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ RowMajor2Col8MajorFp16(weight_hidden_src + i * hidden_size * weight_shape[1],
+ static_cast(weight_hidden_) + i * col_align * weight_shape[1], hidden_size,
+ hidden_size, false);
+ }
+ }
+ auto bias_pack_size = static_cast(col_align) * sizeof(float16_t);
+ auto bias = reinterpret_cast(malloc(bias_pack_size * C6NUM));
+ if (bias == nullptr) {
+ MS_LOG(ERROR) << "malloc for packing bias failed.";
+ return lite::RET_NULL_PTR;
+ }
+ (void)memset(bias, 0, bias_pack_size * C6NUM);
+ bias_in_ = bias;
+ bias_hidden_ = bias + col_align * C3NUM;
+ auto bias_in_src = static_cast(in_tensors_[FOURTH_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float16_t));
+ }
+ auto bias_hidden_src = static_cast(in_tensors_[FIFTH_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float16_t));
+ }
+ if (in_tensors_[SIXTH_INPUT]->IsConst()) {
+ init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size());
+ MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed.");
+ (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size());
+ }
+ return RET_OK;
+}
+
+int CustomGruFp16CPUKernel::Run() {
+ auto input = reinterpret_cast(in_tensors_[FIRST_INPUT]->data());
+ if (input == nullptr) {
+ MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ if (!in_tensors_[SIXTH_INPUT]->IsConst()) {
+ init_h_ = in_tensors_[SIXTH_INPUT]->data();
+ }
+ if (init_h_ == nullptr) {
+ MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ auto output = reinterpret_cast(out_tensors_.front()->data());
+ if (output == nullptr) {
+ MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ MallocRunBuffer(sizeof(float16_t));
+ if (run_buffer_ == nullptr) {
+ MS_LOG(ERROR) << "malloc running buffer failed." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ auto param = reinterpret_cast(op_parameter_);
+ auto row_align = UP_ROUND(param->batch_size, row_tile_);
+ auto run_buffer = reinterpret_cast(run_buffer_);
+ float16_t *buffer[C4NUM] = {
+ run_buffer, run_buffer + row_align * param->input_size,
+ run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM,
+ run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM};
+ CustomGruFp16(output, input, static_cast(weight_in_), static_cast(weight_hidden_),
+ static_cast(bias_in_), static_cast(bias_hidden_),
+ static_cast(init_h_), buffer, param);
+ if (ms_context_->allocator != nullptr) {
+ ms_context_->allocator->Free(run_buffer_);
+ } else {
+ free(run_buffer_);
+ }
+ run_buffer_ = nullptr;
+ return RET_OK;
+}
+
+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimType_Inner_CustomGru, LiteKernelCreator)
+} // namespace mindspore::kernel
+#endif
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h
new file mode 100644
index 00000000..d7ed313a
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp16/custom_gru_fp16.h
@@ -0,0 +1,40 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_
+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_
+#ifdef ENABLE_ARM64
+#include
+#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h"
+
+namespace mindspore::kernel {
+class CustomGruFp16CPUKernel : public CustomGruCPUKernel {
+ public:
+ CustomGruFp16CPUKernel(OpParameter *parameter, const std::vector &inputs,
+ const std::vector &outputs, const lite::InnerContext *ctx)
+ : CustomGruCPUKernel(parameter, inputs, outputs, ctx) {
+ row_tile_ = C4NUM;
+ col_tile_ = C8NUM;
+ }
+ ~CustomGruFp16CPUKernel() override = default;
+ int Run() override;
+
+ protected:
+ int InitWeightAndBias() override;
+};
+} // namespace mindspore::kernel
+#endif
+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_CUSTOM_GRU_FP16_H_
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc
index bbbf8488..3514e5b4 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/convolution_delegate_fp32.cc
@@ -359,7 +359,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector
kernel::LiteKernel *CpuGroupConvFp32KernelCreator(const std::vector &inputs,
const std::vector &outputs, OpParameter *op_parameter,
const lite::InnerContext *ctx) {
- auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32);
+ auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, false, kNumberTypeFloat32, ctx);
auto group_kernel = new (std::nothrow) GroupConvolutionFp32CPUKernel(
op_parameter, inputs, outputs, ctx, group_conv_creator, reinterpret_cast(op_parameter)->group_);
if (group_kernel == nullptr) {
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc
new file mode 100644
index 00000000..c85a1283
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.cc
@@ -0,0 +1,251 @@
+#ifdef ENABLE_ARM64
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "src/runtime/kernel/cpu/fp32/custom_gru_fp32.h"
+#include
+#include "src/runtime//kernel_registry.h"
+#include "include/errorcode.h"
+#include "src/common/log_adapter.h"
+#include "src/runtime//pack_weight_manager.h"
+#include "nnacl/fp32/pack_fp32.h"
+#include "nnacl/custom_gru_parameter.h"
+#include "nnacl/fp32/custom_gru_fp32.h"
+
+using mindspore::lite::KernelRegistrar;
+using mindspore::lite::RET_ERROR;
+using mindspore::lite::RET_NOT_SUPPORT;
+using mindspore::lite::RET_OK;
+
+namespace mindspore::kernel {
+CustomGruCPUKernel::~CustomGruCPUKernel() {
+ if (weight_in_) {
+ lite::PackWeightManager::GetInstance()->Free(weight_in_);
+ weight_in_ = nullptr;
+ }
+ if (weight_hidden_) {
+ lite::PackWeightManager::GetInstance()->Free(weight_hidden_);
+ weight_hidden_ = nullptr;
+ }
+ if (bias_in_) {
+ free(bias_in_);
+ bias_in_ = nullptr;
+ bias_hidden_ = nullptr;
+ }
+ if (in_tensors_[SIXTH_INPUT]->IsConst() && init_h_) {
+ free(init_h_);
+ init_h_ = nullptr;
+ }
+}
+
+int CustomGruCPUKernel::Prepare() {
+ CHECK_LESS_RETURN(in_tensors_.size(), C6NUM);
+ CHECK_LESS_RETURN(out_tensors_.size(), 1);
+ if (in_tensors_[FIRST_INPUT]->IsConst()) {
+ MS_LOG(ERROR) << "Built-in CustomGru first-input must be a variable." << name_;
+ return RET_NOT_SUPPORT;
+ }
+ for (size_t i = 1; i < C5NUM; ++i) {
+ if (!in_tensors_[i]->IsConst()) {
+ MS_LOG(ERROR) << "Built-in CustomGru only support first-input and last-input is variable." << name_;
+ return RET_NOT_SUPPORT;
+ }
+ }
+ if (InitParamter() != RET_OK) {
+ MS_LOG(ERROR) << "Init Built-in CustomGru Parameter failed." << name_;
+ return RET_ERROR;
+ }
+ if (InitWeightAndBias() != RET_OK) {
+ MS_LOG(ERROR) << "Init Built-in CustomGru Weight and bias failed." << name_;
+ return RET_ERROR;
+ }
+ if (!InferShapeDone()) {
+ return RET_OK;
+ }
+ return ReSize();
+}
+
+int CustomGruCPUKernel::InitParamter() {
+ auto param = reinterpret_cast(op_parameter_);
+ thread_num_ = 1;
+ auto weight_in_shape = in_tensors_[1]->shape();
+ auto weight_hidden_shape = in_tensors_[C2NUM]->shape();
+ if (weight_in_shape.size() != C2NUM || weight_hidden_shape.size() != C2NUM) {
+ MS_LOG(ERROR) << "Built-in CustomGru's weight must be 2D." << name_;
+ return RET_ERROR;
+ }
+ if (weight_in_shape[0] != weight_hidden_shape[0]) {
+ MS_LOG(ERROR) << "Built-in CustomGru's weight-in and weight-hidden first-dim must be same." << name_;
+ return RET_ERROR;
+ }
+ if (weight_hidden_shape[0] != weight_hidden_shape[1] * C3NUM) {
+ MS_LOG(ERROR) << "Built-in CustomGru's weight-hidden first-dim must be 3 * second-dim." << name_;
+ return RET_ERROR;
+ }
+ auto bias_in_shape = in_tensors_[C3NUM]->shape();
+ auto bias_hidden_shape = in_tensors_[C4NUM]->shape();
+ if (bias_in_shape.size() != 1) {
+ MS_LOG(ERROR) << "Built-in CustomGru's bias must be 1D." << name_;
+ return RET_ERROR;
+ }
+ if (bias_in_shape != bias_hidden_shape) {
+ MS_LOG(ERROR) << "Built-in CustomGru's bias-in and bias-hidden must have same shape." << name_;
+ return RET_ERROR;
+ }
+ if (bias_in_shape.back() != weight_in_shape.front()) {
+ MS_LOG(ERROR) << "Built-in CustomGru's bias-in shape don't match with the first-dim of weight." << name_;
+ return RET_ERROR;
+ }
+ if (bias_in_shape.front() % C3NUM != 0) {
+ MS_LOG(ERROR) << "The first-dim of CustomGru's weight must be 3 * hidden.";
+ return RET_ERROR;
+ }
+ param->input_size = weight_in_shape.back();
+ param->hidden_size = bias_in_shape.front() / C3NUM;
+ return RET_OK;
+}
+
+int CustomGruCPUKernel::InitWeightAndBias() {
+ auto weight_shape = in_tensors_[1]->shape();
+ auto hidden_size = weight_shape[0] / C3NUM;
+ auto col_align = UP_ROUND(hidden_size, col_tile_);
+ auto weight_in_pack_size = static_cast(col_align * weight_shape[1]) * sizeof(float);
+ bool is_packed = false;
+ weight_in_ = lite::PackWeightManager::GetInstance()->GetPackData(
+ in_tensors_[SECOND_INPUT]->data(), static_cast(weight_in_pack_size * C3NUM), &is_packed);
+ MS_CHECK_TRUE_MSG(weight_in_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-in failed.");
+ if (!is_packed) {
+ auto weight_in_src = static_cast(in_tensors_[SECOND_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ RowMajor2Col8Major(weight_in_src + i * hidden_size * weight_shape[1],
+ static_cast(weight_in_) + i * col_align * weight_shape[1], hidden_size,
+ weight_shape[1]);
+ }
+ }
+ auto weight_hidden_pack_size = static_cast(col_align * hidden_size) * sizeof(float);
+ is_packed = false;
+ weight_hidden_ = lite::PackWeightManager::GetInstance()->GetPackData(
+ in_tensors_[THIRD_INPUT]->data(), static_cast(weight_hidden_pack_size * C3NUM), &is_packed);
+ MS_CHECK_TRUE_MSG(weight_hidden_ != nullptr, lite::RET_NULL_PTR, "malloc for packing weight-hidden failed.");
+ if (!is_packed) {
+ auto weight_hidden_src = static_cast(in_tensors_[THIRD_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ RowMajor2Col8Major(weight_hidden_src + i * hidden_size * weight_shape[1],
+ static_cast(weight_hidden_) + i * col_align * weight_shape[1], hidden_size,
+ hidden_size);
+ }
+ }
+ auto bias_pack_size = static_cast(col_align) * sizeof(float);
+ auto bias = reinterpret_cast(malloc(bias_pack_size * C6NUM));
+ if (bias == nullptr) {
+ MS_LOG(ERROR) << "malloc for packing bias failed.";
+ return lite::RET_NULL_PTR;
+ }
+ (void)memset(bias, 0, bias_pack_size * C6NUM);
+ bias_in_ = bias;
+ bias_hidden_ = bias + col_align * C3NUM;
+ auto bias_in_src = static_cast(in_tensors_[FOURTH_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ (void)memcpy(bias + i * col_align, bias_in_src + i * hidden_size, hidden_size * sizeof(float));
+ }
+ auto bias_hidden_src = static_cast(in_tensors_[FIFTH_INPUT]->data());
+ for (int i = 0; i < C3NUM; ++i) {
+ (void)memcpy(bias + (C3NUM + i) * col_align, bias_hidden_src + i * hidden_size, hidden_size * sizeof(float));
+ }
+ if (in_tensors_[SIXTH_INPUT]->IsConst()) {
+ init_h_ = malloc(in_tensors_[SIXTH_INPUT]->Size());
+ MS_CHECK_TRUE_MSG(init_h_ != nullptr, lite::RET_NULL_PTR, "malloc for init-h failed.");
+ (void)memcpy(init_h_, in_tensors_[SIXTH_INPUT]->data(), in_tensors_[SIXTH_INPUT]->Size());
+ }
+ return RET_OK;
+}
+
+int CustomGruCPUKernel::ReSize() {
+ auto in_shape = in_tensors_.front()->shape();
+ if (in_shape.size() != C3NUM) {
+ MS_LOG(ERROR) << "Built-in CustomGru's first-input must be 3D." << name_;
+ return RET_ERROR;
+ }
+ auto param = reinterpret_cast(op_parameter_);
+ param->num_step = in_shape[0];
+ param->batch_size = in_shape[1];
+ if (in_shape.back() != param->input_size) {
+ MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input don't match its weight." << name_;
+ return RET_ERROR;
+ }
+ return RET_OK;
+}
+
+void CustomGruCPUKernel::MallocRunBuffer(size_t data_type_size) {
+ if (run_buffer_ != nullptr) {
+ return;
+ }
+ auto param = reinterpret_cast(op_parameter_);
+ auto row_align = UP_ROUND(param->batch_size, row_tile_);
+ auto run_buffer_size =
+ (row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C6NUM) *
+ data_type_size;
+ if (ms_context_->allocator != nullptr) {
+ run_buffer_ = ms_context_->allocator->Malloc(run_buffer_size);
+ } else {
+ run_buffer_ = malloc(run_buffer_size);
+ }
+}
+
+int CustomGruCPUKernel::Run() {
+ auto input = reinterpret_cast(in_tensors_[FIRST_INPUT]->data());
+ if (input == nullptr) {
+ MS_LOG(ERROR) << "Built-in CustomGru's fisrt-input is nullptr." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ if (!in_tensors_[SIXTH_INPUT]->IsConst()) {
+ init_h_ = in_tensors_[SIXTH_INPUT]->data();
+ }
+ if (init_h_ == nullptr) {
+ MS_LOG(ERROR) << "Built-in CustomGru's six-input is nullptr." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ auto output = reinterpret_cast(out_tensors_.front()->data());
+ if (output == nullptr) {
+ MS_LOG(ERROR) << "Built-in CustomGru's output is nullptr." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ MallocRunBuffer(sizeof(float));
+ if (run_buffer_ == nullptr) {
+ MS_LOG(ERROR) << "malloc running buffer failed." << name_;
+ return lite::RET_NULL_PTR;
+ }
+ auto param = reinterpret_cast(op_parameter_);
+ auto row_align = UP_ROUND(param->batch_size, row_tile_);
+ auto run_buffer = reinterpret_cast(run_buffer_);
+ float *buffer[C4NUM] = {
+ run_buffer, run_buffer + row_align * param->input_size,
+ run_buffer + row_align * param->input_size + param->batch_size * param->hidden_size * C3NUM,
+ run_buffer + row_align * (param->input_size + param->hidden_size) + param->batch_size * param->hidden_size * C3NUM};
+ CustomGru(output, input, static_cast(weight_in_), static_cast(weight_hidden_),
+ static_cast(bias_in_), static_cast(bias_hidden_), static_cast(init_h_), buffer,
+ param);
+ if (ms_context_->allocator != nullptr) {
+ ms_context_->allocator->Free(run_buffer_);
+ } else {
+ free(run_buffer_);
+ }
+ run_buffer_ = nullptr;
+ return RET_OK;
+}
+
+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomGru, LiteKernelCreator)
+} // namespace mindspore::kernel
+#endif
diff --git a/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h
new file mode 100644
index 00000000..e70e408f
--- /dev/null
+++ b/mindspore/lite/src/runtime/kernel/cpu/fp32/custom_gru_fp32.h
@@ -0,0 +1,51 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_
+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_
+#ifdef ENABLE_ARM64
+#include
+#include "src/runtime/lite_kernel.h"
+
+namespace mindspore::kernel {
+class CustomGruCPUKernel : public LiteKernel {
+ public:
+ CustomGruCPUKernel(OpParameter *parameter, const std::vector &inputs,
+ const std::vector &outputs, const lite::InnerContext *ctx)
+ : LiteKernel(parameter, inputs, outputs, ctx) {}
+ ~CustomGruCPUKernel() override;
+ int Prepare() override;
+ int ReSize() override;
+ int Run() override;
+
+ private:
+ int InitParamter();
+
+ protected:
+ void MallocRunBuffer(size_t data_type_size);
+ virtual int InitWeightAndBias();
+ int row_tile_{C12NUM};
+ int col_tile_{C8NUM};
+ void *weight_in_{nullptr};
+ void *weight_hidden_{nullptr};
+ void *bias_in_{nullptr};
+ void *bias_hidden_{nullptr};
+ void *init_h_{nullptr};
+ void *run_buffer_{nullptr};
+};
+} // namespace mindspore::kernel
+#endif
+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_CUSTOM_GRU_FP32_H_
diff --git a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc
index ef03171a..0ff780c7 100644
--- a/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc
+++ b/mindspore/lite/src/runtime/kernel/cpu/int8/convolution_int8_creator.cc
@@ -107,7 +107,7 @@ kernel::LiteKernel *CpuGroupConvInt8KernelCreator(const std::vectorinput_channel_;
return nullptr;
}
- auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8);
+ auto *group_conv_creator = new GroupConvCreator(inputs, outputs, op_parameter, true, kNumberTypeInt8, ctx);
if (group_conv_creator == nullptr) {
MS_LOG(ERROR) << "group_conv_creator is nullptr.";
return nullptr;
diff --git a/mindspore/lite/src/runtime/lite_session.h b/mindspore/lite/src/runtime/lite_session.h
index 2fdb1eb7..d5a672bb 100644
--- a/mindspore/lite/src/runtime/lite_session.h
+++ b/mindspore/lite/src/runtime/lite_session.h
@@ -106,6 +106,12 @@ class MS_API LiteSession {
std::vector out_put_tensor_name = {}) {
return mindspore::lite::RET_ERROR;
}
+ virtual int ExportWeightsCollaborateWithMicro(const std::string &file_name,
+ lite::ModelType model_type = lite::MT_TRAIN,
+ lite::FormatType = lite::FT_FLATBUFFERS, bool enable_fp16 = false,
+ const std::vector &changeable_weights_name = {}) {
+ return mindspore::lite::RET_ERROR;
+ }
virtual int UpdateWeights(std::vector new_weights) { return mindspore::lite::RET_ERROR; }
virtual std::vector GetFeatureMaps() const {
std::vector features;
diff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc
index 1af44e45..03f5675e 100644
--- a/mindspore/lite/src/train/graph_fusion.cc
+++ b/mindspore/lite/src/train/graph_fusion.cc
@@ -22,6 +22,7 @@
#include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
+#include "src/train/optimizer/fusion/gru_fusion_pass.h"
#include "src/train/optimizer/fusion/matmul_activation_fusion_pass.h"
#include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h"
@@ -41,6 +42,12 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "graph is nullptr.";
return RET_ERROR;
}
+ auto gru_fusion = std::make_shared();
+ MS_CHECK_TRUE_MSG(gru_fusion != nullptr, RET_NULL_PTR, "Create GruFusion object failed.");
+ if (gru_fusion->Run(graph) != RET_OK) {
+ MS_LOG(ERROR) << "Do GruFusion failed.";
+ return RET_ERROR;
+ }
auto old_nodes = GetGraphNodes(*graph);
Optimizer fusion_optimizer;
fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass());
diff --git a/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc
new file mode 100644
index 00000000..435686e5
--- /dev/null
+++ b/mindspore/lite/src/train/optimizer/fusion/gru_fusion_pass.cc
@@ -0,0 +1,809 @@
+/**
+ * Copyright 2023 Huawei Technologies Co., Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "src/train/optimizer/fusion/gru_fusion_pass.h"
+#include
+#include